Skip to content

Commit 776218e

Browse files
tnixonR7L208
andauthored
Standardize imports (#340)
* standardizing package imports * black reformatting * simplify the TSDF imports * reorganize imports * Revert "simplify the TSDF imports" This reverts commit 0cefd1569f110c4e7f27db23bfa33db3a1bc730e. * refactoring sql_fn to sfn based on popular demand * Describe module import standards * black formatting * restoring dlt asofjoin fix from #334 * Update python/tests/intervals_tests.py hmmm - guess I missed this one :D Co-authored-by: Lorin Dawson <[email protected]> * Update python/tests/intervals_tests.py good catch Co-authored-by: Lorin Dawson <[email protected]> * Update python/tests/intervals_tests.py Co-authored-by: Lorin Dawson <[email protected]> * Update python/tests/intervals_tests.py Co-authored-by: Lorin Dawson <[email protected]> --------- Co-authored-by: Lorin Dawson <[email protected]>
1 parent 7f93b9a commit 776218e

14 files changed

+385
-326
lines changed

CONTRIBUTING.md

+30
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,33 @@ These environments are also defined in the `tox.ini`file and skip installing dep
7171
* lint
7272
* type-check
7373
* coverage-report
74+
75+
# Code style & Standards
76+
77+
The tempo project abides by [`black`](https://black.readthedocs.io/en/stable/index.html) formatting standards,
78+
as well as using [`flake8`](https://flake8.pycqa.org/en/latest/) and [`mypy`](https://mypy.readthedocs.io/en/stable/)
79+
to check for effective code style, type-checking and common bad practices.
80+
To test your code against these standards, run the following command:
81+
```bash
82+
tox -e lint, type-check
83+
```
84+
To have `black` automatically format your code, run the following command:
85+
```bash
86+
tox -e format
87+
```
88+
89+
In addition, we apply some project-specific standards:
90+
91+
## Module imports
92+
93+
We organize import statements at the top of each module in the following order, each section being separated by a blank line:
94+
1. Standard Python library imports
95+
2. Third-party library imports
96+
3. PySpark library imports
97+
4. Tempo library imports
98+
99+
Within each section, imports are sorted alphabetically. While it is acceptable to directly import classes and some functions that are
100+
going to be commonly used, for the sake of readability, it is generally preferred to import a package with an alias and then use the alias
101+
to reference the package's classes and functions.
102+
103+
When importing `pyspark.sql.functions`, we use the convention to alias this package as `sfn`, which is both distinctive and short.

python/tempo/interpol.py

+46-41
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from __future__ import annotations
22

3-
from typing import List, Optional, Union, Callable
3+
from typing import Callable, List, Optional, Union
44

55
from pyspark.sql.dataframe import DataFrame
6-
from pyspark.sql.functions import col, expr, last, lead, lit, when
6+
import pyspark.sql.functions as sfn
77
from pyspark.sql.window import Window
88

9-
import tempo.utils as t_utils
109
import tempo.resample as t_resample
1110
import tempo.tsdf as t_tsdf
11+
import tempo.utils as t_utils
1212

1313
# Interpolation fill options
1414
method_options = ["zero", "null", "bfill", "ffill", "linear"]
@@ -130,56 +130,56 @@ def __interpolate_column(
130130
END AS is_interpolated_{target_col}
131131
"""
132132
output_df = output_df.withColumn(
133-
f"is_interpolated_{target_col}", expr(flag_expr)
133+
f"is_interpolated_{target_col}", sfn.expr(flag_expr)
134134
)
135135

136136
# Handle zero fill
137137
if method == "zero":
138138
output_df = output_df.withColumn(
139139
target_col,
140-
when(
141-
col(f"is_interpolated_{target_col}") == False, # noqa: E712
142-
col(target_col),
143-
).otherwise(lit(0)),
140+
sfn.when(
141+
sfn.col(f"is_interpolated_{target_col}") == False, # noqa: E712
142+
sfn.col(target_col),
143+
).otherwise(sfn.lit(0)),
144144
)
145145

146146
# Handle null fill
147147
if method == "null":
148148
output_df = output_df.withColumn(
149149
target_col,
150-
when(
151-
col(f"is_interpolated_{target_col}") == False, # noqa: E712
152-
col(target_col),
150+
sfn.when(
151+
sfn.col(f"is_interpolated_{target_col}") == False, # noqa: E712
152+
sfn.col(target_col),
153153
).otherwise(None),
154154
)
155155

156156
# Handle forward fill
157157
if method == "ffill":
158158
output_df = output_df.withColumn(
159159
target_col,
160-
when(
161-
col(f"is_interpolated_{target_col}") == True, # noqa: E712
162-
col(f"previous_{target_col}"),
163-
).otherwise(col(target_col)),
160+
sfn.when(
161+
sfn.col(f"is_interpolated_{target_col}") == True, # noqa: E712
162+
sfn.col(f"previous_{target_col}"),
163+
).otherwise(sfn.col(target_col)),
164164
)
165165
# Handle backwards fill
166166
if method == "bfill":
167167
output_df = output_df.withColumn(
168168
target_col,
169169
# Handle case when subsequent value is null
170-
when(
171-
(col(f"is_interpolated_{target_col}") == True) # noqa: E712
170+
sfn.when(
171+
(sfn.col(f"is_interpolated_{target_col}") == True) # noqa: E712
172172
& (
173-
col(f"next_{target_col}").isNull()
174-
& (col(f"{ts_col}_{target_col}").isNull())
173+
sfn.col(f"next_{target_col}").isNull()
174+
& (sfn.col(f"{ts_col}_{target_col}").isNull())
175175
),
176-
col(f"next_null_{target_col}"),
176+
sfn.col(f"next_null_{target_col}"),
177177
).otherwise(
178178
# Handle standard backwards fill
179-
when(
180-
col(f"is_interpolated_{target_col}") == True, # noqa: E712
181-
col(f"next_{target_col}"),
182-
).otherwise(col(f"{target_col}"))
179+
sfn.when(
180+
sfn.col(f"is_interpolated_{target_col}") == True, # noqa: E712
181+
sfn.col(f"next_{target_col}"),
182+
).otherwise(sfn.col(f"{target_col}"))
183183
),
184184
)
185185

@@ -205,10 +205,12 @@ def __generate_time_series_fill(
205205
"""
206206
return df.withColumn(
207207
"previous_timestamp",
208-
col(ts_col),
208+
sfn.col(ts_col),
209209
).withColumn(
210210
"next_timestamp",
211-
lead(df[ts_col]).over(Window.partitionBy(*partition_cols).orderBy(ts_col)),
211+
sfn.lead(df[ts_col]).over(
212+
Window.partitionBy(*partition_cols).orderBy(ts_col)
213+
),
212214
)
213215

214216
def __generate_column_time_fill(
@@ -232,13 +234,13 @@ def __generate_column_time_fill(
232234

233235
return df.withColumn(
234236
f"previous_timestamp_{target_col}",
235-
last(col(f"{ts_col}_{target_col}"), ignorenulls=True).over(
237+
sfn.last(sfn.col(f"{ts_col}_{target_col}"), ignorenulls=True).over(
236238
window.orderBy(ts_col).rowsBetween(Window.unboundedPreceding, 0)
237239
),
238240
).withColumn(
239241
f"next_timestamp_{target_col}",
240-
last(col(f"{ts_col}_{target_col}"), ignorenulls=True).over(
241-
window.orderBy(col(ts_col).desc()).rowsBetween(
242+
sfn.last(sfn.col(f"{ts_col}_{target_col}"), ignorenulls=True).over(
243+
window.orderBy(sfn.col(ts_col).desc()).rowsBetween(
242244
Window.unboundedPreceding, 0
243245
)
244246
),
@@ -266,21 +268,21 @@ def __generate_target_fill(
266268
return (
267269
df.withColumn(
268270
f"previous_{target_col}",
269-
last(df[target_col], ignorenulls=True).over(
271+
sfn.last(df[target_col], ignorenulls=True).over(
270272
window.orderBy(ts_col).rowsBetween(Window.unboundedPreceding, 0)
271273
),
272274
)
273275
# Handle if subsequent value is null
274276
.withColumn(
275277
f"next_null_{target_col}",
276-
last(df[target_col], ignorenulls=True).over(
277-
window.orderBy(col(ts_col).desc()).rowsBetween(
278+
sfn.last(df[target_col], ignorenulls=True).over(
279+
window.orderBy(sfn.col(ts_col).desc()).rowsBetween(
278280
Window.unboundedPreceding, 0
279281
)
280282
),
281283
).withColumn(
282284
f"next_{target_col}",
283-
lead(df[target_col]).over(window.orderBy(ts_col)),
285+
sfn.lead(df[target_col]).over(window.orderBy(ts_col)),
284286
)
285287
)
286288

@@ -356,7 +358,7 @@ def interpolate(
356358
for column in target_cols:
357359
add_column_time = add_column_time.withColumn(
358360
f"{ts_col}_{column}",
359-
when(col(column).isNull(), None).otherwise(col(ts_col)),
361+
sfn.when(sfn.col(column).isNull(), None).otherwise(sfn.col(ts_col)),
360362
)
361363
add_column_time = self.__generate_column_time_fill(
362364
add_column_time, partition_cols, ts_col, column
@@ -365,9 +367,10 @@ def interpolate(
365367
# Handle edge case if last value (latest) is null
366368
edge_filled = add_column_time.withColumn(
367369
"next_timestamp",
368-
when(
369-
col("next_timestamp").isNull(), expr(f"{ts_col}+ interval {freq}")
370-
).otherwise(col("next_timestamp")),
370+
sfn.when(
371+
sfn.col("next_timestamp").isNull(),
372+
sfn.expr(f"{ts_col}+ interval {freq}"),
373+
).otherwise(sfn.col("next_timestamp")),
371374
)
372375

373376
# Fill target column for nearest values
@@ -380,7 +383,7 @@ def interpolate(
380383
# Generate missing timeseries values
381384
exploded_series = target_column_filled.withColumn(
382385
f"new_{ts_col}",
383-
expr(
386+
sfn.expr(
384387
f"explode(sequence({ts_col}, next_timestamp - interval {freq}, interval {freq} )) as timestamp"
385388
),
386389
)
@@ -390,10 +393,12 @@ def interpolate(
390393
flagged_series = (
391394
exploded_series.withColumn(
392395
"is_ts_interpolated",
393-
when(col(f"new_{ts_col}") != col(ts_col), True).otherwise(False),
396+
sfn.when(sfn.col(f"new_{ts_col}") != sfn.col(ts_col), True).otherwise(
397+
False
398+
),
394399
)
395-
.withColumn(ts_col, col(f"new_{ts_col}"))
396-
.drop(col(f"new_{ts_col}"))
400+
.withColumn(ts_col, sfn.col(f"new_{ts_col}"))
401+
.drop(sfn.col(f"new_{ts_col}"))
397402
)
398403

399404
# # Perform interpolation on each target column

python/tempo/intervals.py

+26-26
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from __future__ import annotations
22

3-
from typing import Optional
43
from functools import cached_property
4+
from typing import Optional
55

6-
import pyspark.sql
6+
import pyspark.sql.functions as sfn
77
from pyspark.sql.dataframe import DataFrame
8-
from pyspark.sql.types import NumericType, BooleanType, StructField
9-
import pyspark.sql.functions as f
10-
from pyspark.sql.window import Window
8+
from pyspark.sql.types import BooleanType, NumericType, StructField
9+
from pyspark.sql.window import Window, WindowSpec
1110

1211

1312
def is_metric_col(col: StructField) -> bool:
@@ -105,7 +104,7 @@ def metric_columns(self) -> list[str]:
105104
return [col.name for col in self.df.schema.fields if is_metric_col(col)]
106105

107106
@cached_property
108-
def window(self) -> pyspark.sql.window:
107+
def window(self) -> WindowSpec:
109108
return Window.partitionBy(*self.series_ids).orderBy(*self.interval_boundaries)
110109

111110
@classmethod
@@ -210,10 +209,10 @@ def __get_adjacent_rows(self, df: DataFrame) -> DataFrame:
210209
for c in self.interval_boundaries + self.metric_columns:
211210
df = df.withColumn(
212211
f"_lead_1_{c}",
213-
f.lead(c, 1).over(self.window),
212+
sfn.lead(c, 1).over(self.window),
214213
).withColumn(
215214
f"_lag_1_{c}",
216-
f.lag(c, 1).over(self.window),
215+
sfn.lag(c, 1).over(self.window),
217216
)
218217

219218
return df
@@ -236,8 +235,8 @@ def __identify_subset_intervals(self, df: DataFrame) -> tuple[DataFrame, str]:
236235

237236
df = df.withColumn(
238237
subset_indicator,
239-
(f.col(f"_lag_1_{self.start_ts}") <= f.col(self.start_ts))
240-
& (f.col(f"_lag_1_{self.end_ts}") >= f.col(self.end_ts)),
238+
(sfn.col(f"_lag_1_{self.start_ts}") <= sfn.col(self.start_ts))
239+
& (sfn.col(f"_lag_1_{self.end_ts}") >= sfn.col(self.end_ts)),
241240
)
242241

243242
# NB: the first record cannot be a subset of the previous and
@@ -271,12 +270,12 @@ def __identify_overlaps(self, df: DataFrame) -> tuple[DataFrame, list[str]]:
271270
for ts in self.interval_boundaries:
272271
df = df.withColumn(
273272
f"_lead_1_{ts}_overlaps",
274-
(f.col(f"_lead_1_{ts}") > f.col(self.start_ts))
275-
& (f.col(f"_lead_1_{ts}") < f.col(self.end_ts)),
273+
(sfn.col(f"_lead_1_{ts}") > sfn.col(self.start_ts))
274+
& (sfn.col(f"_lead_1_{ts}") < sfn.col(self.end_ts)),
276275
).withColumn(
277276
f"_lag_1_{ts}_overlaps",
278-
(f.col(f"_lag_1_{ts}") > f.col(self.start_ts))
279-
& (f.col(f"_lag_1_{ts}") < f.col(self.end_ts)),
277+
(sfn.col(f"_lag_1_{ts}") > sfn.col(self.start_ts))
278+
& (sfn.col(f"_lag_1_{ts}") < sfn.col(self.end_ts)),
280279
)
281280

282281
overlap_indicators.extend(
@@ -321,9 +320,10 @@ def __merge_adjacent_subset_and_superset(
321320
for c in self.metric_columns:
322321
df = df.withColumn(
323322
c,
324-
f.when(
325-
f.col(subset_indicator), f.coalesce(f.col(c), f"_lag_1_{c}")
326-
).otherwise(f.col(c)),
323+
sfn.when(
324+
sfn.col(subset_indicator),
325+
sfn.coalesce(sfn.col(c), f"_lag_1_{c}"),
326+
).otherwise(sfn.col(c)),
327327
)
328328

329329
return df
@@ -385,7 +385,7 @@ def __merge_adjacent_overlaps(
385385

386386
df = df.withColumn(
387387
new_boundary_col,
388-
f.expr(new_interval_boundaries),
388+
sfn.expr(new_interval_boundaries),
389389
)
390390

391391
if how == "left":
@@ -394,13 +394,13 @@ def __merge_adjacent_overlaps(
394394
c,
395395
# needed when intervals have same start but different ends
396396
# in this case, merge metrics since they overlap
397-
f.when(
398-
f.col(f"_lag_1_{self.end_ts}_overlaps"),
399-
f.coalesce(f.col(c), f.col(f"_lag_1_{c}")),
397+
sfn.when(
398+
sfn.col(f"_lag_1_{self.end_ts}_overlaps"),
399+
sfn.coalesce(sfn.col(c), sfn.col(f"_lag_1_{c}")),
400400
)
401401
# general case when constructing left disjoint interval
402402
# just want new boundary without merging metrics
403-
.otherwise(f.col(c)),
403+
.otherwise(sfn.col(c)),
404404
)
405405

406406
return df
@@ -423,7 +423,7 @@ def __merge_equal_intervals(self, df: DataFrame) -> DataFrame:
423423
424424
"""
425425

426-
merge_expr = tuple(f.max(c).alias(c) for c in self.metric_columns)
426+
merge_expr = tuple(sfn.max(c).alias(c) for c in self.metric_columns)
427427

428428
return df.groupBy(*self.interval_boundaries, *self.series_ids).agg(*merge_expr)
429429

@@ -469,7 +469,7 @@ def disjoint(self) -> "IntervalsDF":
469469

470470
(df, subset_indicator) = self.__identify_subset_intervals(df)
471471

472-
subset_df = df.filter(f.col(subset_indicator))
472+
subset_df = df.filter(sfn.col(subset_indicator))
473473

474474
subset_df = self.__merge_adjacent_subset_and_superset(
475475
subset_df, subset_indicator
@@ -479,7 +479,7 @@ def disjoint(self) -> "IntervalsDF":
479479
*self.interval_boundaries, *self.series_ids, *self.metric_columns
480480
)
481481

482-
non_subset_df = df.filter(~f.col(subset_indicator))
482+
non_subset_df = df.filter(~sfn.col(subset_indicator))
483483

484484
(non_subset_df, overlap_indicators) = self.__identify_overlaps(non_subset_df)
485485

@@ -611,7 +611,7 @@ def toDF(self, stack: bool = False) -> DataFrame:
611611
)
612612

613613
return self.df.select(
614-
*self.interval_boundaries, *self.series_ids, f.expr(stack_expr)
614+
*self.interval_boundaries, *self.series_ids, sfn.expr(stack_expr)
615615
).dropna(subset="metric_value")
616616

617617
else:

0 commit comments

Comments
 (0)