Skip to content

Commit 2edc1d7

Browse files
Rewrite encoding to follow the contract of polymorphic serializers
1 parent 4901447 commit 2edc1d7

File tree

2 files changed

+79
-51
lines changed

2 files changed

+79
-51
lines changed

formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/internal/ProtobufEncoding.kt

Lines changed: 63 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,15 @@ internal open class ProtobufEncoder(
6060
}
6161
StructureKind.CLASS, StructureKind.OBJECT, is PolymorphicKind -> {
6262
val tag = currentTagOrDefault
63-
if (tag == MISSING_TAG && descriptor == this.descriptor) this
64-
else ObjectEncoder(proto, currentTagOrDefault, writer, descriptor = descriptor)
63+
if (tag == MISSING_TAG && descriptor == this.descriptor) {
64+
this
65+
} else if (tag.isOneOf) {
66+
OneOfPolymorphicEncoder(proto = proto, parentWriter = writer, descriptor = descriptor)
67+
} else {
68+
ObjectEncoder(proto, currentTagOrDefault, writer, descriptor = descriptor)
69+
}
6570
}
71+
6672
StructureKind.MAP -> MapRepeatedEncoder(proto, currentTagOrDefault, writer, descriptor)
6773
else -> throw SerializationException("This serial kind is not supported as structure: $descriptor")
6874
}
@@ -137,7 +143,6 @@ internal open class ProtobufEncoder(
137143
serializeMap(serializer as SerializationStrategy<T>, value)
138144
}
139145
serializer.descriptor == ByteArraySerializer().descriptor -> serializeByteArray(value as ByteArray)
140-
(currentTagOrDefault.isOneOf) -> encodeOneOfValue(serializer, value)
141146
else -> serializer.serialize(this, value)
142147
}
143148

@@ -150,32 +155,11 @@ internal open class ProtobufEncoder(
150155
}
151156
}
152157

153-
private fun <T> encodeOneOfValue(serializer: SerializationStrategy<T>, value: T) {
154-
if (serializer is AbstractPolymorphicSerializer) {
155-
val actual = serializer.findPolymorphicSerializerOrNull(this, value)
156-
if (actual != null) {
157-
actual.serialize(
158-
OneOfClassEncoder(
159-
proto,
160-
currentTag,
161-
writer,
162-
descriptor = actual.descriptor
163-
),
164-
value
165-
)
166-
} else {
167-
throw SerializationException("Cannot find available serializer for one-of field $value")
168-
}
169-
} else {
170-
throw SerializationException("Polymorphic class serializer expected for one-of field $value")
171-
}
172-
}
173-
174158
@Suppress("UNCHECKED_CAST")
175159
private fun <T> serializeMap(serializer: SerializationStrategy<T>, value: T) {
176160
// encode maps as collection of map entries, not merged collection of key-values
177161
val casted = (serializer as MapLikeSerializer<Any?, Any?, T, *>)
178-
val mapEntrySerial = kotlinx.serialization.builtins.MapEntrySerializer(casted.keySerializer, casted.valueSerializer)
162+
val mapEntrySerial = MapEntrySerializer(casted.keySerializer, casted.valueSerializer)
179163
SetSerializer(mapEntrySerial).serialize(this, (value as Map<*, *>).entries)
180164
}
181165
}
@@ -197,52 +181,80 @@ private open class ObjectEncoder(
197181
}
198182
}
199183

