diff --git a/flink/src/main/scala/ai/chronon/flink/SparkExpressionEvalFn.scala b/flink/src/main/scala/ai/chronon/flink/SparkExpressionEvalFn.scala index a39421d32..e626db137 100644 --- a/flink/src/main/scala/ai/chronon/flink/SparkExpressionEvalFn.scala +++ b/flink/src/main/scala/ai/chronon/flink/SparkExpressionEvalFn.scala @@ -57,13 +57,12 @@ class SparkExpressionEvalFn[T](encoder: Encoder[T], groupBy: GroupBy) extends Ri private[flink] def getOutputSchema: StructType = { // before we do anything, run our setup statements. // in order to create the output schema, we'll evaluate expressions - // TODO handle UDFs - new CatalystUtil(transforms, chrononSchema, filters).getOutputSparkSchema + new CatalystUtil(transforms, chrononSchema, filters, groupBy.setups).getOutputSparkSchema } override def open(configuration: Configuration): Unit = { super.open(configuration) - catalystUtil = new CatalystUtil(transforms, chrononSchema, filters) + catalystUtil = new CatalystUtil(transforms, chrononSchema, filters, groupBy.setups) val eventExprEncoder = encoder.asInstanceOf[ExpressionEncoder[T]] rowSerializer = eventExprEncoder.createSerializer() diff --git a/online/src/main/scala/ai/chronon/online/CatalystUtil.scala b/online/src/main/scala/ai/chronon/online/CatalystUtil.scala index 6e3c52442..c16918abf 100644 --- a/online/src/main/scala/ai/chronon/online/CatalystUtil.scala +++ b/online/src/main/scala/ai/chronon/online/CatalystUtil.scala @@ -20,6 +20,7 @@ import ai.chronon.api.{DataType, StructType} import ai.chronon.online.CatalystUtil.{IteratorWrapper, PoolKey, poolMap} import ai.chronon.online.Extensions.StructTypeOps import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.FunctionAlreadyExistsException import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator import org.apache.spark.sql.execution.{ @@ -55,6 +56,7 @@ object CatalystUtil { .config("spark.sql.session.timeZone", "UTC") .config("spark.sql.adaptive.enabled", "false") .config("spark.sql.legacy.timeParserPolicy", "LEGACY") + .enableHiveSupport() .getOrCreate() assert(spark.sessionState.conf.wholeStageEnabled) spark @@ -126,7 +128,8 @@ class PooledCatalystUtil(expressions: collection.Seq[(String, String)], inputSch // This class by itself it not thread safe because of the transformBuffer class CatalystUtil(expressions: collection.Seq[(String, String)], inputSchema: StructType, - filters: collection.Seq[String] = Seq.empty) { + filters: collection.Seq[String] = Seq.empty, + setups: collection.Seq[String] = Seq.empty) { private val selectClauses = expressions.map { case (name, expr) => s"$expr as $name" } private val sessionTable = s"q${math.abs(selectClauses.mkString(", ").hashCode)}_f${math.abs(inputSparkSchema.pretty.hashCode)}" @@ -165,6 +168,18 @@ class CatalystUtil(expressions: collection.Seq[(String, String)], private def initialize(): (InternalRow => Option[InternalRow], types.StructType) = { val session = CatalystUtil.session + // run through and execute the setup statements + setups.foreach { statement => + try { + session.sql(statement) + } catch { + case _: FunctionAlreadyExistsException => + // ignore - this crops up in unit tests on occasion + case e: Exception => + throw new RuntimeException(s"Error executing setup statement: $statement", e) + } + } + // create dummy df with sql query and schema val emptyRowRdd = session.emptyDataFrame.rdd val inputSparkSchema = SparkConversions.fromChrononSchema(inputSchema) diff --git a/online/src/test/scala/ai/chronon/online/test/CatalystUtilHiveUDFTest.scala b/online/src/test/scala/ai/chronon/online/test/CatalystUtilHiveUDFTest.scala new file mode 100644 index 000000000..7def8407c --- /dev/null +++ b/online/src/test/scala/ai/chronon/online/test/CatalystUtilHiveUDFTest.scala @@ -0,0 +1,26 @@ +package ai.chronon.online.test + +import ai.chronon.online.CatalystUtil +import junit.framework.TestCase +import org.junit.Assert.assertEquals +import org.junit.Test + +class CatalystUtilHiveUDFTest extends TestCase with CatalystUtilTestSparkSQLStructs { + + @Test + def testHiveUDFsViaSetupsShouldWork(): Unit = { + val setups = Seq( + "CREATE FUNCTION MINUS_ONE AS 'ai.chronon.online.test.Minus_One'", + "CREATE FUNCTION CAT_STR AS 'ai.chronon.online.test.Cat_Str'", + ) + val selects = Seq( + "a" -> "MINUS_ONE(int32_x)", + "b" -> "CAT_STR(string_x)" + ) + val cu = new CatalystUtil(expressions = selects, inputSchema = CommonScalarsStruct, setups = setups) + val res = cu.sqlTransform(CommonScalarsRow) + assertEquals(res.get.size, 2) + assertEquals(res.get("a"), Int.MaxValue - 1) + assertEquals(res.get("b"), "hello123") + } +} diff --git a/online/src/test/scala/ai/chronon/online/test/ExampleUDFs.scala b/online/src/test/scala/ai/chronon/online/test/ExampleUDFs.scala new file mode 100644 index 000000000..608386c2c --- /dev/null +++ b/online/src/test/scala/ai/chronon/online/test/ExampleUDFs.scala @@ -0,0 +1,14 @@ +package ai.chronon.online.test + +// A couple of toy UDFs to help test Hive UDF registration in CatalystUtil +class Minus_One extends org.apache.hadoop.hive.ql.exec.UDF { + def evaluate(x: Integer): Integer = { + x - 1 + } +} + +class Cat_Str extends org.apache.hadoop.hive.ql.exec.UDF { + def evaluate(x: String): String = { + x + "123" + } +}