Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 44 additions & 14 deletions src/strands/telemetry/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import logging
import os
from datetime import datetime, timezone
from datetime import date, datetime, timezone
from importlib.metadata import version
from typing import Any, Dict, Mapping, Optional

Expand All @@ -30,21 +30,49 @@
class JSONEncoder(json.JSONEncoder):
"""Custom JSON encoder that handles non-serializable types."""

def default(self, obj: Any) -> Any:
"""Handle non-serializable types.
def encode(self, obj: Any) -> str:
"""Recursively encode objects, preserving structure and only replacing unserializable values.

Args:
obj: The object to serialize
obj: The object to encode

Returns:
A JSON serializable version of the object
JSON string representation of the object
"""
value = ""
try:
value = super().default(obj)
except TypeError:
value = "<replaced>"
return value
# Process the object to handle non-serializable values
processed_obj = self._process_value(obj)
# Use the parent class to encode the processed object
return super().encode(processed_obj)

def _process_value(self, value: Any) -> Any:
"""Process any value, handling containers recursively.

Args:
value: The value to process

Returns:
Processed value with unserializable parts replaced
"""
# Handle datetime objects directly
if isinstance(value, (datetime, date)):
return value.isoformat()

# Handle dictionaries
elif isinstance(value, dict):
return {k: self._process_value(v) for k, v in value.items()}

# Handle lists
elif isinstance(value, list):
return [self._process_value(item) for item in value]

# Handle all other values
else:
try:
# Test if the value is JSON serializable
json.dumps(value)
return value
except (TypeError, OverflowError, ValueError):
return "<replaced>"


class Tracer:
Expand Down Expand Up @@ -332,6 +360,7 @@ def start_tool_call_span(
The created span, or None if tracing is not enabled.
"""
attributes: Dict[str, AttributeValue] = {
"gen_ai.prompt": json.dumps(tool, cls=JSONEncoder),
"tool.name": tool["name"],
"tool.id": tool["toolUseId"],
"tool.parameters": json.dumps(tool["input"], cls=JSONEncoder),
Expand All @@ -358,10 +387,11 @@ def end_tool_call_span(
status = tool_result.get("status")
status_str = str(status) if status is not None else ""

tool_result_content_json = json.dumps(tool_result.get("content"), cls=JSONEncoder)
attributes.update(
{
"tool.result": json.dumps(tool_result.get("content"), cls=JSONEncoder),
"gen_ai.completion": json.dumps(tool_result.get("content"), cls=JSONEncoder),
"tool.result": tool_result_content_json,
"gen_ai.completion": tool_result_content_json,
"tool.status": status_str,
}
)
Expand Down Expand Up @@ -492,7 +522,7 @@ def end_agent_span(
if response:
attributes.update(
{
"gen_ai.completion": json.dumps(response, cls=JSONEncoder),
"gen_ai.completion": str(response),
}
)

Expand Down
142 changes: 140 additions & 2 deletions tests/strands/telemetry/test_tracer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
import os
from datetime import date, datetime, timezone
from unittest import mock

import pytest
from opentelemetry.trace import StatusCode # type: ignore

from strands.telemetry.tracer import Tracer, get_tracer
from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer
from strands.types.streaming import Usage


Expand Down Expand Up @@ -268,6 +269,9 @@ def test_start_tool_call_span(mock_tracer):

mock_tracer.start_span.assert_called_once()
assert mock_tracer.start_span.call_args[1]["name"] == "Tool: test-tool"
mock_span.set_attribute.assert_any_call(
"gen_ai.prompt", json.dumps({"name": "test-tool", "toolUseId": "123", "input": {"param": "value"}})
)
mock_span.set_attribute.assert_any_call("tool.name", "test-tool")
mock_span.set_attribute.assert_any_call("tool.id", "123")
mock_span.set_attribute.assert_any_call("tool.parameters", json.dumps({"param": "value"}))
Expand Down Expand Up @@ -369,7 +373,7 @@ def test_end_agent_span(mock_span):

tracer.end_agent_span(mock_span, mock_response)

mock_span.set_attribute.assert_any_call("gen_ai.completion", '"<replaced>"')
mock_span.set_attribute.assert_any_call("gen_ai.completion", "Agent response")
mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50)
mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100)
mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150)
Expand Down Expand Up @@ -497,3 +501,137 @@ def test_start_model_invoke_span_with_parent(mock_tracer):

