Skip to content

Commit 1789529

Browse files
xerialclaude
andauthored
feature: Add ListMap support to ObjectWeaver (#133)
## Summary - Add serialization/deserialization support for `scala.collection.immutable.ListMap` - ListMap preserves insertion order of key-value pairs, unlike regular Map - Includes comprehensive tests covering basic operations, JSON serialization, nested structures, and error handling ## Test plan - [x] Basic ListMap serialization/deserialization - [x] Empty ListMap handling - [x] JSON round-trip conversion - [x] Nested ListMap structures - [x] Insertion order preservation verification - [x] Error handling for malformed data - [x] All existing tests pass 🤖 Generated with [Claude Code](https://claude.ai/code) --------- Co-authored-by: Claude <[email protected]>
1 parent 161b6b7 commit 1789529

File tree

2 files changed

+225
-44
lines changed

2 files changed

+225
-44
lines changed

ai-core/src/main/scala/wvlet/ai/core/weaver/codec/PrimitiveWeaver.scala

Lines changed: 87 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,61 @@ object PrimitiveWeaver:
7575
context.setError(e)
7676
None
7777

78+
private def unpackMapToBuffer[K, V](
79+
u: Unpacker,
80+
context: WeaverContext,
81+
keyWeaver: ObjectWeaver[K],
82+
valueWeaver: ObjectWeaver[V]
83+
): Option[ListBuffer[(K, V)]] =
84+
try
85+
val mapSize = u.unpackMapHeader
86+
val buffer = ListBuffer.empty[(K, V)]
87+
88+
var i = 0
89+
var hasError = false
90+
while i < mapSize && !hasError do
91+
// Unpack key
92+
val keyContext = WeaverContext(context.config)
93+
keyWeaver.unpack(u, keyContext)
94+
95+
if keyContext.hasError then
96+
context.setError(keyContext.getError.get)
97+
hasError = true
98+
// Skip remaining pairs to keep unpacker in consistent state
99+
while i < mapSize do
100+
u.skipValue // Skip key
101+
u.skipValue // Skip value
102+
i += 1
103+
else
104+
val key = keyContext.getLastValue.asInstanceOf[K]
105+
106+
// Unpack value
107+
val valueContext = WeaverContext(context.config)
108+
valueWeaver.unpack(u, valueContext)
109+
110+
if valueContext.hasError then
111+
context.setError(valueContext.getError.get)
112+
hasError = true
113+
// Skip remaining pairs to keep unpacker in consistent state
114+
while i + 1 < mapSize do
115+
u.skipValue // Skip key
116+
u.skipValue // Skip value
117+
i += 1
118+
else
119+
val value = valueContext.getLastValue.asInstanceOf[V]
120+
buffer += (key -> value)
121+
i += 1
122+
end while
123+
124+
if hasError then
125+
None
126+
else
127+
Some(buffer)
128+
catch
129+
case e: Exception =>
130+
context.setError(e)
131+
None
132+
78133
given intWeaver: ObjectWeaver[Int] =
79134
new ObjectWeaver[Int]:
80135
override def pack(p: Packer, v: Int, config: WeaverConfig): Unit = p.packInt(v)
@@ -457,51 +512,10 @@ object PrimitiveWeaver:
457512
override def unpack(u: Unpacker, context: WeaverContext): Unit =
458513
u.getNextValueType match
459514
case ValueType.MAP =>
460-
try
461-
val mapSize = u.unpackMapHeader
462-
val buffer = scala.collection.mutable.Map.empty[K, V]
463-
464-
var i = 0
465-
var hasError = false
466-
while i < mapSize && !hasError do
467-
// Unpack key
468-
val keyContext = WeaverContext(context.config)
469-
keyWeaver.unpack(u, keyContext)
470-
471-
if keyContext.hasError then
472-
context.setError(keyContext.getError.get)
473-
hasError = true
474-
// Skip remaining pairs to keep unpacker in consistent state
475-
while i < mapSize do
476-
u.skipValue // Skip key
477-
u.skipValue // Skip value
478-
i += 1
479-
else
480-
val key = keyContext.getLastValue.asInstanceOf[K]
481-
482-
// Unpack value
483-
val valueContext = WeaverContext(context.config)
484-
valueWeaver.unpack(u, valueContext)
485-
486-
if valueContext.hasError then
487-
context.setError(valueContext.getError.get)
488-
hasError = true
489-
// Skip remaining pairs to keep unpacker in consistent state
490-
while i + 1 < mapSize do
491-
u.skipValue // Skip key
492-
u.skipValue // Skip value
493-
i += 1
494-
else
495-
val value = valueContext.getLastValue.asInstanceOf[V]
496-
buffer += (key -> value)
497-
i += 1
498-
end while
499-
500-
if !hasError then
515+
unpackMapToBuffer(u, context, keyWeaver, valueWeaver) match
516+
case Some(buffer) =>
501517
context.setObject(buffer.toMap)
502-
catch
503-
case e: Exception =>
504-
context.setError(e)
518+
case None => // Error already set in unpackMapToBuffer
505519
case ValueType.NIL =>
506520
safeUnpackNil(context, u)
507521
case other =>
@@ -567,4 +581,33 @@ object PrimitiveWeaver:
567581
new IllegalArgumentException(s"Cannot convert ${other} to java.util.List")
568582
)
569583

