Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Interceptor for collecting Cloud Spanner metrics."""

import re
from typing import Dict
from typing import Any, Dict

from grpc_interceptor import ClientInterceptor

Expand Down Expand Up @@ -122,10 +122,115 @@ def intercept(self, invoked_method, request_or_iterator, call_details):
tracer.set_method(method_name)
tracer.record_attempt_start()
response = invoked_method(request_or_iterator, call_details)
tracer.record_attempt_completion()

# Process and send GFE metrics if enabled
if tracer.gfe_enabled:
metadata = response.initial_metadata()
tracer.record_gfe_metrics(metadata)
return _wrap_response(response, tracer)


def _wrap_response(response: Any, tracer: Any) -> Any:
"""Wraps the response if it is streaming, or records metrics immediately if unary."""
if hasattr(response, "__anext__") or hasattr(response, "__aiter__"):
return _AsyncStreamingResponseWrapper(response, tracer)
elif hasattr(response, "__next__") or hasattr(response, "__iter__"):
return _StreamingResponseWrapper(response, tracer)
Comment on lines +131 to +134

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Checking hasattr(response, "__iter__") or hasattr(response, "__aiter__") is too broad because many standard unary response types (such as lists, dicts, tuples, or custom iterables) are iterable but are not streaming responses. If a unary RPC returns an iterable, it will be incorrectly wrapped in _StreamingResponseWrapper, which will break the caller's expectations.

Additionally, in unit tests, MagicMock objects have __iter__ by default, which causes them to be incorrectly wrapped.

To correctly identify streaming responses, we should only check for the iterator protocol methods __next__ and __anext__. Any gRPC streaming response is an iterator and must implement these methods.

Suggested change
if hasattr(response, "__anext__") or hasattr(response, "__aiter__"):
return _AsyncStreamingResponseWrapper(response, tracer)
elif hasattr(response, "__next__") or hasattr(response, "__iter__"):
return _StreamingResponseWrapper(response, tracer)
if hasattr(response, "__anext__"):
return _AsyncStreamingResponseWrapper(response, tracer)
elif hasattr(response, "__next__"):
return _StreamingResponseWrapper(response, tracer)

else:
# Unary call: execute completion and record metrics immediately
tracer.record_attempt_completion()
metadata = []
if hasattr(response, "initial_metadata"):
try:
metadata.extend(response.initial_metadata() or [])
except Exception:
pass
if hasattr(response, "trailing_metadata"):
try:
metadata.extend(response.trailing_metadata() or [])
except Exception:
pass
tracer.record_gfe_metrics(metadata)
return response
Comment on lines +135 to 150

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The metrics recording block for unary calls is not wrapped in a try-except block. If tracer.record_attempt_completion() or tracer.record_gfe_metrics(metadata) raises an exception (e.g., due to OpenTelemetry configuration issues or unexpected metadata formats), it will crash the entire unary RPC call and prevent the response from being returned. Telemetry and metrics collection should be non-blocking and fail-safe, meaning they should never disrupt the main application flow. Avoid broad except Exception: blocks that silently pass; instead, log the exception to aid in debugging.

Suggested change
else:
# Unary call: execute completion and record metrics immediately
tracer.record_attempt_completion()
metadata = []
if hasattr(response, "initial_metadata"):
try:
metadata.extend(response.initial_metadata() or [])
except Exception:
pass
if hasattr(response, "trailing_metadata"):
try:
metadata.extend(response.trailing_metadata() or [])
except Exception:
pass
tracer.record_gfe_metrics(metadata)
return response
else:
# Unary call: execute completion and record metrics immediately
try:
tracer.record_attempt_completion()
metadata = []
if hasattr(response, "initial_metadata"):
try:
metadata.extend(response.initial_metadata() or [])
except Exception as e:
logger.warning(f"Failed to retrieve initial metadata: {e}")
if hasattr(response, "trailing_metadata"):
try:
metadata.extend(response.trailing_metadata() or [])
except Exception as e:
logger.warning(f"Failed to retrieve trailing metadata: {e}")
tracer.record_gfe_metrics(metadata)
except Exception as e:
logger.warning(f"Failed to record metrics: {e}")
return response
References
  1. Avoid broad except Exception: blocks that silently return None. Instead, log the exception (e.g., using logger.warning) to aid in debugging and prevent masking underlying issues.



class _StreamingResponseWrapper:
"""Wrapper for streaming RPC response iterators to defer metrics recording."""

def __init__(self, response, tracer):
self._response = response
self._tracer = tracer
self._metrics_recorded = False

def __iter__(self):
return self

def __next__(self):
try:
return next(self._response)
except StopIteration:
self._record_metrics()
raise
except Exception:
self._record_metrics()
raise

def _record_metrics(self):
if self._metrics_recorded:
return
self._metrics_recorded = True
self._tracer.record_attempt_completion()
metadata = []
if hasattr(self._response, "initial_metadata"):
try:
metadata.extend(self._response.initial_metadata() or [])
except Exception:
pass
if hasattr(self._response, "trailing_metadata"):
try:
metadata.extend(self._response.trailing_metadata() or [])
except Exception:
pass
self._tracer.record_gfe_metrics(metadata)
Comment on lines +174 to +190

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the unary call metrics recording, if any exception is raised during _record_metrics (e.g., in record_attempt_completion or record_gfe_metrics), it will propagate to the caller of __next__. This can mask the original StopIteration or other exceptions, or crash the application during stream consumption. We should wrap the entire metrics recording logic in a try-except block to ensure telemetry failures are safe, and log any exceptions to avoid silent failures.

Suggested change
def _record_metrics(self):
if self._metrics_recorded:
return
self._metrics_recorded = True
self._tracer.record_attempt_completion()
metadata = []
if hasattr(self._response, "initial_metadata"):
try:
metadata.extend(self._response.initial_metadata() or [])
except Exception:
pass
if hasattr(self._response, "trailing_metadata"):
try:
metadata.extend(self._response.trailing_metadata() or [])
except Exception:
pass
self._tracer.record_gfe_metrics(metadata)
def _record_metrics(self):
if self._metrics_recorded:
return
self._metrics_recorded = True
try:
self._tracer.record_attempt_completion()
metadata = []
if hasattr(self._response, "initial_metadata"):
try:
metadata.extend(self._response.initial_metadata() or [])
except Exception as e:
logger.warning(f"Failed to retrieve initial metadata: {e}")
if hasattr(self._response, "trailing_metadata"):
try:
metadata.extend(self._response.trailing_metadata() or [])
except Exception as e:
logger.warning(f"Failed to retrieve trailing metadata: {e}")
self._tracer.record_gfe_metrics(metadata)
except Exception as e:
logger.warning(f"Failed to record metrics: {e}")
References
  1. Avoid broad except Exception: blocks that silently return None. Instead, log the exception (e.g., using logger.warning) to aid in debugging and prevent masking underlying issues.


def __getattr__(self, name):
return getattr(self._response, name)


class _AsyncStreamingResponseWrapper:
"""Wrapper for async streaming RPC response iterators to defer metrics recording."""

def __init__(self, response, tracer):
self._response = response
self._tracer = tracer
self._metrics_recorded = False

def __aiter__(self):
return self

async def __anext__(self):
try:
return await self._response.__anext__()
except StopAsyncIteration:
self._record_metrics()
raise
except Exception:
self._record_metrics()
raise

def _record_metrics(self):
if self._metrics_recorded:
return
self._metrics_recorded = True
self._tracer.record_attempt_completion()
metadata = []
if hasattr(self._response, "initial_metadata"):
try:
metadata.extend(self._response.initial_metadata() or [])
except Exception:
pass
if hasattr(self._response, "trailing_metadata"):
try:
metadata.extend(self._response.trailing_metadata() or [])
except Exception:
pass
self._tracer.record_gfe_metrics(metadata)
Comment on lines +217 to +233

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For the asynchronous streaming wrapper, we should also wrap the metrics recording logic in a try-except block to prevent telemetry failures from crashing the async stream or masking StopAsyncIteration. Ensure exceptions are logged rather than silently ignored.

Suggested change
def _record_metrics(self):
if self._metrics_recorded:
return
self._metrics_recorded = True
self._tracer.record_attempt_completion()
metadata = []
if hasattr(self._response, "initial_metadata"):
try:
metadata.extend(self._response.initial_metadata() or [])
except Exception:
pass
if hasattr(self._response, "trailing_metadata"):
try:
metadata.extend(self._response.trailing_metadata() or [])
except Exception:
pass
self._tracer.record_gfe_metrics(metadata)
def _record_metrics(self):
if self._metrics_recorded:
return
self._metrics_recorded = True
try:
self._tracer.record_attempt_completion()
metadata = []
if hasattr(self._response, "initial_metadata"):
try:
metadata.extend(self._response.initial_metadata() or [])
except Exception as e:
logger.warning(f"Failed to retrieve initial metadata: {e}")
if hasattr(self._response, "trailing_metadata"):
try:
metadata.extend(self._response.trailing_metadata() or [])
except Exception as e:
logger.warning(f"Failed to retrieve trailing metadata: {e}")
self._tracer.record_gfe_metrics(metadata)
except Exception as e:
logger.warning(f"Failed to record metrics: {e}")
References
  1. Avoid broad except Exception: blocks that silently return None. Instead, log the exception (e.g., using logger.warning) to aid in debugging and prevent masking underlying issues.


def __getattr__(self, name):
return getattr(self._response, name)
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
while the helper classes provide additional functionality and context for the metrics being traced.
"""

