Skip to content

Commit db26196

Browse files
varant-zlaismcnamara2-stripeezvz
authored
[Untested] Vz cherry pick oss avro schema (#170)
## Summary Cherry picking avro schema parsing improvements ## Checklist - [ ] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update - [x] Untested <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new `SchemaTraverser` trait to enhance schema traversal capabilities. - Added advanced schema handling for Avro and Spark data conversions. - **Improvements** - Enhanced row conversion methods with more flexible schema processing. - Improved support for complex data types and schema representations. - Updated encoding and conversion methods across multiple components. - **Technical Enhancements** - Implemented `AvroSchemaTraverser` for more robust Avro schema navigation. - Refined data conversion methods to support more flexible schema handling. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Spencer McNamara <[email protected]> Co-authored-by: ezvz <[email protected]>
1 parent 8647407 commit db26196

File tree

5 files changed

+166
-41
lines changed

5 files changed

+166
-41
lines changed

api/src/main/scala/ai/chronon/api/Row.scala

Lines changed: 108 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,36 @@ trait Row {
4040
}
4141
}
4242

43+
/**
44+
* SchemaTraverser aids in the traversal of the given SchemaType.
45+
* In some cases (eg avro), it is more performant to create the
46+
* top-level schema once and then traverse it top-to-bottom, rather
47+
* than recreating at each node.
48+
*
49+
* This helper trait allows the Row.to function to traverse SchemaType
50+
* without leaking details of the SchemaType structure.
51+
*/
52+
trait SchemaTraverser[SchemaType] {
53+
54+
def currentNode: SchemaType
55+
56+
// Returns the equivalent SchemaType representation of the given field
57+
def getField(field: StructField): SchemaTraverser[SchemaType]
58+
59+
// Returns the inner type of the current collection field type.
60+
// Throws if the current type is not a collection.
61+
def getCollectionType: SchemaTraverser[SchemaType]
62+
63+
// Returns the key type of the current map field type.
64+
// Throws if the current type is not a map.
65+
def getMapKeyType: SchemaTraverser[SchemaType]
66+
67+
// Returns the valye type of the current map field type.
68+
// Throws if the current type is not a map.
69+
def getMapValueType: SchemaTraverser[SchemaType]
70+
71+
}
72+
4373
object Row {
4474
// recursively traverse a logical struct, and convert it chronon's row type
4575
def from[CompositeType, BinaryType, ArrayType, StringType](
@@ -95,49 +125,71 @@ object Row {
95125
}
96126

97127
// recursively traverse a chronon dataType value, and convert it to an external type
98-
def to[StructType, BinaryType, ListType, MapType](value: Any,
99-
dataType: DataType,
100-
composer: (Iterator[Any], DataType) => StructType,
101-
binarizer: Array[Byte] => BinaryType,
102-
collector: (Iterator[Any], Int) => ListType,
103-
mapper: (util.Map[Any, Any] => MapType),
104-
extraneousRecord: Any => Array[Any] = null): Any = {
128+
def to[StructType, BinaryType, ListType, MapType, OutputSchema](
129+
value: Any,
130+
dataType: DataType,
131+
composer: (Iterator[Any], DataType, Option[OutputSchema]) => StructType,
132+
binarizer: Array[Byte] => BinaryType,
133+
collector: (Iterator[Any], Int) => ListType,
134+
mapper: (util.Map[Any, Any] => MapType),
135+
extraneousRecord: Any => Array[Any] = null,
136+
schemaTraverser: Option[SchemaTraverser[OutputSchema]] = None): Any = {
105137

106138
if (value == null) return null
107-
def edit(value: Any, dataType: DataType): Any =
108-
to(value, dataType, composer, binarizer, collector, mapper, extraneousRecord)
139+
140+
def getFieldSchema(f: StructField) = schemaTraverser.map(_.getField(f))
141+
142+
def edit(value: Any, dataType: DataType, subTreeTraverser: Option[SchemaTraverser[OutputSchema]]): Any =
143+
to(value, dataType, composer, binarizer, collector, mapper, extraneousRecord, subTreeTraverser)
144+
109145
dataType match {
110146
case StructType(_, fields) =>
111147
value match {
112148
case arr: Array[Any] =>
113-
composer(arr.iterator.zipWithIndex.map { case (value, idx) => edit(value, fields(idx).fieldType) },
114-
dataType)
149+
composer(
150+
arr.iterator.zipWithIndex.map {
151+
case (value, idx) => edit(value, fields(idx).fieldType, getFieldSchema(fields(idx)))
152+
},
153+
dataType,
154+
schemaTraverser.map(_.currentNode)
155+
)
115156
case list: util.ArrayList[Any] =>
116-
composer(list
117-
.iterator()
118-
.asScala
119-
.zipWithIndex
120-
.map { case (value, idx) => edit(value, fields(idx).fieldType) },
121-
dataType)
122-
case list: List[Any] =>
123-
composer(list.iterator.zipWithIndex
124-
.map { case (value, idx) => edit(value, fields(idx).fieldType) },
125-
dataType)
157+
composer(
158+
list
159+
.iterator()
160+
.asScala
161+
.zipWithIndex
162+
.map { case (value, idx) => edit(value, fields(idx).fieldType, getFieldSchema(fields(idx))) },
163+
dataType,
164+
schemaTraverser.map(_.currentNode)
165+
)
126166
case value: Any =>
127167
assert(extraneousRecord != null, s"No handler for $value of class ${value.getClass}")
128-
composer(extraneousRecord(value).iterator.zipWithIndex.map {
129-
case (value, idx) => edit(value, fields(idx).fieldType)
130-
},
131-
dataType)
168+
composer(
169+
extraneousRecord(value).iterator.zipWithIndex.map {
170+
case (value, idx) => edit(value, fields(idx).fieldType, getFieldSchema(fields(idx)))
171+
},
172+
dataType,
173+
schemaTraverser.map(_.currentNode)
174+
)
132175
}
133176
case ListType(elemType) =>
134177
value match {
135178
case list: util.ArrayList[Any] =>
136-
collector(list.iterator().asScala.map(edit(_, elemType)), list.size())
179+
collector(
180+
list.iterator().asScala.map(edit(_, elemType, schemaTraverser.map(_.getCollectionType))),
181+
list.size()
182+
)
137183
case arr: Array[_] => // avro only recognizes arrayList for its ArrayType/ListType
138-
collector(arr.iterator.map(edit(_, elemType)), arr.length)
184+
collector(
185+
arr.iterator.map(edit(_, elemType, schemaTraverser.map(_.getCollectionType))),
186+
arr.length
187+
)
139188
case arr: mutable.WrappedArray[Any] => // handles the wrapped array type from transform function in spark sql
140-
collector(arr.iterator.map(edit(_, elemType)), arr.length)
189+
collector(
190+
arr.iterator.map(edit(_, elemType, schemaTraverser.map(_.getCollectionType))),
191+
arr.length
192+
)
141193
}
142194
case MapType(keyType, valueType) =>
143195
value match {
@@ -147,12 +199,38 @@ object Row {
147199
.entrySet()
148200
.iterator()
149201
.asScala
150-
.foreach { entry => newMap.put(edit(entry.getKey, keyType), edit(entry.getValue, valueType)) }
202+
.foreach { entry =>
203+
newMap.put(
204+
edit(
205+
entry.getKey,
206+
keyType,
207+
schemaTraverser.map(_.getMapKeyType)
208+
),
209+
edit(
210+
entry.getValue,
211+
valueType,
212+
schemaTraverser.map(_.getMapValueType)
213+
)
214+
)
215+
}
151216
mapper(newMap)
152217
case map: collection.immutable.Map[Any, Any] =>
153218
val newMap = new util.HashMap[Any, Any](map.size)
154219
map
155-
.foreach { entry => newMap.put(edit(entry._1, keyType), edit(entry._2, valueType)) }
220+
.foreach { entry =>
221+
newMap.put(
222+
edit(
223+
entry._1,
224+
keyType,
225+
schemaTraverser.map(_.getMapKeyType)
226+
),
227+
edit(
228+
entry._2,
229+
valueType,
230+
schemaTraverser.map(_.getMapValueType)
231+
)
232+
)
233+
}
156234
mapper(newMap)
157235
}
158236
case BinaryType => binarizer(value.asInstanceOf[Array[Byte]])

online/src/main/scala/ai/chronon/online/AvroConversions.scala

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,16 @@ object AvroConversions {
114114
}
115115
}
116116

117-
def fromChrononRow(value: Any, dataType: DataType, extraneousRecord: Any => Array[Any] = null): Any = {
117+
def fromChrononRow(value: Any,
118+
dataType: DataType,
119+
topLevelSchema: Schema,
120+
extraneousRecord: Any => Array[Any] = null): Any = {
118121
// But this also has to happen at the recursive depth - data type and schema inside the compositor need to
119-
Row.to[GenericRecord, ByteBuffer, util.ArrayList[Any], util.Map[Any, Any]](
122+
Row.to[GenericRecord, ByteBuffer, util.ArrayList[Any], util.Map[Any, Any], Schema](
120123
value,
121124
dataType,
122-
{ (data: Iterator[Any], elemDataType: DataType) =>
123-
val schema = AvroConversions.fromChrononSchema(elemDataType)
125+
{ (data: Iterator[Any], elemDataType: DataType, providedSchema: Option[Schema]) =>
126+
val schema = providedSchema.getOrElse(AvroConversions.fromChrononSchema(elemDataType))
124127
val record = new GenericData.Record(schema)
125128
data.zipWithIndex.foreach {
126129
case (value1, idx) => record.put(idx, value1)
@@ -134,7 +137,8 @@ object AvroConversions {
134137
result
135138
},
136139
{ m: util.Map[Any, Any] => m },
137-
extraneousRecord
140+
extraneousRecord,
141+
Some(AvroSchemaTraverser(topLevelSchema))
138142
)
139143
}
140144

@@ -169,7 +173,8 @@ object AvroConversions {
169173
def encodeBytes(schema: StructType, extraneousRecord: Any => Array[Any] = null): Any => Array[Byte] = {
170174
val codec: AvroCodec = new AvroCodec(fromChrononSchema(schema).toString(true));
171175
{ data: Any =>
172-
val record = fromChrononRow(data, codec.chrononSchema, extraneousRecord).asInstanceOf[GenericData.Record]
176+
val record =
177+
fromChrononRow(data, codec.chrononSchema, codec.schema, extraneousRecord).asInstanceOf[GenericData.Record]
173178
val bytes = codec.encodeBinary(record)
174179
bytes
175180
}
@@ -178,9 +183,49 @@ object AvroConversions {
178183
def encodeJson(schema: StructType, extraneousRecord: Any => Array[Any] = null): Any => String = {
179184
val codec: AvroCodec = new AvroCodec(fromChrononSchema(schema).toString(true));
180185
{ data: Any =>
181-
val record = fromChrononRow(data, codec.chrononSchema, extraneousRecord).asInstanceOf[GenericData.Record]
186+
val record =
187+
fromChrononRow(data, codec.chrononSchema, codec.schema, extraneousRecord).asInstanceOf[GenericData.Record]
182188
val json = codec.encodeJson(record)
183189
json
184190
}
185191
}
186192
}
193+
194+
case class AvroSchemaTraverser(currentNode: Schema) extends SchemaTraverser[Schema] {
195+
196+
// We only use union types for nullable fields, and always
197+
// unbox them when writing the actual schema out.
198+
private def unboxUnion(maybeUnion: Schema): Schema =
199+
if (maybeUnion.getType == Schema.Type.UNION) {
200+
maybeUnion.getTypes.get(1)
201+
} else {
202+
maybeUnion
203+
}
204+
205+
override def getField(field: StructField): SchemaTraverser[Schema] =
206+
copy(
207+
unboxUnion(currentNode.getField(field.name).schema())
208+
)
209+
210+
override def getCollectionType: SchemaTraverser[Schema] =
211+
copy(
212+
unboxUnion(currentNode.getElementType)
213+
)
214+
215+
// Avro map keys are always strings.
216+
override def getMapKeyType: SchemaTraverser[Schema] =
217+
if (currentNode.getType == Schema.Type.MAP) {
218+
copy(
219+
Schema.create(Schema.Type.STRING)
220+
)
221+
} else {
222+
throw new UnsupportedOperationException(
223+
s"Current node ${currentNode.getName} is a ${currentNode.getType}, not a ${Schema.Type.MAP}"
224+
)
225+
}
226+
227+
override def getMapValueType: SchemaTraverser[Schema] =
228+
copy(
229+
unboxUnion(currentNode.getValueType)
230+
)
231+
}

online/src/main/scala/ai/chronon/online/Fetcher.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ class Fetcher(val kvStore: KVStore,
301301
elem
302302
}
303303
}
304-
val avroRecord = AvroConversions.fromChrononRow(data, schema).asInstanceOf[GenericRecord]
304+
val avroRecord = AvroConversions.fromChrononRow(data, schema, codec.schema).asInstanceOf[GenericRecord]
305305
codec.encodeBinary(avroRecord)
306306
}
307307

online/src/main/scala/ai/chronon/online/SparkConversions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@ object SparkConversions {
137137
})
138138

139139
def toSparkRow(value: Any, dataType: api.DataType, extraneousRecord: Any => Array[Any] = null): Any = {
140-
api.Row.to[GenericRow, Array[Byte], Array[Any], mutable.Map[Any, Any]](
140+
api.Row.to[GenericRow, Array[Byte], Array[Any], mutable.Map[Any, Any], StructType](
141141
value,
142142
dataType,
143-
{ (data: Iterator[Any], _) => new GenericRow(data.toArray) },
143+
{ (data: Iterator[Any], _, _) => new GenericRow(data.toArray) },
144144
{ bytes: Array[Byte] => bytes },
145145
{ (elems: Iterator[Any], size: Int) =>
146146
val result = new Array[Any](size)

spark/src/main/scala/ai/chronon/spark/utils/InMemoryStream.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ class InMemoryStream {
100100
input.addData(inputDf.collect.map { row: Row =>
101101
val bytes =
102102
encodeRecord(avroSchema)(
103-
AvroConversions.fromChrononRow(row, schema, GenericRowHandler.func).asInstanceOf[GenericData.Record])
103+
AvroConversions
104+
.fromChrononRow(row, schema, avroSchema, GenericRowHandler.func)
105+
.asInstanceOf[GenericData.Record])
104106
bytes
105107
})
106108
input.toDF

0 commit comments

Comments
 (0)