Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions sentry_sdk/ai/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from sentry_sdk.consts import SPANDATA
import sentry_sdk.utils
from sentry_sdk import start_span
from sentry_sdk import capture_event, get_client, start_span
from sentry_sdk.tracing import Span
from sentry_sdk.utils import ContextVar
from sentry_sdk.utils import event_from_exception, ContextVar

from typing import TYPE_CHECKING

Expand All @@ -29,18 +29,35 @@ def get_ai_pipeline_name():

def ai_track(description, **span_kwargs):
# type: (str, Any) -> Callable[[F], F]
"""
Decorator to track AI pipeline/operation spans and capture exceptions.
"""

def decorator(f):
# type: (F) -> F
def sync_wrapped(*args, **kwargs):
# type: (Any, Any) -> Any
curr_pipeline = _ai_pipeline_name.get()
op = span_kwargs.pop("op", "ai.run" if curr_pipeline else "ai.pipeline")

with start_span(name=description, op=op, **span_kwargs) as span:
for k, v in kwargs.pop("sentry_tags", {}).items():
def _set_span_tags_and_data(span, kwargs):
# Avoids repeated lookups and pop logic for both sync/async branches.
tags = kwargs.pop("sentry_tags", None)
if tags:
for k, v in tags.items():
span.set_tag(k, v)
for k, v in kwargs.pop("sentry_data", {}).items():
data = kwargs.pop("sentry_data", None)
if data:
for k, v in data.items():
span.set_data(k, v)

def sync_wrapped(*args, **kwargs):
# type: (Any, Any) -> Any
curr_pipeline = _ai_pipeline_name.get()
# Avoids modifying the decorator-level span_kwargs in concurrent/multi use.
local_span_kwargs = span_kwargs.copy()
op = local_span_kwargs.pop(
"op", "ai.run" if curr_pipeline else "ai.pipeline"
)

with start_span(name=description, op=op, **local_span_kwargs) as span:
_set_span_tags_and_data(span, kwargs)
if curr_pipeline:
span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, curr_pipeline)
return f(*args, **kwargs)
Expand All @@ -49,12 +66,12 @@ def sync_wrapped(*args, **kwargs):
try:
res = f(*args, **kwargs)
except Exception as e:
event, hint = sentry_sdk.utils.event_from_exception(
event, hint = event_from_exception(
e,
client_options=sentry_sdk.get_client().options,
client_options=get_client().options,
mechanism={"type": "ai_monitoring", "handled": False},
)
sentry_sdk.capture_event(event, hint=hint)
capture_event(event, hint=hint)
raise e from None
finally:
_ai_pipeline_name.set(None)
Expand All @@ -63,13 +80,13 @@ def sync_wrapped(*args, **kwargs):
async def async_wrapped(*args, **kwargs):
# type: (Any, Any) -> Any
curr_pipeline = _ai_pipeline_name.get()
op = span_kwargs.pop("op", "ai.run" if curr_pipeline else "ai.pipeline")
local_span_kwargs = span_kwargs.copy()
op = local_span_kwargs.pop(
"op", "ai.run" if curr_pipeline else "ai.pipeline"
)

with start_span(name=description, op=op, **span_kwargs) as span:
for k, v in kwargs.pop("sentry_tags", {}).items():
span.set_tag(k, v)
for k, v in kwargs.pop("sentry_data", {}).items():
span.set_data(k, v)
with start_span(name=description, op=op, **local_span_kwargs) as span:
_set_span_tags_and_data(span, kwargs)
if curr_pipeline:
span.set_data(SPANDATA.GEN_AI_PIPELINE_NAME, curr_pipeline)
return await f(*args, **kwargs)
Expand All @@ -78,12 +95,12 @@ async def async_wrapped(*args, **kwargs):
try:
res = await f(*args, **kwargs)
except Exception as e:
event, hint = sentry_sdk.utils.event_from_exception(
event, hint = event_from_exception(
e,
client_options=sentry_sdk.get_client().options,
client_options=get_client().options,
mechanism={"type": "ai_monitoring", "handled": False},
)
sentry_sdk.capture_event(event, hint=hint)
capture_event(event, hint=hint)
raise e from None
finally:
_ai_pipeline_name.set(None)
Expand Down