import re
from datetime import datetime
from typing import Dict
from typing import Any, Dict, Optional

from grpc import StatusCode

Expand Down Expand Up @@ -198,6 +199,8 @@ def __init__(
instrument_operation_counter: "Counter",
client_attributes: Dict[str, str],
gfe_enabled: bool = False,
instrument_gfe_latency: Optional["Histogram"] = None,
instrument_gfe_missing_header_count: Optional["Counter"] = None,
):
"""
Initialize a MetricsTracer instance with the given parameters.
Expand All @@ -214,15 +217,19 @@ def __init__(
instrument_operation_counter (Counter): Instrument for counting operations.
client_attributes (Dict[str, str]): Dictionary of client attributes used for metrics tracing.
gfe_enabled (bool, optional): Indicates if GFE metrics are enabled. Defaults to False.
instrument_gfe_latency (Histogram, optional): Instrument for measuring GFE latency.
instrument_gfe_missing_header_count (Counter, optional): Instrument for counting missing GFE headers.
"""
self.current_op = MetricOpTracer()
self._client_attributes = client_attributes
self._instrument_attempt_latency = instrument_attempt_latency
self._instrument_attempt_counter = instrument_attempt_counter
self._instrument_operation_latency = instrument_operation_latency
self._instrument_operation_counter = instrument_operation_counter
self._instrument_gfe_latency = instrument_gfe_latency
self._instrument_gfe_missing_header_count = instrument_gfe_missing_header_count
self.enabled = enabled
self.gfe_enabled = gfe_enabled
self.gfe_enabled = True

