Skip to content

Commit c8d2b0c

Browse files
committed
Add support to register UDFs in Flink
1 parent 14cc871 commit c8d2b0c

File tree

4 files changed

+58
-4
lines changed

4 files changed

+58
-4
lines changed

flink/src/main/scala/ai/chronon/flink/SparkExpressionEvalFn.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,12 @@ class SparkExpressionEvalFn[T](encoder: Encoder[T], groupBy: GroupBy) extends Ri
5757
private[flink] def getOutputSchema: StructType = {
5858
// before we do anything, run our setup statements.
5959
// in order to create the output schema, we'll evaluate expressions
60-
// TODO handle UDFs
61-
new CatalystUtil(transforms, chrononSchema, filters).getOutputSparkSchema
60+
new CatalystUtil(transforms, chrononSchema, filters, groupBy.setups).getOutputSparkSchema
6261
}
6362

6463
override def open(configuration: Configuration): Unit = {
6564
super.open(configuration)
66-
catalystUtil = new CatalystUtil(transforms, chrononSchema, filters)
65+
catalystUtil = new CatalystUtil(transforms, chrononSchema, filters, groupBy.setups)
6766
val eventExprEncoder = encoder.asInstanceOf[ExpressionEncoder[T]]
6867
rowSerializer = eventExprEncoder.createSerializer()
6968

online/src/main/scala/ai/chronon/online/CatalystUtil.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import ai.chronon.api.{DataType, StructType}
2020
import ai.chronon.online.CatalystUtil.{IteratorWrapper, PoolKey, poolMap}
2121
import ai.chronon.online.Extensions.StructTypeOps
2222
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException
2324
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
2425
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
2526
import org.apache.spark.sql.execution.{
@@ -55,6 +56,7 @@ object CatalystUtil {
5556
.config("spark.sql.session.timeZone", "UTC")
5657
.config("spark.sql.adaptive.enabled", "false")
5758
.config("spark.sql.legacy.timeParserPolicy", "LEGACY")
59+
.enableHiveSupport()
5860
.getOrCreate()
5961
assert(spark.sessionState.conf.wholeStageEnabled)
6062
spark
@@ -126,7 +128,8 @@ class PooledCatalystUtil(expressions: collection.Seq[(String, String)], inputSch
126128
// This class by itself it not thread safe because of the transformBuffer
127129
class CatalystUtil(expressions: collection.Seq[(String, String)],
128130
inputSchema: StructType,
129-
filters: collection.Seq[String] = Seq.empty) {
131+
filters: collection.Seq[String] = Seq.empty,
132+
setups: collection.Seq[String] = Seq.empty) {
130133
private val selectClauses = expressions.map { case (name, expr) => s"$expr as $name" }
131134
private val sessionTable =
132135
s"q${math.abs(selectClauses.mkString(", ").hashCode)}_f${math.abs(inputSparkSchema.pretty.hashCode)}"
@@ -165,6 +168,18 @@ class CatalystUtil(expressions: collection.Seq[(String, String)],
165168
private def initialize(): (InternalRow => Option[InternalRow], types.StructType) = {
166169
val session = CatalystUtil.session
167170

171+
// run through and execute the setup statements
172+
setups.foreach { statement =>
173+
try {
174+
session.sql(statement)
175+
} catch {
176+
case _: FunctionAlreadyExistsException =>
177+
// ignore - this crops up in unit tests on occasion
178+
case e: Exception =>
179+
throw new RuntimeException(s"Error executing setup statement: $statement", e)
180+
}
181+
}
182+
168183
// create dummy df with sql query and schema
169184
val emptyRowRdd = session.emptyDataFrame.rdd
170185
val inputSparkSchema = SparkConversions.fromChrononSchema(inputSchema)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package ai.chronon.online.test
2+
3+
import ai.chronon.online.CatalystUtil
4+
import junit.framework.TestCase
5+
import org.junit.Assert.assertEquals
6+
import org.junit.Test
7+
8+
class CatalystUtilHiveUDFTest extends TestCase with CatalystUtilTestSparkSQLStructs {
9+
10+
@Test
11+
def testHiveUDFsViaSetupsShouldWork(): Unit = {
12+
val setups = Seq(
13+
"CREATE FUNCTION MINUS_ONE AS 'ai.chronon.online.test.Minus_One'",
14+
"CREATE FUNCTION CAT_STR AS 'ai.chronon.online.test.Cat_Str'",
15+
)
16+
val selects = Seq(
17+
"a" -> "MINUS_ONE(int32_x)",
18+
"b" -> "CAT_STR(string_x)"
19+
)
20+
val cu = new CatalystUtil(expressions = selects, inputSchema = CommonScalarsStruct, setups = setups)
21+
val res = cu.sqlTransform(CommonScalarsRow)
22+
assertEquals(res.get.size, 2)
23+
assertEquals(res.get("a"), Int.MaxValue - 1)
24+
assertEquals(res.get("b"), "hello123")
25+
}
26+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package ai.chronon.online.test
2+
3+
// A couple of toy UDFs to help test Hive UDF registration in CatalystUtil
4+
class Minus_One extends org.apache.hadoop.hive.ql.exec.UDF {
5+
def evaluate(x: Integer): Integer = {
6+
x - 1
7+
}
8+
}
9+
10+
class Cat_Str extends org.apache.hadoop.hive.ql.exec.UDF {
11+
def evaluate(x: String): String = {
12+
x + "123"
13+
}
14+
}

0 commit comments

Comments
 (0)