Skip to content

refactor: reorganize message handling for better type safety and clarity #239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 13, 2025
Prev Previous commit
refactor: rename root to message
  • Loading branch information
dsp-ant committed Mar 13, 2025
commit 1c53fc208beaa8fa13bccfaa41d2d56908563239
2 changes: 1 addition & 1 deletion src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def sse_reader(
case "message":
try:
message = MessageFrame(
root=types.JSONRPCMessage.model_validate_json( # noqa: E501
message=types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
),
raw=sse,
Expand Down
4 changes: 2 additions & 2 deletions src/mcp/client/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def ws_reader():
raw_text
)
# Create MessageFrame with JSON message as root
message = MessageFrame(root=json_message, raw=raw_text)
message = MessageFrame(message=json_message, raw=raw_text)
await read_stream_writer.send(message)
except ValidationError as exc:
# If JSON parse or model validation fails, send the exception
Expand All @@ -72,7 +72,7 @@ async def ws_writer():
async with write_stream_reader:
async for message in write_stream_reader:
# Extract the JSON-RPC message from MessageFrame and convert to JSON
msg_dict = message.root.model_dump(
msg_dict = message.message.model_dump(
by_alias=True, mode="json", exclude_none=True
)
await ws.send(json.dumps(msg_dict))
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,4 @@ async def handle_post_message(
logger.debug(f"Sending message to writer: {message}")
response = Response("Accepted", status_code=202)
await response(scope, receive, send)
await writer.send(MessageFrame(root=message, raw=request))
await writer.send(MessageFrame(message=message, raw=request))
4 changes: 3 additions & 1 deletion src/mcp/server/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ async def stdin_reader():
await read_stream_writer.send(exc)
continue

await read_stream_writer.send(MessageFrame(root=message, raw=line))
await read_stream_writer.send(
MessageFrame(message=message, raw=line)
)
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()

Expand Down
2 changes: 1 addition & 1 deletion src/mcp/server/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def ws_reader():
continue

await read_stream_writer.send(
MessageFrame(root=client_message, raw=message)
MessageFrame(message=client_message, raw=message)
)
except anyio.ClosedResourceError:
await websocket.close()
Expand Down
10 changes: 5 additions & 5 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ async def send_request(
# TODO: Support progress callbacks

await self._write_stream.send(
MessageFrame(root=JSONRPCMessage(jsonrpc_request), raw=None)
MessageFrame(message=JSONRPCMessage(jsonrpc_request), raw=None)
)

try:
Expand Down Expand Up @@ -287,7 +287,7 @@ async def send_notification(self, notification: SendNotificationT) -> None:
)

await self._write_stream.send(
MessageFrame(root=JSONRPCMessage(jsonrpc_notification), raw=None)
MessageFrame(message=JSONRPCMessage(jsonrpc_notification), raw=None)
)

async def _send_response(
Expand All @@ -296,7 +296,7 @@ async def _send_response(
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
await self._write_stream.send(
MessageFrame(root=JSONRPCMessage(jsonrpc_error), raw=None)
MessageFrame(message=JSONRPCMessage(jsonrpc_error), raw=None)
)
else:
jsonrpc_response = JSONRPCResponse(
Expand All @@ -307,7 +307,7 @@ async def _send_response(
),
)
await self._write_stream.send(
MessageFrame(root=JSONRPCMessage(jsonrpc_response), raw=None)
MessageFrame(message=JSONRPCMessage(jsonrpc_response), raw=None)
)

async def _receive_loop(self) -> None:
Expand All @@ -321,7 +321,7 @@ async def _receive_loop(self) -> None:
await self._incoming_message_stream_writer.send(raw_message)
continue

message = raw_message.root
message = raw_message.message
if isinstance(message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.root.model_dump(
Expand Down
34 changes: 31 additions & 3 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,43 @@ class JSONRPCMessage(


class MessageFrame(BaseModel, Generic[RawT]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know how fastidious we are about this in general, but fwiw the other classes here are docstringed up pretty thoroughly (and the point of raw in particular might be worth noting here)

root: JSONRPCMessage
"""
A wrapper around the general message received that contains both the parsed message
and the raw message.

This class serves as an encapsulation for JSON-RPC messages, providing access to
both the parsed structure (root) and the original raw data. This design is
particularly useful for Server-Sent Events (SSE) consumers who may need to access
additional metadata or headers associated with the message.

The 'root' attribute contains the parsed JSONRPCMessage, which could be a request,
notification, response, or error. The 'raw' attribute preserves the original
message as received, allowing access to any additional context or metadata that
might be lost in parsing.

This dual representation allows for flexible handling of messages, where consumers
can work with the structured data for standard operations, but still have the
option to examine or utilize the raw data when needed, such as for debugging,
logging, or accessing transport-specific information.
"""

message: JSONRPCMessage
raw: RawT | None = None
model_config = ConfigDict(extra="allow")

def model_dump(self, *args, **kwargs):
return self.root.model_dump(*args, **kwargs)
"""
Dumps the model to a dictionary, delegating to the root JSONRPCMessage.
This method allows for consistent serialization of the parsed message.
"""
return self.message.model_dump(*args, **kwargs)

def model_dump_json(self, *args, **kwargs):
return self.root.model_dump_json(*args, **kwargs)
"""
Dumps the model to a JSON string, delegating to the root JSONRPCMessage.
This method provides a convenient way to serialize the parsed message to JSON.
"""
return self.message.model_dump_json(*args, **kwargs)


class EmptyResult(Result):
Expand Down
10 changes: 5 additions & 5 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ async def mock_server():
)

async with server_to_client_send:
assert isinstance(jsonrpc_request.root.root, JSONRPCRequest)
assert isinstance(jsonrpc_request.message.root, JSONRPCRequest)
await server_to_client_send.send(
MessageFrame(
root=JSONRPCMessage(
message=JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.root.id,
id=jsonrpc_request.message.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
Expand All @@ -74,9 +74,9 @@ async def mock_server():
)
)
jsonrpc_notification = await client_to_server_receive.receive()
assert isinstance(jsonrpc_notification.root, JSONRPCMessage)
assert isinstance(jsonrpc_notification.message, JSONRPCMessage)
initialized_notification = ClientNotification.model_validate(
jsonrpc_notification.root.model_dump(
jsonrpc_notification.message.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
Expand Down
10 changes: 6 additions & 4 deletions tests/issues/test_192_request_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def run_server():
)

await client_writer.send(
MessageFrame(root=JSONRPCMessage(root=init_req), raw=None)
MessageFrame(message=JSONRPCMessage(root=init_req), raw=None)
)
await server_reader.receive() # Get init response but don't need to check it

Expand All @@ -77,7 +77,9 @@ async def run_server():
jsonrpc="2.0",
)
await client_writer.send(
MessageFrame(root=JSONRPCMessage(root=initialized_notification), raw=None)
MessageFrame(
message=JSONRPCMessage(root=initialized_notification), raw=None
)
)

# Send ping request with custom ID
Expand All @@ -86,15 +88,15 @@ async def run_server():
)

await client_writer.send(
MessageFrame(root=JSONRPCMessage(root=ping_request), raw=None)
MessageFrame(message=JSONRPCMessage(root=ping_request), raw=None)
)

# Read response
response = await server_reader.receive()

# Verify response ID matches request ID
assert (
response.root.root.id == custom_request_id
response.message.root.id == custom_request_id
), "Response ID should match request ID"

# Cancel server task
Expand Down
16 changes: 8 additions & 8 deletions tests/server/test_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async def run_server():
)
await send_stream1.send(
MessageFrame(
root=JSONRPCMessage(
message=JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=1,
Expand All @@ -100,7 +100,7 @@ async def run_server():
# Send initialized notification
await send_stream1.send(
MessageFrame(
root=JSONRPCMessage(
message=JSONRPCMessage(
root=JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized",
Expand All @@ -113,7 +113,7 @@ async def run_server():
# Call the tool to verify lifespan context
await send_stream1.send(
MessageFrame(
root=JSONRPCMessage(
message=JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=2,
Expand All @@ -127,7 +127,7 @@ async def run_server():

# Get response and verify
response = await receive_stream2.receive()
assert response.root.root.result["content"][0]["text"] == "true"
assert response.message.root.result["content"][0]["text"] == "true"

# Cancel server task
tg.cancel_scope.cancel()
Expand Down Expand Up @@ -189,7 +189,7 @@ async def run_server():
)
await send_stream1.send(
MessageFrame(
root=JSONRPCMessage(
message=JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=1,
Expand All @@ -205,7 +205,7 @@ async def run_server():
# Send initialized notification
await send_stream1.send(
MessageFrame(
root=JSONRPCMessage(
message=JSONRPCMessage(
root=JSONRPCNotification(
jsonrpc="2.0",
method="notifications/initialized",
Expand All @@ -218,7 +218,7 @@ async def run_server():
# Call the tool to verify lifespan context
await send_stream1.send(
MessageFrame(
root=JSONRPCMessage(
message=JSONRPCMessage(
root=JSONRPCRequest(
jsonrpc="2.0",
id=2,
Expand All @@ -232,7 +232,7 @@ async def run_server():

# Get response and verify
response = await receive_stream2.receive()
assert response.root.root.result["content"][0]["text"] == "true"
assert response.message.root.result["content"][0]["text"] == "true"

# Cancel server task
tg.cancel_scope.cancel()
18 changes: 9 additions & 9 deletions tests/server/test_stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,25 @@ async def test_stdio_server():

# Verify received messages
assert len(received_messages) == 2
assert isinstance(received_messages[0].root, JSONRPCMessage)
assert isinstance(received_messages[0].root.root, JSONRPCRequest)
assert received_messages[0].root.root.id == 1
assert received_messages[0].root.root.method == "ping"
assert isinstance(received_messages[0].message, JSONRPCMessage)
assert isinstance(received_messages[0].message.root, JSONRPCRequest)
assert received_messages[0].message.root.id == 1
assert received_messages[0].message.root.method == "ping"

assert isinstance(received_messages[1].root, JSONRPCMessage)
assert isinstance(received_messages[1].root.root, JSONRPCResponse)
assert received_messages[1].root.root.id == 2
assert isinstance(received_messages[1].message, JSONRPCMessage)
assert isinstance(received_messages[1].message.root, JSONRPCResponse)
assert received_messages[1].message.root.id == 2

# Test sending responses from the server
responses = [
MessageFrame(
root=JSONRPCMessage(
message=JSONRPCMessage(
root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")
),
raw=None,
),
MessageFrame(
root=JSONRPCMessage(
message=JSONRPCMessage(
root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})
),
raw=None,
Expand Down