Skip to content

Support using an integer as the class/case discriminator in polymorphic serialization #2587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add some tests for serial polymorphic numbers which fail for now
Some miscellaneous changes:
1. Add a `getSerialPolymorphicNumberByBaseClass` function in `SerialDescriptor` to throw the appropriate exception.
1. Add `defaultDeserializerForNumber` which was missing in `PolymorphicModuleBuilder`.
  • Loading branch information
ShreckYe committed Feb 24, 2024
commit 51d9a75f55424324fe38f7a3441aab2907e85e8f
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public class SealedClassSerializer<T : Any>(
private val serialPolymorphicNumber2Serializer: Map<Int, KSerializer<out T>>? by lazy(LazyThreadSafetyMode.PUBLICATION) {
if (descriptor.useSerialPolymorphicNumbers)
class2Serializer.entries.groupingBy {
it.value.descriptor.serialPolymorphicNumberByBaseClass.getValue(baseClass)
it.value.descriptor.getSerialPolymorphicNumberByBaseClass(baseClass)
}
.aggregate<Map.Entry<KClass<out T>, KSerializer<out T>>, Int, Map.Entry<KClass<*>, KSerializer<out T>>>
{ key, accumulator, element, _ ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,13 @@ public interface SerialDescriptor {
@ExperimentalSerializationApi
public val serialPolymorphicNumberByBaseClass: Map<KClass<*>, Int> get() = emptyMap()

@ExperimentalSerializationApi
public fun getSerialPolymorphicNumberByBaseClass(baseClass: KClass<*>): Int =
serialPolymorphicNumberByBaseClass.getOrElse(baseClass) {
throw SerializationException("The serial polymorphic number for $serialName in the scope of ${baseClass.simpleName} is not found. " +
"Please annotate the class with `@SerialPolymorphicNumber` with the first argument ${baseClass.simpleName}.")
}

/**
* Returns serial annotations of the associated class.
* Serial annotations can be used to specify an additional metadata that may be used during serialization.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ public abstract class AbstractPolymorphicSerializer<T : Any> internal constructo
encoder.encodeStructure(descriptor) {
if (descriptor.useSerialPolymorphicNumbers)
encodeIntElement(
descriptor,
0,
actualSerializer.descriptor.serialPolymorphicNumberByBaseClass.getValue(baseClass)
descriptor, 0, actualSerializer.descriptor.getSerialPolymorphicNumberByBaseClass(baseClass)
)
else
encodeStringElement(descriptor, 0, actualSerializer.descriptor.serialName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class PolymorphicModuleBuilder<in Base : Any> @PublishedApi internal cons
private val subclasses: MutableList<Pair<KClass<out Base>, KSerializer<out Base>>> = mutableListOf()
private var defaultSerializerProvider: ((Base) -> SerializationStrategy<Base>?)? = null
private var defaultDeserializerProvider: PolymorphicDeserializerProvider<Base>? = null
private var defaultDeserializerProviderForNumber: PolymorphicDeserializerProviderForNumber<Base>? = null

/*
// TODO implement this or remove?
Expand Down Expand Up @@ -75,6 +76,16 @@ public class PolymorphicModuleBuilder<in Base : Any> @PublishedApi internal cons
this.defaultDeserializerProvider = defaultDeserializerProvider
}

/**
* TODO
*/
public fun defaultDeserializerForNumber(defaultDeserializerProviderForNumber: (serialPolymorphicNumber: Int?) -> DeserializationStrategy<Base>?) {
require(this.defaultDeserializerProviderForNumber == null) {
"Default deserializer provider for number is already registered for class $baseClass: ${this.defaultDeserializerProvider}"
}
this.defaultDeserializerProviderForNumber = defaultDeserializerProviderForNumber
}

/**
* Adds a default deserializers provider associated with the given [baseClass] to the resulting module.
* This function affect only deserialization process. To avoid confusion, it was deprecated and replaced with [defaultDeserializer].
Expand Down Expand Up @@ -123,6 +134,10 @@ public class PolymorphicModuleBuilder<in Base : Any> @PublishedApi internal cons
if (defaultDeserializer != null) {
builder.registerDefaultPolymorphicDeserializer(baseClass, defaultDeserializer, false)
}

defaultDeserializerProviderForNumber?.let {
builder.registerDefaultPolymorphicDeserializerForNumber(baseClass, it, false)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ public class SerializersModuleBuilder @PublishedApi internal constructor() : Ser
// Remove previous serializers from name mapping
if (previousSerializer != null) {
names.remove(previousSerializer.descriptor.serialName)
numbers.remove(previousSerializer.descriptor.serialPolymorphicNumberByBaseClass[baseClass])
previousSerializer.descriptor.serialPolymorphicNumberByBaseClass[baseClass]?.let { numbers.remove(it) }
}
// Update mappings
baseClassSerializers[concreteClass] = concreteSerializer
Expand All @@ -245,7 +245,7 @@ public class SerializersModuleBuilder @PublishedApi internal constructor() : Ser
} else {
// Cleanup name mapping
names.remove(previousSerializer.descriptor.serialName)
numbers.remove(previousSerializer.descriptor.serialPolymorphicNumberByBaseClass[baseClass])
previousSerializer.descriptor.serialPolymorphicNumberByBaseClass[baseClass]?.let { numbers.remove(it) }
}
}
val previousByName = names[name]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright 2017-2024 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.serialization

import kotlinx.serialization.json.*
import kotlinx.serialization.modules.*
import kotlinx.serialization.test.*
import kotlin.test.*

class SerialPolymorphicNumberTest {
@Serializable
@UseSerialPolymorphicNumbers
sealed class Sealed1 {
@Serializable
@SerialPolymorphicNumber(Sealed1::class, 1)
class Case : Sealed1()
}

@Serializable
sealed class Sealed2 {
@Serializable
@SerialPolymorphicNumber(Sealed2::class, 1)
class Case : Sealed2()
}

@Serializable
@UseSerialPolymorphicNumbers
sealed class Sealed3 {
@Serializable
class Case : Sealed3()
}

@Test
fun testSealed() {
testConversion<Sealed1>(Sealed1.Case(), """{"type":1}""")
testConversion<Sealed2>(Sealed2.Case(), """{"type":"kotlinx.serialization.SerialPolymorphicNumberTest.Case"}""")
assertFailsWith(SerializationException::class) {
Json.encodeToString<Sealed3>(Sealed3.Case())
}
assertFailsWith(SerializationException::class) {
Json.decodeFromString<Sealed3>("{}")
}
}

@Serializable
@UseSerialPolymorphicNumbers
sealed class Abstract {
@Serializable
@SerialPolymorphicNumber(Abstract::class, 1)
class Case : Abstract()

@Serializable
class Default(val type: Int?):Abstract()
}

val json = Json {
serializersModule = SerializersModule {
polymorphic(Abstract::class) {
subclass(Abstract.Case::class)
defaultDeserializerForNumber {
Abstract.Default.serializer()
}
}
}
}

@Test
fun testPolymorphicModule() {
testConversion<Abstract>(json, Abstract.Case(), """{"type":1}""")
testConversion<Abstract>(json, Abstract.Default(0), """{"type":0}""")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright 2017-2024 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.serialization.test

import kotlinx.serialization.*
import kotlinx.serialization.json.*
import kotlin.test.*

inline fun <reified T> testConversion(json: Json, data: T, expectedHexString: String) {
val string = json.encodeToString(data)
assertEquals(expectedHexString, string)
assertEquals(data, json.decodeFromString(string))
}

inline fun <reified T> testConversion(data: T, expectedHexString: String) =
testConversion(Json, data, expectedHexString)