Skip to content

Commit 0010767

Browse files
SeanChinJunKaimichaellatman
authored andcommitted
feat: add StreamableHttpTransport for server
1 parent 55666b2 commit 0010767

File tree

1 file changed

+304
-0
lines changed

1 file changed

+304
-0
lines changed
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
package io.modelcontextprotocol.kotlin.sdk.server
2+
3+
import io.ktor.http.*
4+
import io.ktor.server.application.*
5+
import io.ktor.server.request.*
6+
import io.ktor.server.response.*
7+
import io.ktor.server.sse.*
8+
import io.modelcontextprotocol.kotlin.sdk.*
9+
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
10+
import io.modelcontextprotocol.kotlin.sdk.shared.McpJson
11+
import kotlinx.serialization.encodeToString
12+
import kotlin.collections.HashMap
13+
import kotlin.concurrent.atomics.AtomicBoolean
14+
import kotlin.concurrent.atomics.ExperimentalAtomicApi
15+
import kotlin.uuid.ExperimentalUuidApi
16+
import kotlin.uuid.Uuid
17+
18+
@OptIn(ExperimentalAtomicApi::class)
19+
public class StreamableHttpServerTransport(
20+
private val isStateful: Boolean = false,
21+
private val enableJSONResponse: Boolean = false,
22+
): AbstractTransport() {
23+
private val standalone = "standalone"
24+
private val streamMapping: HashMap<String, ServerSSESession> = hashMapOf()
25+
private val requestToStreamMapping: HashMap<RequestId, String> = hashMapOf()
26+
private val requestResponseMapping: HashMap<RequestId, JSONRPCMessage> = hashMapOf()
27+
private val callMapping: HashMap<String, ApplicationCall> = hashMapOf()
28+
private val started: AtomicBoolean = AtomicBoolean(false)
29+
private val initialized: AtomicBoolean = AtomicBoolean(false)
30+
31+
public var sessionId: String? = null
32+
private set
33+
34+
override suspend fun start() {
35+
if (!started.compareAndSet(false, true)) {
36+
error("StreamableHttpServerTransport already started! If using Server class, note that connect() calls start() automatically.")
37+
}
38+
}
39+
40+
override suspend fun send(message: JSONRPCMessage) {
41+
var requestId: RequestId? = null
42+
43+
if (message is JSONRPCResponse) {
44+
requestId = message.id
45+
}
46+
47+
if (requestId == null) {
48+
val standaloneSSE = streamMapping[standalone] ?: return
49+
50+
standaloneSSE.send(
51+
event = "message",
52+
data = McpJson.encodeToString(message),
53+
)
54+
return
55+
}
56+
57+
val streamId = requestToStreamMapping[requestId] ?: error("No connection established for request id $requestId")
58+
val correspondingStream = streamMapping[streamId] ?: error("No connection established for request id $requestId")
59+
val correspondingCall = callMapping[streamId] ?: error("No connection established for request id $requestId")
60+
61+
if (!enableJSONResponse) {
62+
correspondingStream.send(
63+
event = "message",
64+
data = McpJson.encodeToString(message),
65+
)
66+
}
67+
68+
requestResponseMapping[requestId] = message
69+
val relatedIds = requestToStreamMapping.entries.filter { streamMapping[it.value] == correspondingStream }.map { it.key }
70+
val allResponsesReady = relatedIds.all { requestResponseMapping[it] != null }
71+
72+
if (allResponsesReady) {
73+
if (enableJSONResponse) {
74+
correspondingCall.response.headers.append(ContentType.toString(), ContentType.Application.Json.toString())
75+
correspondingCall.response.status(HttpStatusCode.OK)
76+
if (sessionId != null) {
77+
correspondingCall.response.header("Mcp-Session-Id", sessionId!!)
78+
}
79+
val responses = relatedIds.map{ requestResponseMapping[it] }
80+
if (responses.size == 1) {
81+
correspondingCall.respond(responses[0]!!)
82+
} else {
83+
correspondingCall.respond(responses)
84+
}
85+
callMapping.remove(streamId)
86+
} else {
87+
correspondingStream.close()
88+
streamMapping.remove(streamId)
89+
}
90+
91+
for (id in relatedIds) {
92+
requestToStreamMapping.remove(id)
93+
requestResponseMapping.remove(id)
94+
}
95+
}
96+
97+
}
98+
99+
override suspend fun close() {
100+
streamMapping.values.forEach {
101+
it.close()
102+
}
103+
streamMapping.clear()
104+
requestToStreamMapping.clear()
105+
requestResponseMapping.clear()
106+
// TODO Check if we need to clear the callMapping or if call timeout after awhile
107+
_onClose.invoke()
108+
}
109+
110+
@OptIn(ExperimentalUuidApi::class)
111+
public suspend fun handlePostRequest(call: ApplicationCall, session: ServerSSESession) {
112+
try {
113+
val acceptHeader = call.request.headers["Accept"]?.split(",") ?: listOf()
114+
115+
if (!acceptHeader.contains("text/event-stream") || !acceptHeader.contains("application/json")) {
116+
call.response.status(HttpStatusCode.NotAcceptable)
117+
call.respond(
118+
JSONRPCResponse(
119+
id = null,
120+
error = JSONRPCError(
121+
code = ErrorCode.Unknown(-32000),
122+
message = "Not Acceptable: Client must accept both application/json and text/event-stream"
123+
)
124+
)
125+
)
126+
return
127+
}
128+
129+
val contentType = call.request.contentType()
130+
if (contentType != ContentType.Application.Json) {
131+
call.response.status(HttpStatusCode.UnsupportedMediaType)
132+
call.respond(
133+
JSONRPCResponse(
134+
id = null,
135+
error = JSONRPCError(
136+
code = ErrorCode.Unknown(-32000),
137+
message = "Unsupported Media Type: Content-Type must be application/json"
138+
)
139+
)
140+
)
141+
return
142+
}
143+
144+
val body = call.receiveText()
145+
val messages = mutableListOf<JSONRPCMessage>()
146+
147+
if (body.startsWith("[")) {
148+
messages.addAll(McpJson.decodeFromString<List<JSONRPCMessage>>(body))
149+
} else {
150+
messages.add(McpJson.decodeFromString(body))
151+
}
152+
153+
val hasInitializationRequest = messages.any { it is JSONRPCRequest && it.method == "initialize" }
154+
if (hasInitializationRequest) {
155+
if (initialized.load() && sessionId != null) {
156+
call.response.status(HttpStatusCode.BadRequest)
157+
call.respond(
158+
JSONRPCResponse(
159+
id = null,
160+
error = JSONRPCError(
161+
code = ErrorCode.Defined.InvalidRequest,
162+
message = "Invalid Request: Server already initialized"
163+
)
164+
)
165+
)
166+
return
167+
}
168+
169+
if (messages.size > 1) {
170+
call.response.status(HttpStatusCode.BadRequest)
171+
call.respond(
172+
JSONRPCResponse(
173+
id = null,
174+
error = JSONRPCError(
175+
code = ErrorCode.Defined.InvalidRequest,
176+
message = "Invalid Request: Only one initialization request is allowed"
177+
)
178+
)
179+
)
180+
return
181+
}
182+
183+
if (isStateful) {
184+
sessionId = Uuid.random().toString()
185+
}
186+
initialized.store(true)
187+
188+
if (!validateSession(call)) {
189+
return
190+
}
191+
192+
val hasRequests = messages.any { it is JSONRPCRequest }
193+
val streamId = Uuid.random().toString()
194+
195+
if (!hasRequests){
196+
call.respondNullable(HttpStatusCode.Accepted)
197+
} else {
198+
if (!enableJSONResponse) {
199+
call.response.headers.append(ContentType.toString(), ContentType.Text.EventStream.toString())
200+
201+
if (sessionId != null) {
202+
call.response.header("Mcp-Session-Id", sessionId!!)
203+
}
204+
}
205+
206+
for (message in messages) {
207+
if (message is JSONRPCRequest) {
208+
streamMapping[streamId] = session
209+
callMapping[streamId] = call
210+
requestToStreamMapping[message.id] = streamId
211+
}
212+
}
213+
}
214+
for (message in messages) {
215+
_onMessage.invoke(message)
216+
}
217+
}
218+
219+
} catch (e: Exception) {
220+
call.response.status(HttpStatusCode.BadRequest)
221+
call.respond(
222+
JSONRPCResponse(
223+
id = null,
224+
error = JSONRPCError(
225+
code = ErrorCode.Unknown(-32000),
226+
message = e.message ?: "Parse error"
227+
)
228+
)
229+
)
230+
_onError.invoke(e)
231+
}
232+
}
233+
234+
public suspend fun handleGetRequest(call: ApplicationCall, session: ServerSSESession) {
235+
val acceptHeader = call.request.headers["Accept"]?.split(",") ?: listOf()
236+
if (!acceptHeader.contains("text/event-stream")) {
237+
call.response.status(HttpStatusCode.NotAcceptable)
238+
call.respond(
239+
JSONRPCResponse(
240+
id = null,
241+
error = JSONRPCError(
242+
code = ErrorCode.Unknown(-32000),
243+
message = "Not Acceptable: Client must accept text/event-stream"
244+
)
245+
)
246+
)
247+
}
248+
249+
if (!validateSession(call)) {
250+
return
251+
}
252+
253+
if (sessionId != null) {
254+
call.response.header("Mcp-Session-Id", sessionId!!)
255+
}
256+
257+
if (streamMapping[standalone] != null) {
258+
call.response.status(HttpStatusCode.Conflict)
259+
call.respond(
260+
JSONRPCResponse(
261+
id = null,
262+
error = JSONRPCError(
263+
code = ErrorCode.Unknown(-32000),
264+
message = "Conflict: Only one SSE stream is allowed per session"
265+
)
266+
)
267+
)
268+
session.close()
269+
return
270+
}
271+
272+
// TODO: Equivalent of typescript res.writeHead(200, headers).flushHeaders();
273+
streamMapping[standalone] = session
274+
}
275+
276+
public suspend fun handleDeleteRequest(call: ApplicationCall) {
277+
if (!validateSession(call)) {
278+
return
279+
}
280+
close()
281+
call.respondNullable(HttpStatusCode.OK)
282+
}
283+
284+
public suspend fun validateSession(call: ApplicationCall): Boolean {
285+
if (sessionId == null) {
286+
return true
287+
}
288+
289+
if (!initialized.load()) {
290+
call.response.status(HttpStatusCode.BadRequest)
291+
call.respond(
292+
JSONRPCResponse(
293+
id = null,
294+
error = JSONRPCError(
295+
code = ErrorCode.Unknown(-32000),
296+
message = "Bad Request: Server not initialized"
297+
)
298+
)
299+
)
300+
return false
301+
}
302+
return true
303+
}
304+
}

0 commit comments

Comments
 (0)