584+
given listMapWeaver[K, V](using
585+
keyWeaver: ObjectWeaver[K],
586+
valueWeaver: ObjectWeaver[V]
587+
): ObjectWeaver[scala.collection.immutable.ListMap[K, V]] =
588+
new ObjectWeaver[scala.collection.immutable.ListMap[K, V]]:
589+
override def pack(
590+
p: Packer,
591+
v: scala.collection.immutable.ListMap[K, V],
592+
config: WeaverConfig
593+
): Unit =
594+
p.packMapHeader(v.size)
595+
v.foreach { case (key, value) =>
596+
keyWeaver.pack(p, key, config)
597+
valueWeaver.pack(p, value, config)
598+
}
599+
600+
override def unpack(u: Unpacker, context: WeaverContext): Unit =
601+
u.getNextValueType match
602+
case ValueType.MAP =>
603+
unpackMapToBuffer(u, context, keyWeaver, valueWeaver) match
604+
case Some(buffer) =>
605+
context.setObject(scala.collection.immutable.ListMap.from(buffer))
606+
case None => // Error already set in unpackMapToBuffer
607+
case ValueType.NIL =>
608+
safeUnpackNil(context, u)
609+
case other =>
610+
u.skipValue
611+
context.setError(new IllegalArgumentException(s"Cannot convert ${other} to ListMap"))
612+
570613
end PrimitiveWeaver

ai-core/src/test/scala/wvlet/ai/core/weaver/WeaverTest.scala

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,4 +359,142 @@ class WeaverTest extends AirSpec:
359359
result.get.getMessage.contains("Cannot convert") shouldBe true
360360
}
361361

