Skip to content

Commit 5ba6ce5

Browse files
R7L208Tristan Nixon
and
Tristan Nixon
authored
TSDF: extractStateIntervals method to identify an "event_time" range where state is constant (#231)
* update gitignore * working changes * refactor tests & split out * black & working changes * comments * working constantMetricRanges need to complete test cases * remove extra select to create event_ts from constantMetricRanges * implementation and test cases for constantMetricState for PySpark comparison operator * update constantStateRanges and new test case for state defined by column expressions * black formatting * update import statement for Column * reformating and added warning message for non-tested state definitions * move call to logger to correct spot * fix logger warning output * update tests and data to new format * black formatting * exclude cache and virtual env directories from flake8 check * flake8 formatting * revert formatting in examples/. * refactor ConstantMetricState to exlude metric intervals from output * rename method constantMetricState to extractStateIntervals * fixing method name in warning message Co-authored-by: Tristan Nixon <[email protected]>
1 parent 891e4bd commit 5ba6ce5

File tree

6 files changed

+1285
-20
lines changed

6 files changed

+1285
-20
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ scala/target/stream/*
3131

3232
# ignore virtual environments
3333
python/venv
34+
python/.venv
35+
python/.env
3436
venv
37+
.venv
38+
.env
3539

3640
# other misc ignore
3741
.DS_Store

python/.flake8

+7-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,10 @@ extend-ignore =
99
# Invalid escape sequence 'x' (W605)
1010
W605,
1111
# 'from module import *' used; unable to detect undefined names (F403)
12-
F403
12+
F403
13+
exclude =
14+
__pycache__
15+
env
16+
.env
17+
venv
18+
.venv

python/tempo/tsdf.py

+81-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
from __future__ import annotations
2+
13
import logging
24
from functools import reduce
3-
from typing import List
5+
from typing import List, Collection, Union
46

57
import numpy as np
68
import pyspark.sql.functions as f
79
from IPython.core.display import HTML
810
from IPython.display import display as ipydisplay
911
from pyspark.sql import SparkSession
12+
from pyspark.sql.column import Column
1013
from pyspark.sql.dataframe import DataFrame
1114
from pyspark.sql.window import Window, WindowSpec
1215
from scipy.fft import fft, fftfreq
@@ -1334,6 +1337,83 @@ def tempo_fourier_util(pdf):
13341337

13351338
return TSDF(result, self.ts_col, self.partitionCols, self.sequence_col)
13361339

1340+
def extractStateIntervals(
1341+
self,
1342+
*metricCols: Collection[str],
1343+
state_definition: Union[str, Column[bool]] = "=",
1344+
) -> TSDF:
1345+
1346+
data = self.df
1347+
1348+
w = self.__baseWindow()
1349+
1350+
if type(state_definition) is str:
1351+
if state_definition not in ("=", "<=>", "!=", "<>", ">", "<", ">=", "<="):
1352+
logger.warning(
1353+
"A `state_definition` which has not been tested was"
1354+
"provided to the `extractStateIntervals` method."
1355+
)
1356+
current_state = f.array(*metricCols)
1357+
else:
1358+
current_state = state_definition
1359+
1360+
data = data.withColumn("current_state", current_state).drop(*metricCols)
1361+
1362+
data = (
1363+
data.withColumn(
1364+
"previous_state",
1365+
f.lag(f.col("current_state"), offset=1).over(w),
1366+
)
1367+
.withColumn(
1368+
"previous_ts",
1369+
f.lag(f.col(self.ts_col), offset=1).over(w),
1370+
)
1371+
.filter(f.col("previous_state").isNotNull())
1372+
)
1373+
1374+
if type(state_definition) is str:
1375+
state_change_exp = f"""
1376+
!(current_state {state_definition} previous_state)
1377+
"""
1378+
else:
1379+
state_change_exp = "!(current_state AND previous_state)"
1380+
1381+
data = data.withColumn(
1382+
"state_change",
1383+
f.expr(state_change_exp),
1384+
).drop("current_state", "previous_state")
1385+
1386+
data = (
1387+
data.withColumn(
1388+
"state_incrementer",
1389+
f.sum(f.col("state_change").cast("int")).over(w),
1390+
)
1391+
.filter(~f.col("state_change"))
1392+
.drop("state_change")
1393+
)
1394+
1395+
data = (
1396+
data.groupBy(*self.partitionCols, "state_incrementer")
1397+
.agg(
1398+
f.struct(
1399+
f.min("previous_ts").alias("start"),
1400+
f.max(f"{self.ts_col}").alias("end"),
1401+
).alias(self.ts_col),
1402+
)
1403+
.drop("state_incrementer")
1404+
)
1405+
1406+
result = data.select(
1407+
self.ts_col,
1408+
*self.partitionCols,
1409+
)
1410+
1411+
return TSDF(
1412+
result,
1413+
self.ts_col,
1414+
self.partitionCols,
1415+
)
1416+
13371417

13381418
class _ResampledTSDF(TSDF):
13391419
def __init__(

python/tests/base.py

-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ def assertDataFramesEqual(self, dfA, dfB):
203203
sortedB = dfB.select(colOrder)
204204
# must have identical data
205205
# that is all rows in A must be in B, and vice-versa
206-
207206
self.assertEqual(sortedA.subtract(sortedB).count(), 0)
208207
self.assertEqual(sortedB.subtract(sortedA).count(), 0)
209208

python/tests/tsdf_tests.py

+135
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from dateutil import parser as dt_parser
44

55
import pyspark.sql.functions as F
6+
from pyspark.sql.dataframe import DataFrame
67

78
from tempo.tsdf import TSDF
89
from tests.base import SparkTest
@@ -450,6 +451,140 @@ def test_upsample(self):
450451
self.assertDataFramesEqual(bars, barsExpected)
451452

452453

454+
class extractStateIntervalsTest(SparkTest):
455+
"""Test of finding time ranges for metrics with constant state."""
456+
457+
def create_expected_test_df(
458+
self,
459+
df,
460+
) -> DataFrame:
461+
return (
462+
# StringType not converting to TimeStamp type inside of struct so forcing
463+
df.withColumn(
464+
"event_ts",
465+
F.struct(
466+
F.to_timestamp("event_ts.start").alias("start"),
467+
F.to_timestamp("event_ts.end").alias("end"),
468+
),
469+
)
470+
)
471+
472+
def test_eq_extractStateIntervals(self):
473+
474+
# construct dataframes
475+
input_tsdf = self.get_data_as_tsdf("input")
476+
expected_df = self.get_data_as_sdf("expected")
477+
expected_df = self.create_expected_test_df(expected_df)
478+
479+
# call extractStateIntervals method
480+
extractStateIntervals_eq_1_df = input_tsdf.extractStateIntervals(
481+
"metric_1", "metric_2", "metric_3"
482+
).df
483+
extractStateIntervals_eq_2_df = input_tsdf.extractStateIntervals(
484+
"metric_1", "metric_2", "metric_3", state_definition="<=>"
485+
).df
486+
487+
# test extractStateIntervals_tsdf summary
488+
self.assertDataFramesEqual(extractStateIntervals_eq_1_df, expected_df)
489+
self.assertDataFramesEqual(extractStateIntervals_eq_2_df, expected_df)
490+
491+
def test_ne_extractStateIntervals(self):
492+
493+
# construct dataframes
494+
input_tsdf = self.get_data_as_tsdf("input")
495+
expected_df = self.get_data_as_sdf("expected")
496+
expected_df = self.create_expected_test_df(expected_df)
497+
498+
# call extractStateIntervals method
499+
extractStateIntervals_ne_1_df = input_tsdf.extractStateIntervals(
500+
"metric_1", "metric_2", "metric_3", state_definition="!="
501+
).df
502+
extractStateIntervals_ne_2_df = input_tsdf.extractStateIntervals(
503+
"metric_1", "metric_2", "metric_3", state_definition="<>"
504+
).df
505+
506+
# test extractStateIntervals_tsdf summary
507+
self.assertDataFramesEqual(extractStateIntervals_ne_1_df, expected_df)
508+
self.assertDataFramesEqual(extractStateIntervals_ne_2_df, expected_df)
509+
510+
def test_gt_extractStateIntervals(self):
511+
512+
# construct dataframes
513+
input_tsdf = self.get_data_as_tsdf("input")
514+
expected_df = self.get_data_as_sdf("expected")
515+
expected_df = self.create_expected_test_df(expected_df)
516+
517+
# call extractStateIntervals method
518+
extractStateIntervals_gt_df = input_tsdf.extractStateIntervals(
519+
"metric_1", "metric_2", "metric_3", state_definition=">"
520+
).df
521+
522+
self.assertDataFramesEqual(extractStateIntervals_gt_df, expected_df)
523+
524+
def test_lt_extractStateIntervals(self):
525+
# construct dataframes
526+
input_tsdf = self.get_data_as_tsdf("input")
527+
expected_df = self.get_data_as_sdf("expected")
528+
expected_df = self.create_expected_test_df(expected_df)
529+
530+
# call extractStateIntervals method
531+
extractStateIntervals_lt_df = input_tsdf.extractStateIntervals(
532+
"metric_1", "metric_2", "metric_3", state_definition="<"
533+
).df
534+
535+
# test extractStateIntervals_tsdf summary
536+
self.assertDataFramesEqual(extractStateIntervals_lt_df, expected_df)
537+
538+
def test_gte_extractStateIntervals(self):
539+
# construct dataframes
540+
input_tsdf = self.get_data_as_tsdf("input")
541+
expected_df = self.get_data_as_sdf("expected")
542+
expected_df = self.create_expected_test_df(expected_df)
543+
544+
# call extractStateIntervals method
545+
extractStateIntervals_gt_df = input_tsdf.extractStateIntervals(
546+
"metric_1", "metric_2", "metric_3", state_definition=">="
547+
).df
548+
549+
self.assertDataFramesEqual(extractStateIntervals_gt_df, expected_df)
550+
551+
def test_lte_extractStateIntervals(self):
552+
553+
# construct dataframes
554+
input_tsdf = self.get_data_as_tsdf("input")
555+
expected_df = self.get_data_as_sdf("expected")
556+
expected_df = self.create_expected_test_df(expected_df)
557+
558+
# call extractStateIntervals method
559+
extractStateIntervals_lte_df = input_tsdf.extractStateIntervals(
560+
"metric_1", "metric_2", "metric_3", state_definition="<="
561+
).df
562+
563+
# test extractStateIntervals_tsdf summary
564+
self.assertDataFramesEqual(extractStateIntervals_lte_df, expected_df)
565+
566+
def test_bool_col_extractStateIntervals(self):
567+
568+
# construct dataframes
569+
input_tsdf = self.get_data_as_tsdf("input")
570+
expected_df = self.get_data_as_sdf("expected")
571+
expected_df = self.create_expected_test_df(expected_df)
572+
573+
# call extractStateIntervals method
574+
extractStateIntervals_bool_col_df = input_tsdf.extractStateIntervals(
575+
"metric_1",
576+
"metric_2",
577+
"metric_3",
578+
state_definition=F.abs(
579+
F.col("metric_1") - F.col("metric_2") - F.col("metric_3")
580+
)
581+
< F.lit(10),
582+
).df
583+
584+
# test extractStateIntervals_tsdf summary
585+
self.assertDataFramesEqual(extractStateIntervals_bool_col_df, expected_df)
586+
587+
453588
# MAIN
454589
if __name__ == "__main__":
455590
unittest.main()

0 commit comments

Comments
 (0)