Skip to content

Commit c5bf418

Browse files
committed
Refactor handlers and collections to use atomic references: use persistent collections for thread safety
1 parent e9eb109 commit c5bf418

File tree

5 files changed

+174
-109
lines changed

5 files changed

+174
-109
lines changed

build.gradle.kts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import org.jreleaser.model.Active
1414
plugins {
1515
alias(libs.plugins.kotlin.multiplatform)
1616
alias(libs.plugins.kotlin.serialization)
17+
alias(libs.plugins.kotlin.atomicfu)
1718
alias(libs.plugins.dokka)
1819
alias(libs.plugins.jreleaser)
1920
`maven-publish`
@@ -246,6 +247,7 @@ kotlin {
246247
kotlin.srcDir(generateLibVersionTask.map { it.sourcesDir })
247248
dependencies {
248249
api(libs.kotlinx.serialization.json)
250+
api(libs.kotlinx.collections.immutable)
249251
api(libs.ktor.client.cio)
250252
api(libs.ktor.server.cio)
251253
api(libs.ktor.server.sse)

gradle/libs.versions.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
# plugins version
33
kotlin = "2.2.0"
44
dokka = "2.0.0"
5+
atomicfu = "0.29.0"
56

67
# libraries version
78
serialization = "1.9.0"
9+
collections-immutable = "0.4.0"
810
coroutines = "1.10.2"
911
ktor = "3.2.1"
1012
mockk = "1.14.4"
@@ -17,6 +19,7 @@ kotest = "5.9.1"
1719
[libraries]
1820
# Kotlinx libraries
1921
kotlinx-serialization-json = { group = "org.jetbrains.kotlinx", name = "kotlinx-serialization-json", version.ref = "serialization" }
22+
kotlinx-collections-immutable = { group = "org.jetbrains.kotlinx", name = "kotlinx-collections-immutable", version.ref = "collections-immutable" }
2023
kotlin-logging = { group = "io.github.oshai", name = "kotlin-logging", version.ref = "logging" }
2124

2225
# Ktor
@@ -36,6 +39,7 @@ kotest-assertions-json = { group = "io.kotest", name = "kotest-assertions-json",
3639
[plugins]
3740
kotlin-multiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" }
3841
kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" }
42+
kotlin-atomicfu = { id = "org.jetbrains.kotlinx.atomicfu", version.ref = "atomicfu" }
3943
dokka = { id = "org.jetbrains.dokka", version.ref = "dokka" }
4044
jreleaser = { id = "org.jreleaser", version.ref = "jreleaser"}
4145
kotlinx-binary-compatibility-validator = { id = "org.jetbrains.kotlinx.binary-compatibility-validator", version.ref = "binaryCompatibilityValidatorPlugin" }

src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ import io.modelcontextprotocol.kotlin.sdk.shared.Protocol
4141
import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions
4242
import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions
4343
import io.modelcontextprotocol.kotlin.sdk.shared.Transport
44+
import kotlinx.atomicfu.atomic
45+
import kotlinx.atomicfu.getAndUpdate
46+
import kotlinx.atomicfu.update
47+
import kotlinx.collections.immutable.persistentMapOf
4448
import kotlinx.serialization.json.JsonElement
4549
import kotlinx.serialization.json.JsonNull
4650
import kotlinx.serialization.json.JsonObject
@@ -94,7 +98,7 @@ public open class Client(
9498

9599
private val capabilities: ClientCapabilities = options.capabilities
96100

97-
private val roots = mutableMapOf<String, Root>()
101+
private val roots = atomic(persistentMapOf<String, Root>())
98102

99103
init {
100104
logger.debug { "Initializing MCP client with capabilities: $capabilities" }
@@ -483,7 +487,7 @@ public open class Client(
483487
throw IllegalStateException("Client does not support roots capability.")
484488
}
485489
logger.info { "Adding root: $name ($uri)" }
486-
roots[uri] = Root(uri, name)
490+
roots.update { current -> current.put(uri, Root(uri, name)) }
487491
}
488492

489493
/**
@@ -498,10 +502,7 @@ public open class Client(
498502
throw IllegalStateException("Client does not support roots capability.")
499503
}
500504
logger.info { "Adding ${rootsToAdd.size} roots" }
501-
for (r in rootsToAdd) {
502-
logger.info { "Adding root: ${r.name} (${r.uri})" }
503-
roots[r.uri] = r
504-
}
505+
roots.update { current -> current.putAll(rootsToAdd.associateBy { it.uri }) }
505506
}
506507

507508
/**
@@ -517,7 +518,8 @@ public open class Client(
517518
throw IllegalStateException("Client does not support roots capability.")
518519
}
519520
logger.info { "Removing root: $uri" }
520-
val removed = roots.remove(uri) != null
521+
val oldMap = roots.getAndUpdate { current -> current.remove(uri) }
522+
val removed = uri in oldMap
521523
logger.debug {
522524
if (removed) {
523525
"Root removed: $uri"
@@ -541,13 +543,16 @@ public open class Client(
541543
throw IllegalStateException("Client does not support roots capability.")
542544
}
543545
logger.info { "Removing ${uris.size} roots" }
544-
var removedCount = 0
545-
for (uri in uris) {
546-
logger.debug { "Removing root: $uri" }
547-
if (roots.remove(uri) != null) {
548-
removedCount++
546+
547+
val oldMap = roots.getAndUpdate { current ->
548+
uris.fold(current) { map, uri ->
549+
logger.debug { "Removing root: $uri" }
550+
map.remove(uri)
549551
}
550552
}
553+
554+
val removedCount = uris.count { it in oldMap }
555+
551556
logger.info {
552557
if (removedCount > 0) {
553558
"Removed $removedCount roots"
@@ -571,7 +576,7 @@ public open class Client(
571576
// --- Internal Handlers ---
572577

573578
private suspend fun handleListRoots(): ListRootsResult {
574-
val rootList = roots.values.toList()
579+
val rootList = roots.value.values.toList()
575580
return ListRootsResult(rootList)
576581
}
577582
}

src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt

Lines changed: 63 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification
4444
import io.modelcontextprotocol.kotlin.sdk.shared.Protocol
4545
import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions
4646
import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions
47+
import kotlinx.atomicfu.atomic
48+
import kotlinx.atomicfu.getAndUpdate
49+
import kotlinx.atomicfu.update
50+
import kotlinx.collections.immutable.persistentMapOf
4751
import kotlinx.coroutines.CompletableDeferred
4852
import kotlinx.serialization.json.JsonObject
4953

@@ -91,9 +95,15 @@ public open class Server(
9195

9296
private val capabilities: ServerCapabilities = options.capabilities
9397

94-
private val tools = mutableMapOf<String, RegisteredTool>()
95-
private val prompts = mutableMapOf<String, RegisteredPrompt>()
96-
private val resources = mutableMapOf<String, RegisteredResource>()
98+
private val _tools = atomic(persistentMapOf<String, RegisteredTool>())
99+
private val _prompts = atomic(persistentMapOf<String, RegisteredPrompt>())
100+
private val _resources = atomic(persistentMapOf<String, RegisteredResource>())
101+
private val tools: Map<String, RegisteredTool>
102+
get() = _tools.value
103+
private val prompts: Map<String, RegisteredPrompt>
104+
get() = _prompts.value
105+
private val resources: Map<String, RegisteredResource>
106+
get() = _resources.value
97107

98108
init {
99109
logger.debug { "Initializing MCP server with capabilities: $capabilities" }
@@ -192,7 +202,9 @@ public open class Server(
192202
throw IllegalStateException("Server does not support tools capability. Enable it in ServerOptions.")
193203
}
194204
logger.info { "Registering tool: $name" }
195-
tools[name] = RegisteredTool(Tool(name, description, inputSchema, toolAnnotations), handler)
205+
_tools.update { current ->
206+
current.put(name, RegisteredTool(Tool(name, description, inputSchema, toolAnnotations), handler))
207+
}
196208
}
197209

198210
/**
@@ -207,10 +219,7 @@ public open class Server(
207219
throw IllegalStateException("Server does not support tools capability.")
208220
}
209221
logger.info { "Registering ${toolsToAdd.size} tools" }
210-
for (rt in toolsToAdd) {
211-
logger.debug { "Registering tool: ${rt.tool.name}" }
212-
tools[rt.tool.name] = rt
213-
}
222+
_tools.update { current -> current.putAll(toolsToAdd.associateBy { it.tool.name }) }
214223
}
215224

216225
/**
@@ -226,7 +235,10 @@ public open class Server(
226235
throw IllegalStateException("Server does not support tools capability.")
227236
}
228237
logger.info { "Removing tool: $name" }
229-
val removed = tools.remove(name) != null
238+
239+
val oldMap = _tools.getAndUpdate { current -> current.remove(name) }
240+
241+
val removed = name in oldMap
230242
logger.debug {
231243
if (removed) {
232244
"Tool removed: $name"
@@ -250,18 +262,20 @@ public open class Server(
250262
throw IllegalStateException("Server does not support tools capability.")
251263
}
252264
logger.info { "Removing ${toolNames.size} tools" }
253-
var removedCount = 0
254-
for (name in toolNames) {
255-
logger.debug { "Removing tool: $name" }
256-
if (tools.remove(name) != null) {
257-
removedCount++
265+
266+
val oldMap = _tools.getAndUpdate { current ->
267+
toolNames.fold(current) { map, name ->
268+
logger.debug { "Removing tool: $name" }
269+
map.remove(name)
258270
}
259271
}
272+
273+
val removedCount = toolNames.count { it in oldMap }
260274
logger.info {
261275
if (removedCount > 0) {
262-
"Removed $removedCount tools"
276+
"Removed $removedCount tools"
263277
} else {
264-
"No tools were removed"
278+
"No tools were removed"
265279
}
266280
}
267281
return removedCount
@@ -280,7 +294,7 @@ public open class Server(
280294
throw IllegalStateException("Server does not support prompts capability.")
281295
}
282296
logger.info { "Registering prompt: ${prompt.name}" }
283-
prompts[prompt.name] = RegisteredPrompt(prompt, promptProvider)
297+
_prompts.update { current -> current.put(prompt.name, RegisteredPrompt(prompt, promptProvider)) }
284298
}
285299

286300
/**
@@ -314,10 +328,7 @@ public open class Server(
314328
throw IllegalStateException("Server does not support prompts capability.")
315329
}
316330
logger.info { "Registering ${promptsToAdd.size} prompts" }
317-
for (rp in promptsToAdd) {
318-
logger.debug { "Registering prompt: ${rp.prompt.name}" }
319-
prompts[rp.prompt.name] = rp
320-
}
331+
_prompts.update { current -> current.putAll(promptsToAdd.associateBy { it.prompt.name }) }
321332
}
322333

323334
/**
@@ -333,7 +344,10 @@ public open class Server(
333344
throw IllegalStateException("Server does not support prompts capability.")
334345
}
335346
logger.info { "Removing prompt: $name" }
336-
val removed = prompts.remove(name) != null
347+
348+
val oldMap = _prompts.getAndUpdate { current -> current.remove(name) }
349+
350+
val removed = name in oldMap
337351
logger.debug {
338352
if (removed) {
339353
"Prompt removed: $name"
@@ -357,13 +371,16 @@ public open class Server(
357371
throw IllegalStateException("Server does not support prompts capability.")
358372
}
359373
logger.info { "Removing ${promptNames.size} prompts" }
360-
var removedCount = 0
361-
for (name in promptNames) {
362-
logger.debug { "Removing prompt: $name" }
363-
if (prompts.remove(name) != null) {
364-
removedCount++
374+
375+
val oldMap = _prompts.getAndUpdate { current ->
376+
promptNames.fold(current) { map, name ->
377+
logger.debug { "Removing prompt: $name" }
378+
map.remove(name)
365379
}
366380
}
381+
382+
val removedCount = promptNames.count { it in oldMap }
383+
367384
logger.info {
368385
if (removedCount > 0) {
369386
"Removed $removedCount prompts"
@@ -396,7 +413,12 @@ public open class Server(
396413
throw IllegalStateException("Server does not support resources capability.")
397414
}
398415
logger.info { "Registering resource: $name ($uri)" }
399-
resources[uri] = RegisteredResource(Resource(uri, name, description, mimeType), readHandler)
416+
_resources.update { current ->
417+
current.put(
418+
uri,
419+
RegisteredResource(Resource(uri, name, description, mimeType), readHandler)
420+
)
421+
}
400422
}
401423

402424
/**
@@ -411,10 +433,7 @@ public open class Server(
411433
throw IllegalStateException("Server does not support resources capability.")
412434
}
413435
logger.info { "Registering ${resourcesToAdd.size} resources" }
414-
for (r in resourcesToAdd) {
415-
logger.debug { "Registering resource: ${r.resource.name} (${r.resource.uri})" }
416-
resources[r.resource.uri] = r
417-
}
436+
_resources.update { current -> current.putAll(resourcesToAdd.associateBy { it.resource.uri }) }
418437
}
419438

420439
/**
@@ -430,7 +449,10 @@ public open class Server(
430449
throw IllegalStateException("Server does not support resources capability.")
431450
}
432451
logger.info { "Removing resource: $uri" }
433-
val removed = resources.remove(uri) != null
452+
453+
val oldMap = _resources.getAndUpdate { current -> current.remove(uri) }
454+
455+
val removed = uri in oldMap
434456
logger.debug {
435457
if (removed) {
436458
"Resource removed: $uri"
@@ -454,13 +476,16 @@ public open class Server(
454476
throw IllegalStateException("Server does not support resources capability.")
455477
}
456478
logger.info { "Removing ${uris.size} resources" }
457-
var removedCount = 0
458-
for (uri in uris) {
459-
logger.debug { "Removing resource: $uri" }
460-
if (resources.remove(uri) != null) {
461-
removedCount++
479+
480+
val oldMap = _resources.getAndUpdate { current ->
481+
uris.fold(current) { map, uri ->
482+
logger.debug { "Removing resource: $uri" }
483+
map.remove(uri)
462484
}
463485
}
486+
487+
val removedCount = uris.count { it in oldMap }
488+
464489
logger.info {
465490
if (removedCount > 0) {
466491
"Removed $removedCount resources"

0 commit comments

Comments
 (0)