362+
test("weave ListMap[String, Int]") {
363+
val v = scala.collection.immutable.ListMap("a" -> 1, "b" -> 2, "c" -> 3)
364+
val msgpack = ObjectWeaver.weave(v)
365+
val v2 = ObjectWeaver.unweave[scala.collection.immutable.ListMap[String, Int]](msgpack)
366+
v shouldBe v2
367+
// Verify order is preserved
368+
v.keys.toList shouldBe v2.keys.toList
369+
v.values.toList shouldBe v2.values.toList
370+
}
371+
372+
test("weave empty ListMap[String, Int]") {
373+
val v = scala.collection.immutable.ListMap.empty[String, Int]
374+
val msgpack = ObjectWeaver.weave(v)
375+
val v2 = ObjectWeaver.unweave[scala.collection.immutable.ListMap[String, Int]](msgpack)
376+
v shouldBe v2
377+
}
378+
379+
test("weave ListMap[Int, String]") {
380+
val v = scala.collection.immutable.ListMap(1 -> "one", 2 -> "two", 3 -> "three")
381+
val msgpack = ObjectWeaver.weave(v)
382+
val v2 = ObjectWeaver.unweave[scala.collection.immutable.ListMap[Int, String]](msgpack)
383+
v shouldBe v2
384+
// Verify order is preserved
385+
v.keys.toList shouldBe v2.keys.toList
386+
v.values.toList shouldBe v2.values.toList
387+
}
388+
389+
test("ListMap[String, Int] toJson") {
390+
val v = scala.collection.immutable.ListMap("x" -> 10, "y" -> 20, "z" -> 30)
391+
val json = ObjectWeaver.toJson(v)
392+
val v2 = ObjectWeaver.fromJson[scala.collection.immutable.ListMap[String, Int]](json)
393+
v shouldBe v2
394+
// Verify order is preserved
395+
v.keys.toList shouldBe v2.keys.toList
396+
v.values.toList shouldBe v2.values.toList
397+
}
398+
399+
test("nested ListMap[String, List[Int]]") {
400+
val v = scala
401+
.collection
402+
.immutable
403+
.ListMap("numbers" -> List(1, 2, 3), "more" -> List(4, 5), "empty" -> List.empty[Int])
404+
val msgpack = ObjectWeaver.weave(v)
405+
val v2 = ObjectWeaver.unweave[scala.collection.immutable.ListMap[String, List[Int]]](msgpack)
406+
v shouldBe v2
407+
// Verify order is preserved
408+
v.keys.toList shouldBe v2.keys.toList
409+
}
410+
411+
test("nested ListMap[String, ListMap[String, Int]]") {
412+
val v = scala
413+
.collection
414+
.immutable
415+
.ListMap(
416+
"group1" -> scala.collection.immutable.ListMap("a" -> 1, "b" -> 2),
417+
"group2" -> scala.collection.immutable.ListMap("x" -> 10, "y" -> 20),
418+
"empty" -> scala.collection.immutable.ListMap.empty[String, Int]
419+
)
420+
val msgpack = ObjectWeaver.weave(v)
421+
val v2 = ObjectWeaver.unweave[
422+
scala.collection.immutable.ListMap[String, scala.collection.immutable.ListMap[String, Int]]
423+
](msgpack)
424+
v shouldBe v2
425+
// Verify order is preserved for outer map
426+
v.keys.toList shouldBe v2.keys.toList
427+
// Verify order is preserved for inner maps
428+
v("group1").keys.toList shouldBe v2("group1").keys.toList
429+
v("group2").keys.toList shouldBe v2("group2").keys.toList
430+
}
431+
432+
test("ListMap preserves insertion order") {
433+
// Create ListMap with specific order
434+
val builder = scala.collection.immutable.ListMap.newBuilder[String, Int]
435+
builder += ("third" -> 3)
436+
builder += ("first" -> 1)
437+
builder += ("second" -> 2)
438+
val v = builder.result()
439+
440+
val msgpack = ObjectWeaver.weave(v)
441+
val v2 = ObjectWeaver.unweave[scala.collection.immutable.ListMap[String, Int]](msgpack)
442+
443+
// Verify values are correct
444+
v shouldBe v2
445+
// Verify insertion order is preserved
446+
v.keys.toList shouldBe List("third", "first", "second")
447+
v2.keys.toList shouldBe List("third", "first", "second")
448+
v.values.toList shouldBe List(3, 1, 2)
449+
v2.values.toList shouldBe List(3, 1, 2)
450+
}
451+
452+
test("handle malformed ListMap data gracefully") {
453+
import wvlet.ai.core.msgpack.spi.MessagePack
454+
// Create a malformed msgpack where we claim there are more pairs than we provide
455+
val packer = MessagePack.newPacker()
456+
packer.packMapHeader(3) // Say we have 3 key-value pairs
457+
packer.packString("key1") // Valid first key
458+
packer.packInt(1) // Valid first value
459+
packer.packString("key2") // Valid second key
460+
packer.packInt(2) // Valid second value
461+
// Missing third key-value pair!
462+
463+
val malformedMsgpack = packer.toByteArray
464+
465+
val result =
466+
try
467+
ObjectWeaver.unweave[scala.collection.immutable.ListMap[String, Int]](malformedMsgpack)
468+
None
469+
catch
470+
case e: Exception =>
471+
Some(e)
472+
473+
result.isDefined shouldBe true
474+
}
475+
476+
test("handle malformed ListMap value gracefully") {
477+
import wvlet.ai.core.msgpack.spi.MessagePack
478+
// Create a malformed msgpack map with wrong value type
479+
val packer = MessagePack.newPacker()
480+
packer.packMapHeader(2) // Say we have 2 key-value pairs
481+
packer.packString("key1") // Valid first key
482+
packer.packInt(1) // Valid first value
483+
packer.packString("key2") // Valid second key
484+
packer.packString("invalid") // Invalid second value for ListMap[String, Int]
485+
486+
val malformedMsgpack = packer.toByteArray
487+
488+
val result =
489+
try
490+
ObjectWeaver.unweave[scala.collection.immutable.ListMap[String, Int]](malformedMsgpack)
491+
None
492+
catch
493+
case e: Exception =>
494+
Some(e)
495+
496+
result.isDefined shouldBe true
497+
result.get.getMessage.contains("Cannot convert") shouldBe true
498+
}
499+
362500
end WeaverTest

0 commit comments

Comments
 (0)