@staticmethod
def _get_ms_time_diff(start: datetime, end: datetime) -> float:
Expand Down Expand Up @@ -399,7 +406,11 @@ def record_gfe_latency(self, latency: int) -> None:
Args:
latency (int): The latency duration to be recorded.
"""
if not self.enabled or not HAS_OPENTELEMETRY_INSTALLED or not self.gfe_enabled:
if (
not self.enabled
or not HAS_OPENTELEMETRY_INSTALLED
or not getattr(self, "_instrument_gfe_latency", None)
):
return
self._instrument_gfe_latency.record(
amount=latency, attributes=self.client_attributes
Expand All @@ -409,12 +420,65 @@ def record_gfe_missing_header_count(self) -> None:
"""
Increments the counter for missing GFE headers.
"""
if not self.enabled or not HAS_OPENTELEMETRY_INSTALLED or not self.gfe_enabled:
if (
not self.enabled
or not HAS_OPENTELEMETRY_INSTALLED
or not getattr(self, "_instrument_gfe_missing_header_count", None)
):
return
self._instrument_gfe_missing_header_count.add(
amount=1, attributes=self.client_attributes
)

@staticmethod
def extract_gfe_latency(metadata: Any) -> Optional[int]:
"""
Extracts the GFE latency value (in milliseconds) from response metadata.
"""
if not metadata:
return None

header_vals = []
if isinstance(metadata, dict):
for key, val in metadata.items():
if key and str(key).lower() in ("server-timing", "server_timing"):
if isinstance(val, (list, tuple)):
header_vals.extend(val)
else:
header_vals.append(val)
elif isinstance(metadata, (list, tuple)):
for key, val in metadata:
if key and str(key).lower() in ("server-timing", "server_timing"):
if isinstance(val, (list, tuple)):
header_vals.extend(val)
else:
header_vals.append(val)
Comment on lines +449 to +455

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When iterating over metadata as a list or tuple, unpacking for key, val in metadata assumes that every element in metadata is a sequence of exactly two elements. If metadata contains any malformed elements (e.g., a 1-tuple, a string, or None), this will raise a ValueError and crash the metrics extraction. We should defensively verify that each item is a sequence of length 2 before unpacking.

Suggested change
elif isinstance(metadata, (list, tuple)):
for key, val in metadata:
if key and str(key).lower() in ("server-timing", "server_timing"):
if isinstance(val, (list, tuple)):
header_vals.extend(val)
else:
header_vals.append(val)
elif isinstance(metadata, (list, tuple)):
for item in metadata:
if isinstance(item, (list, tuple)) and len(item) == 2:
key, val = item
if key and str(key).lower() in ("server-timing", "server_timing"):
if isinstance(val, (list, tuple)):
header_vals.extend(val)
else:
header_vals.append(val)


for header_val in header_vals:
if not header_val:
continue
if not isinstance(header_val, str):
header_val = str(header_val)
Comment on lines +460 to +461

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If header_val is of type bytes (which is common for gRPC metadata values), calling str(header_val) in Python 3 will produce the string representation b'...' (including the literal b and quotes). This will prevent the regex from matching correctly or cause unexpected behavior. We should decode bytes to str using .decode("utf-8") first.

            if isinstance(header_val, bytes):
                try:
                    header_val = header_val.decode("utf-8")
                except Exception:
                    header_val = str(header_val)
            elif not isinstance(header_val, str):
                header_val = str(header_val)

match = re.search(r"gfet4t7;\s*dur=([0-9]+)", header_val)
if match:
try:
return int(match.group(1))
except ValueError:
pass
return None

def record_gfe_metrics(self, metadata: Any) -> None:
"""
Extracts and records GFE metrics from the RPC response metadata.
"""
if not self.enabled or not HAS_OPENTELEMETRY_INSTALLED:
return
latency = self.extract_gfe_latency(metadata)
if latency is not None:
self.record_gfe_latency(latency)
else:
self.record_gfe_missing_header_count()

def _create_operation_otel_attributes(self) -> dict:
"""
Create additional attributes for operation metrics tracing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(self, enabled: bool, service_name: str):
project (str): The project ID for the monitored resource.
"""
self.enabled = enabled
self.gfe_enabled = True
self._create_metric_instruments(service_name)
self._client_attributes = {}

Expand Down Expand Up @@ -268,6 +269,11 @@ def create_metrics_tracer(self) -> MetricsTracer:
instrument_operation_latency=self._instrument_operation_latency,
instrument_operation_counter=self._instrument_operation_counter,
client_attributes=self._client_attributes.copy(),
gfe_enabled=True,
instrument_gfe_latency=getattr(self, "_instrument_gfe_latency", None),
instrument_gfe_missing_header_count=getattr(
self, "_instrument_gfe_missing_header_count", None
),
)
return metrics_tracer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ class SpannerMetricsTracerFactory(MetricsTracerFactory):
"current_metrics_tracer", default=None
)

def __new__(
cls, enabled: bool = True, gfe_enabled: bool = False
) -> "SpannerMetricsTracerFactory":
def __new__(cls, enabled: bool = True) -> "SpannerMetricsTracerFactory":
"""
Create a new instance of SpannerMetricsTracerFactory if it doesn't already exist.

Expand All @@ -63,7 +61,6 @@ def __new__(

Args:
enabled (bool): A flag indicating whether metrics tracing is enabled. Defaults to True.
gfe_enabled (bool): A flag indicating whether GFE metrics are enabled. Defaults to False.

Returns:
SpannerMetricsTracerFactory: The singleton instance of SpannerMetricsTracerFactory.
Expand All @@ -83,7 +80,7 @@ def __new__(
cls._generate_client_hash(client_uid)
)
cls._metrics_tracer_factory.set_location(_get_cloud_region())
cls._metrics_tracer_factory.gfe_enabled = gfe_enabled
cls._metrics_tracer_factory.gfe_enabled = True

if cls._metrics_tracer_factory.enabled != enabled:
cls._metrics_tracer_factory.enabled = enabled
Expand Down
Loading
Loading