Skip to content

Commit 88b4218

Browse files
jyeminrozza
andauthored
Allow generic base classes for POJOs (#1423)
This change fixes a regression which prevents the driver from encoding and decoding concrete classes which extend generic base classes, when the base class is specified as the generic type of the MongoCollection. JAVA-5173 Co-authored-by: Ross Lawley <[email protected]>
1 parent d8503c3 commit 88b4218

File tree

9 files changed

+248
-21
lines changed

9 files changed

+248
-21
lines changed

bson/src/main/org/bson/codecs/pojo/LazyPropertyModelCodec.java

+29-4
Original file line numberDiff line numberDiff line change
@@ -163,19 +163,44 @@ private <V> PropertyModel<V> getSpecializedPropertyModel(final PropertyModel<V>
163163
static final class NeedSpecializationCodec<T> extends PojoCodec<T> {
164164
private final ClassModel<T> classModel;
165165
private final DiscriminatorLookup discriminatorLookup;
166+
private final CodecRegistry codecRegistry;
166167

167-
NeedSpecializationCodec(final ClassModel<T> classModel, final DiscriminatorLookup discriminatorLookup) {
168+
NeedSpecializationCodec(final ClassModel<T> classModel, final DiscriminatorLookup discriminatorLookup, final CodecRegistry codecRegistry) {
168169
this.classModel = classModel;
169170
this.discriminatorLookup = discriminatorLookup;
171+
this.codecRegistry = codecRegistry;
170172
}
171173

172174
@Override
173-
public T decode(final BsonReader reader, final DecoderContext decoderContext) {
174-
throw exception();
175+
public void encode(final BsonWriter writer, final T value, final EncoderContext encoderContext) {
176+
if (value.getClass().equals(classModel.getType())) {
177+
throw exception();
178+
}
179+
tryEncode(codecRegistry.get(value.getClass()), writer, value, encoderContext);
175180
}
176181

177182
@Override
178-
public void encode(final BsonWriter writer, final T value, final EncoderContext encoderContext) {
183+
public T decode(final BsonReader reader, final DecoderContext decoderContext) {
184+
return tryDecode(reader, decoderContext);
185+
}
186+
187+
@SuppressWarnings("unchecked")
188+
private <A> void tryEncode(final Codec<A> codec, final BsonWriter writer, final T value, final EncoderContext encoderContext) {
189+
try {
190+
codec.encode(writer, (A) value, encoderContext);
191+
} catch (Exception e) {
192+
throw exception();
193+
}
194+
}
195+
196+
@SuppressWarnings("unchecked")
197+
public T tryDecode(final BsonReader reader, final DecoderContext decoderContext) {
198+
Codec<T> codec = PojoCodecImpl.<T>getCodecFromDocument(reader, classModel.useDiscriminator(), classModel.getDiscriminatorKey(),
199+
codecRegistry, discriminatorLookup, null, classModel.getName());
200+
if (codec != null) {
201+
return codec.decode(reader, decoderContext);
202+
}
203+
179204
throw exception();
180205
}
181206

bson/src/main/org/bson/codecs/pojo/PojoCodecImpl.java

+10-8
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ public T decode(final BsonReader reader, final DecoderContext decoderContext) {
101101
return instanceCreator.getInstance();
102102
} else {
103103
return getCodecFromDocument(reader, classModel.useDiscriminator(), classModel.getDiscriminatorKey(), registry,
104-
discriminatorLookup, this).decode(reader, DecoderContext.builder().checkedDiscriminator(true).build());
104+
discriminatorLookup, this, classModel.getName())
105+
.decode(reader, DecoderContext.builder().checkedDiscriminator(true).build());
105106
}
106107
}
107108

@@ -275,10 +276,11 @@ private <S, V> boolean areEquivalentTypes(final Class<S> t1, final Class<V> t2)
275276
}
276277

277278
@SuppressWarnings("unchecked")
278-
private Codec<T> getCodecFromDocument(final BsonReader reader, final boolean useDiscriminator, final String discriminatorKey,
279-
final CodecRegistry registry, final DiscriminatorLookup discriminatorLookup,
280-
final Codec<T> defaultCodec) {
281-
Codec<T> codec = defaultCodec;
279+
@Nullable
280+
static <C> Codec<C> getCodecFromDocument(final BsonReader reader, final boolean useDiscriminator, final String discriminatorKey,
281+
final CodecRegistry registry, final DiscriminatorLookup discriminatorLookup, @Nullable final Codec<C> defaultCodec,
282+
final String simpleClassName) {
283+
Codec<C> codec = defaultCodec;
282284
if (useDiscriminator) {
283285
BsonReaderMark mark = reader.getMark();
284286
reader.readStartDocument();
@@ -289,12 +291,12 @@ private Codec<T> getCodecFromDocument(final BsonReader reader, final boolean use
289291
discriminatorKeyFound = true;
290292
try {
291293
Class<?> discriminatorClass = discriminatorLookup.lookup(reader.readString());
292-
if (!codec.getEncoderClass().equals(discriminatorClass)) {
293-
codec = (Codec<T>) registry.get(discriminatorClass);
294+
if (codec == null || !codec.getEncoderClass().equals(discriminatorClass)) {
295+
codec = (Codec<C>) registry.get(discriminatorClass);
294296
}
295297
} catch (Exception e) {
296298
throw new CodecConfigurationException(format("Failed to decode '%s'. Decoding errored with: %s",
297-
classModel.getName(), e.getMessage()), e);
299+
simpleClassName, e.getMessage()), e);
298300
}
299301
} else {
300302
reader.skipValue();

bson/src/main/org/bson/codecs/pojo/PojoCodecProvider.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ private static <T> PojoCodec<T> createCodec(final ClassModel<T> classModel, fina
9797
final List<PropertyCodecProvider> propertyCodecProviders, final DiscriminatorLookup discriminatorLookup) {
9898
return shouldSpecialize(classModel)
9999
? new PojoCodecImpl<>(classModel, codecRegistry, propertyCodecProviders, discriminatorLookup)
100-
: new LazyPropertyModelCodec.NeedSpecializationCodec<>(classModel, discriminatorLookup);
100+
: new LazyPropertyModelCodec.NeedSpecializationCodec<>(classModel, discriminatorLookup, codecRegistry);
101101
}
102102

103103
/**

bson/src/test/unit/org/bson/codecs/pojo/PojoCustomTest.java

+15-1
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,14 @@
3838
import org.bson.codecs.pojo.entities.BsonRepresentationUnsupportedString;
3939
import org.bson.codecs.pojo.entities.ConcreteAndNestedAbstractInterfaceModel;
4040
import org.bson.codecs.pojo.entities.ConcreteCollectionsModel;
41+
import org.bson.codecs.pojo.entities.ConcreteModel;
42+
import org.bson.codecs.pojo.entities.ConcreteField;
4143
import org.bson.codecs.pojo.entities.ConcreteStandAloneAbstractInterfaceModel;
4244
import org.bson.codecs.pojo.entities.ConstructorNotPublicModel;
4345
import org.bson.codecs.pojo.entities.ConventionModel;
4446
import org.bson.codecs.pojo.entities.ConverterModel;
4547
import org.bson.codecs.pojo.entities.CustomPropertyCodecOptionalModel;
48+
import org.bson.codecs.pojo.entities.GenericBaseModel;
4649
import org.bson.codecs.pojo.entities.GenericHolderModel;
4750
import org.bson.codecs.pojo.entities.GenericTreeModel;
4851
import org.bson.codecs.pojo.entities.InterfaceBasedModel;
@@ -545,6 +548,17 @@ public void testInvalidDiscriminatorInNestedModel() {
545548
+ "'simple': {'_t': 'FakeModel', 'integerField': 42, 'stringField': 'myString'}}"));
546549
}
547550

551+
@Test
552+
public void testGenericBaseClass() {
553+
CodecRegistry registry = fromProviders(new ValueCodecProvider(), PojoCodecProvider.builder().automatic(true).build());
554+
555+
ConcreteModel model = new ConcreteModel(new ConcreteField("name1"));
556+
557+
String json = "{\"_t\": \"org.bson.codecs.pojo.entities.ConcreteModel\", \"field\": {\"name\": \"name1\"}}";
558+
roundTrip(PojoCodecProvider.builder().automatic(true), GenericBaseModel.class, model, json);
559+
}
560+
561+
548562
@Test
549563
public void testCannotEncodeUnspecializedClasses() {
550564
CodecRegistry registry = fromProviders(getPojoCodecProviderBuilder(GenericTreeModel.class).build());
@@ -553,7 +567,7 @@ public void testCannotEncodeUnspecializedClasses() {
553567
}
554568

555569
@Test
556-
public void testCannotDecodeUnspecializedClasses() {
570+
public void testCannotDecodeUnspecializedClassesWithoutADiscriminator() {
557571
assertThrows(CodecConfigurationException.class, () ->
558572
decodingShouldFail(getCodec(GenericTreeModel.class),
559573
"{'field1': 'top', 'field2': 1, "

bson/src/test/unit/org/bson/codecs/pojo/PojoTestCase.java

+25-7
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,12 @@ <T> void roundTrip(final T value, final String json) {
9090
}
9191

9292
<T> void roundTrip(final PojoCodecProvider.Builder builder, final T value, final String json) {
93-
encodesTo(getCodecRegistry(builder), value, json);
94-
decodesTo(getCodecRegistry(builder), json, value);
93+
roundTrip(builder, value.getClass(), value, json);
94+
}
95+
96+
<T> void roundTrip(final PojoCodecProvider.Builder builder, final Class<?> clazz, final T value, final String json) {
97+
encodesTo(getCodecRegistry(builder), clazz, value, json);
98+
decodesTo(getCodecRegistry(builder), clazz, json, value);
9599
}
96100

97101
<T> void threadedRoundTrip(final PojoCodecProvider.Builder builder, final T value, final String json) {
@@ -109,21 +113,30 @@ <T> void roundTrip(final CodecRegistry registry, final T value, final String jso
109113
decodesTo(registry, json, value);
110114
}
111115

116+
<T> void roundTrip(final CodecRegistry registry, final Class<T> clazz, final T value, final String json) {
117+
encodesTo(registry, clazz, value, json);
118+
decodesTo(registry, clazz, json, value);
119+
}
120+
112121
<T> void encodesTo(final PojoCodecProvider.Builder builder, final T value, final String json) {
113122
encodesTo(builder, value, json, false);
114123
}
115124

116125
<T> void encodesTo(final PojoCodecProvider.Builder builder, final T value, final String json, final boolean collectible) {
117-
encodesTo(getCodecRegistry(builder), value, json, collectible);
126+
encodesTo(getCodecRegistry(builder), value.getClass(), value, json, collectible);
118127
}
119128

120129
<T> void encodesTo(final CodecRegistry registry, final T value, final String json) {
121-
encodesTo(registry, value, json, false);
130+
encodesTo(registry, value.getClass(), value, json, false);
131+
}
132+
133+
<T> void encodesTo(final CodecRegistry registry, final Class<?> clazz, final T value, final String json) {
134+
encodesTo(registry, clazz, value, json, false);
122135
}
123136

124137
@SuppressWarnings("unchecked")
125-
<T> void encodesTo(final CodecRegistry registry, final T value, final String json, final boolean collectible) {
126-
Codec<T> codec = (Codec<T>) registry.get(value.getClass());
138+
<T> void encodesTo(final CodecRegistry registry, final Class<?> clazz, final T value, final String json, final boolean collectible) {
139+
Codec<T> codec = (Codec<T>) registry.get(clazz);
127140
encodesTo(codec, value, json, collectible);
128141
}
129142

@@ -144,7 +157,12 @@ <T> void decodesTo(final PojoCodecProvider.Builder builder, final String json, f
144157

145158
@SuppressWarnings("unchecked")
146159
<T> void decodesTo(final CodecRegistry registry, final String json, final T expected) {
147-
Codec<T> codec = (Codec<T>) registry.get(expected.getClass());
160+
decodesTo(registry, expected.getClass(), json, expected);
161+
}
162+
163+
@SuppressWarnings("unchecked")
164+
<T> void decodesTo(final CodecRegistry registry, final Class<?> clazz, final String json, final T expected) {
165+
Codec<T> codec = (Codec<T>) registry.get(clazz);
148166
decodesTo(codec, json, expected);
149167
}
150168

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.bson.codecs.pojo.entities;
18+
19+
import java.util.Objects;
20+
21+
public abstract class BaseField {
22+
private String name;
23+
24+
public BaseField(final String name) {
25+
this.name = name;
26+
}
27+
28+
protected BaseField() {
29+
}
30+
31+
public String getName() {
32+
return name;
33+
}
34+
35+
public void setName(final String name) {
36+
this.name = name;
37+
}
38+
39+
@Override
40+
public boolean equals(final Object o) {
41+
if (this == o) {
42+
return true;
43+
}
44+
if (o == null || getClass() != o.getClass()) {
45+
return false;
46+
}
47+
BaseField baseField = (BaseField) o;
48+
return Objects.equals(name, baseField.name);
49+
}
50+
51+
@Override
52+
public int hashCode() {
53+
return Objects.hashCode(name);
54+
}
55+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.bson.codecs.pojo.entities;
18+
19+
public class ConcreteField extends BaseField {
20+
21+
public ConcreteField() {
22+
}
23+
24+
public ConcreteField(final String name) {
25+
super(name);
26+
}
27+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.bson.codecs.pojo.entities;
18+
19+
public class ConcreteModel extends GenericBaseModel<ConcreteField> {
20+
21+
public ConcreteModel() {
22+
}
23+
24+
public ConcreteModel(final ConcreteField field) {
25+
super(field);
26+
}
27+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.bson.codecs.pojo.entities;
18+
19+
import org.bson.codecs.pojo.annotations.BsonDiscriminator;
20+
21+
import java.util.Objects;
22+
23+
@BsonDiscriminator()
24+
public class GenericBaseModel<T extends BaseField> {
25+
26+
private T field;
27+
28+
public GenericBaseModel(final T field) {
29+
this.field = field;
30+
}
31+
32+
public GenericBaseModel() {
33+
}
34+
35+
public T getField() {
36+
return field;
37+
}
38+
39+
public void setField(final T field) {
40+
this.field = field;
41+
}
42+
43+
@Override
44+
public boolean equals(final Object o) {
45+
if (this == o) {
46+
return true;
47+
}
48+
if (o == null || getClass() != o.getClass()) {
49+
return false;
50+
}
51+
GenericBaseModel<?> that = (GenericBaseModel<?>) o;
52+
return Objects.equals(field, that.field);
53+
}
54+
55+
@Override
56+
public int hashCode() {
57+
return Objects.hashCode(field);
58+
}
59+
}

0 commit comments

Comments
 (0)