Skip to content

Commit 6d31905

Browse files
authored
fix: inject part into extension when handing init req (#275)
1 parent a62c6d1 commit 6d31905

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

crates/rmcp/src/transport/common/server_side_http.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,15 @@ pub(crate) const fn internal_error_response<E: Display>(
103103
}
104104
}
105105

106+
pub(crate) fn unexpected_message_response(
107+
expect: &str,
108+
) -> Response<UnsyncBoxBody<Bytes, Infallible>> {
109+
Response::builder()
110+
.status(http::StatusCode::UNPROCESSABLE_ENTITY)
111+
.body(Full::new(Bytes::from(format!("Unexpected message, expect {expect}"))).boxed_unsync())
112+
.expect("valid response")
113+
}
114+
106115
pub(crate) async fn expect_json<B>(
107116
body: B,
108117
) -> Result<ClientJsonRpcMessage, Response<UnsyncBoxBody<Bytes, Infallible>>>

crates/rmcp/src/transport/streamable_http_server/tower.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use tokio_stream::wrappers::ReceiverStream;
1010
use super::session::SessionManager;
1111
use crate::{
1212
RoleServer,
13-
model::{ClientJsonRpcMessage, GetExtensions},
13+
model::{ClientJsonRpcMessage, ClientRequest, GetExtensions},
1414
serve_server,
1515
service::serve_directly,
1616
transport::{
@@ -21,7 +21,7 @@ use crate::{
2121
},
2222
server_side_http::{
2323
BoxResponse, ServerSseMessage, accepted_response, expect_json,
24-
internal_error_response, sse_stream_response,
24+
internal_error_response, sse_stream_response, unexpected_message_response,
2525
},
2626
},
2727
},
@@ -318,6 +318,15 @@ where
318318
.create_session()
319319
.await
320320
.map_err(internal_error_response("create session"))?;
321+
if let ClientJsonRpcMessage::Request(req) = &mut message {
322+
if !matches!(req.request, ClientRequest::InitializeRequest(_)) {
323+
return Err(unexpected_message_response("initialize request"));
324+
}
325+
// inject request part to extensions
326+
req.request.extensions_mut().insert(part);
327+
} else {
328+
return Err(unexpected_message_response("initialize request"));
329+
}
321330
let service = self
322331
.get_service()
323332
.map_err(internal_error_response("get service"))?;
@@ -378,7 +387,8 @@ where
378387
.get_service()
379388
.map_err(internal_error_response("get service"))?;
380389
match message {
381-
ClientJsonRpcMessage::Request(request) => {
390+
ClientJsonRpcMessage::Request(mut request) => {
391+
request.request.extensions_mut().insert(part);
382392
let (transport, receiver) =
383393
OneshotTransport::<RoleServer>::new(ClientJsonRpcMessage::Request(request));
384394
let service = serve_directly(service, transport, None);

0 commit comments

Comments
 (0)