Skip to content

Commit 3acc1ef

Browse files
author
Tristan Nixon
committed
restoring dlt asofjoin fix from #334
1 parent d62697a commit 3acc1ef

File tree

1 file changed

+53
-53
lines changed

1 file changed

+53
-53
lines changed

python/tempo/tsdf.py

+53-53
Original file line numberDiff line numberDiff line change
@@ -750,63 +750,63 @@ def asofJoin(
750750
left_df = self.df
751751
right_df = right_tsdf.df
752752

753-
spark = SparkSession.builder.getOrCreate()
754-
left_bytes = self.__getBytesFromPlan(left_df, spark)
755-
right_bytes = self.__getBytesFromPlan(right_df, spark)
756-
757-
# choose 30MB as the cutoff for the broadcast
758-
bytes_threshold = 30 * 1024 * 1024
759-
if sql_join_opt & (
760-
(left_bytes < bytes_threshold) | (right_bytes < bytes_threshold)
761-
):
762-
spark.conf.set("spark.databricks.optimizer.rangeJoin.binSize", 60)
763-
partition_cols = right_tsdf.partitionCols
764-
left_cols = list(set(left_df.columns).difference(set(self.partitionCols)))
765-
right_cols = list(
766-
set(right_df.columns).difference(set(right_tsdf.partitionCols))
767-
)
768-
769-
left_prefix = (
770-
""
771-
if not left_prefix # use truthiness of None and ""
772-
else left_prefix + "_"
773-
)
774-
right_prefix = (
775-
""
776-
if not right_prefix # use truthiness of None and ""
777-
else right_prefix + "_"
778-
)
753+
# test if the broadcast join will be efficient
754+
if sql_join_opt:
755+
spark = SparkSession.builder.getOrCreate()
756+
left_bytes = self.__getBytesFromPlan(left_df, spark)
757+
right_bytes = self.__getBytesFromPlan(right_df, spark)
758+
759+
# choose 30MB as the cutoff for the broadcast
760+
bytes_threshold = 30 * 1024 * 1024
761+
if (left_bytes < bytes_threshold) or (right_bytes < bytes_threshold):
762+
spark.conf.set("spark.databricks.optimizer.rangeJoin.binSize", 60)
763+
partition_cols = right_tsdf.partitionCols
764+
left_cols = list(set(left_df.columns) - set(self.partitionCols))
765+
right_cols = list(set(right_df.columns) - set(right_tsdf.partitionCols))
766+
767+
left_prefix = (
768+
""
769+
if ((left_prefix is None) | (left_prefix == ""))
770+
else left_prefix + "_"
771+
)
772+
right_prefix = (
773+
""
774+
if ((right_prefix is None) | (right_prefix == ""))
775+
else right_prefix + "_"
776+
)
779777

780-
w = Window.partitionBy(*partition_cols).orderBy(
781-
right_prefix + right_tsdf.ts_col
782-
)
778+
w = Window.partitionBy(*partition_cols).orderBy(
779+
right_prefix + right_tsdf.ts_col
780+
)
783781

784-
new_left_ts_col = left_prefix + self.ts_col
785-
new_left_cols = [
786-
sfn.col(c).alias(left_prefix + c) for c in left_cols
787-
] + partition_cols
788-
new_right_cols = [
789-
sfn.col(c).alias(right_prefix + c) for c in right_cols
790-
] + partition_cols
791-
quotes_df_w_lag = right_df.select(*new_right_cols).withColumn(
792-
"lead_" + right_tsdf.ts_col,
793-
sfn.lead(right_prefix + right_tsdf.ts_col).over(w),
794-
)
795-
left_df = left_df.select(*new_left_cols)
796-
res = (
797-
left_df.join(quotes_df_w_lag, partition_cols)
798-
.where(
799-
left_df[new_left_ts_col].between(
800-
sfn.col(right_prefix + right_tsdf.ts_col),
801-
sfn.coalesce(
802-
sfn.col("lead_" + right_tsdf.ts_col),
803-
sfn.lit("2099-01-01").cast("timestamp"),
804-
),
782+
new_left_ts_col = left_prefix + self.ts_col
783+
new_left_cols = [
784+
sfn.col(c).alias(left_prefix + c) for c in left_cols
785+
] + partition_cols
786+
new_right_cols = [
787+
sfn.col(c).alias(right_prefix + c) for c in right_cols
788+
] + partition_cols
789+
quotes_df_w_lag = right_df.select(*new_right_cols).withColumn(
790+
"lead_" + right_tsdf.ts_col,
791+
sfn.lead(right_prefix + right_tsdf.ts_col).over(w),
792+
)
793+
left_df = left_df.select(*new_left_cols)
794+
res = (
795+
left_df.join(quotes_df_w_lag, partition_cols)
796+
.where(
797+
left_df[new_left_ts_col].between(
798+
sfn.col(right_prefix + right_tsdf.ts_col),
799+
sfn.coalesce(
800+
sfn.col("lead_" + right_tsdf.ts_col),
801+
sfn.lit("2099-01-01").cast("timestamp"),
802+
),
803+
)
805804
)
805+
.drop("lead_" + right_tsdf.ts_col)
806+
)
807+
return TSDF(
808+
res, partition_cols=self.partitionCols, ts_col=new_left_ts_col
806809
)
807-
.drop("lead_" + right_tsdf.ts_col)
808-
)
809-
return TSDF(res, partition_cols=self.partitionCols, ts_col=new_left_ts_col)
810810

811811
# end of block checking to see if standard Spark SQL join will work
812812

0 commit comments

Comments
 (0)