200-
private class OneOfClassEncoder(
184+
/**
185+
* When writing a one-of element with polymorphic serializer,
186+
* use [OneOfPolymorphicEncoder] to skip the first element of type name,
187+
* and then dispatch to [OneOfElementEncoder] when calling [beginStructure]
188+
* to write the content value, with ProtoNumber overridden by class annotation,
189+
* directly back to the output stream.
190+
*/
191+
private class OneOfPolymorphicEncoder(
201192
proto: ProtoBuf,
202-
parentTag: ProtoDesc,
203193
private val parentWriter: ProtobufWriter,
204194
descriptor: SerialDescriptor
205195
) : ProtobufEncoder(proto, parentWriter, descriptor) {
206196

207-
private val classProtoNumber: Int
208-
209197
init {
210-
require(descriptor.elementsCount == 1) {
211-
"Implementation of oneOf type ${descriptor.serialName} should contain only 1 element, but get ${descriptor.elementsCount}"
212-
}
213-
val protoNumber = descriptor.annotations.filterIsInstance<ProtoNumber>().singleOrNull()
214-
require(protoNumber != null) {
215-
"Implementation of oneOf type ${descriptor.serialName} should have @ProtoNumber annotation"
198+
require(descriptor.kind is PolymorphicKind) {
199+
"The serializer of one of type ${descriptor.serialName} should be using generic polymorphic serializer, but got ${descriptor.kind}"
216200
}
217-
classProtoNumber = protoNumber.number
218-
}
219201

220-
private val writeTag: ProtoDesc = parentTag.overrideId(classProtoNumber)
202+
// Do we need this strict check?
203+
require(descriptor.getElementName(0) == "type" && descriptor.getElementDescriptor(0).kind == PrimitiveKind.STRING)
204+
}
221205

222206
override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder {
223-
val tag = if (currentTagOrDefault == MISSING_TAG) {
224-
writeTag
225-
} else {
226-
currentTagOrDefault
227-
}
228207
return if (descriptor == this.descriptor) {
229208
this
230-
} else if (tag.isOneOf) {
231-
OneOfClassEncoder(proto, descriptor.extractClassDesc(), parentWriter, descriptor = descriptor)
232209
} else {
233-
ObjectEncoder(proto, tag, parentWriter, descriptor = descriptor)
210+
OneOfElementEncoder(
211+
proto = proto,
212+
parentWriter = parentWriter,
213+
descriptor = descriptor
214+
)
234215
}
235216
}
236217

237-
override fun encodeInline(descriptor: SerialDescriptor): Encoder {
238-
return encodeTaggedInline(writeTag, descriptor)
218+
override fun encodeTaggedString(tag: ProtoDesc, value: String) {
219+
// the first element with type string is the discriminator of polymorphic serializer with class name
220+
// just ignore it
221+
if (tag != MISSING_TAG) {
222+
super.encodeTaggedString(tag, value)
223+
}
239224
}
240225

241-
override fun encodeTaggedInline(tag: ProtoDesc, inlineDescriptor: SerialDescriptor): Encoder {
242-
return super.encodeTaggedInline(tag, inlineDescriptor)
226+
override fun SerialDescriptor.getTag(index: Int) = when (index) {
227+
// 0 for discriminator
228+
0 -> MISSING_TAG
229+
1 -> extractParameters(index)
230+
else -> throw SerializationException("Unsupported index: $index in a oneOf type $serialName, which should be using generic polymorphic serializer")
231+
}
232+
}
233+
234+
/**
235+
* A helper encoder for one-of element to write the content value,
236+
* with ProtoNumber overridden by class annotation,
237+
* directly back to the output stream.
238+
*/
239+
private class OneOfElementEncoder(
240+
proto: ProtoBuf,
241+
parentWriter: ProtobufWriter,
242+
descriptor: SerialDescriptor
243+
) : ProtobufEncoder(proto, parentWriter, descriptor) {
244+
private val classId: Int
245+
246+
init {
247+
require(descriptor.elementsCount == 1) {
248+
"Implementation of oneOf type ${descriptor.serialName} should contain only 1 element, but get ${descriptor.elementsCount}"
249+
}
250+
val protoNumber = descriptor.annotations.filterIsInstance<ProtoNumber>().singleOrNull()
251+
require(protoNumber != null) {
252+
"Implementation of oneOf type ${descriptor.serialName} should have @ProtoNumber annotation"
253+
}
254+
classId = protoNumber.number
243255
}
244-
override fun SerialDescriptor.getTag(index: Int) = extractParameters(index).overrideId(classProtoNumber)
245256

257+
override fun SerialDescriptor.getTag(index: Int): ProtoDesc = extractParameters(index).overrideId(classId)
246258
}
247259

248260
private class MapRepeatedEncoder(

formats/protobuf/commonTest/src/kotlinx/serialization/protobuf/ProtobufOneOfTest.kt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,4 +418,20 @@ class ProtobufOneOfTest {
418418
assertEquals("082a", buf.encodeToHexString(CustomOuter.serializer(), data))
419419
}
420420

421+
@Serializable
422+
data class CustomAnyData(@ProtoOneOf(1, 2) @Polymorphic val inner: Any)
423+
424+
@Test
425+
fun testCustomAny() {
426+
val module = SerializersModule {
427+
polymorphic(Any::class) {
428+
subclass(CustomInnerInt::class, CustomerInnerIntSerializer)
429+
}
430+
}
431+
val data = CustomAnyData(CustomInnerInt(42))
432+
val buf = ProtoBuf { serializersModule = module }
433+
assertEquals("082a", buf.encodeToHexString(data))
434+
assertEquals(data, buf.decodeFromHexString<CustomAnyData>("082a"))
435+
}
436+
421437
}

0 commit comments

Comments
 (0)