# Verify span was returned
assert span is mock_span


@pytest.mark.parametrize(
"input_data, expected_result",
[
("test string", '"test string"'),
(1234, "1234"),
(13.37, "13.37"),
(False, "false"),
(None, "null"),
],
)
def test_json_encoder_serializable(input_data, expected_result):
"""Test encoding of serializable values."""
encoder = JSONEncoder()

result = encoder.encode(input_data)
assert result == expected_result


def test_json_encoder_datetime():
"""Test encoding datetime and date objects."""
encoder = JSONEncoder()

dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
result = encoder.encode(dt)
assert result == f'"{dt.isoformat()}"'

d = date(2025, 1, 1)
result = encoder.encode(d)
assert result == f'"{d.isoformat()}"'


def test_json_encoder_list():
"""Test encoding a list with mixed content."""
encoder = JSONEncoder()

non_serializable = lambda x: x # noqa: E731

data = ["value", 42, 13.37, non_serializable, None, {"key": True}, ["value here"]]

result = json.loads(encoder.encode(data))
assert result == ["value", 42, 13.37, "<replaced>", None, {"key": True}, ["value here"]]


def test_json_encoder_dict():
"""Test encoding a dict with mixed content."""
encoder = JSONEncoder()

class UnserializableClass:
def __str__(self):
return "Unserializable Object"

non_serializable = lambda x: x # noqa: E731

now = datetime.now(timezone.utc)

data = {
"metadata": {
"timestamp": now,
"version": "1.0",
"debug_info": {"object": non_serializable, "callable": lambda x: x + 1}, # noqa: E731
},
"content": [
{"type": "text", "value": "Hello world"},
{"type": "binary", "value": non_serializable},
{"type": "mixed", "values": [1, "text", non_serializable, {"nested": non_serializable}]},
],
"statistics": {
"processed": 100,
"failed": 5,
"details": [{"id": 1, "status": "ok"}, {"id": 2, "status": "error", "error_obj": non_serializable}],
},
"list": [
non_serializable,
1234,
13.37,
True,
None,
"string here",
],
}

expected = {
"metadata": {
"timestamp": now.isoformat(),
"version": "1.0",
"debug_info": {"object": "<replaced>", "callable": "<replaced>"},
},
"content": [
{"type": "text", "value": "Hello world"},
{"type": "binary", "value": "<replaced>"},
{"type": "mixed", "values": [1, "text", "<replaced>", {"nested": "<replaced>"}]},
],
"statistics": {
"processed": 100,
"failed": 5,
"details": [{"id": 1, "status": "ok"}, {"id": 2, "status": "error", "error_obj": "<replaced>"}],
},
"list": [
"<replaced>",
1234,
13.37,
True,
None,
"string here",
],
}

result = json.loads(encoder.encode(data))

assert result == expected


def test_json_encoder_value_error():
"""Test encoding values that cause ValueError."""
encoder = JSONEncoder()

# A very large integer that exceeds JSON limits and throws ValueError
huge_number = 2**100000

# Test in a dictionary
dict_data = {"normal": 42, "huge": huge_number}
result = json.loads(encoder.encode(dict_data))
assert result == {"normal": 42, "huge": "<replaced>"}

# Test in a list
list_data = [42, huge_number]
result = json.loads(encoder.encode(list_data))
assert result == [42, "<replaced>"]

# Test just the value
result = json.loads(encoder.encode(huge_number))
assert result == "<replaced>"