Skip to content

Commit a51f1bf

Browse files
R7L208Tristan Nixon
and
Tristan Nixon
authored
Update TSDF.extractStateInterval() to perform state comparison per metric column. (#234)
* working updates to test cases for extractStateInterval * refactored extractStateIntervals function: * fix for issue 232 * allowing function of (Column, Column) -> Column as alternate state_definition arg * returning basic DF, not TSDF * More verbose messages on assertion failures fix bug when extracting non-TSDF compatible test dataframes * updated extractStateIntervals test for threshold function * no longer any need to convert expected dataframes from extractStateIntervals into special format * mockup of new way to do column comparisons * working changes * extractStateInterval to operate per column instead of on array test cases need to be updated to allow 5 return columns instead of 4 * remove negation on filter for state change in extractStateIntervals * update type signature of extractStateIntervals * new and fixed test cases for extractStateIntervals * update type signature of extractStateIntervals * black & flake8 on tests/base.py if needed for CI checks * black & flake8 * new test cases for TestExtractStateInterval * update typing for extractStateIntervals * update typing for extractStateIntervals Co-authored-by: Tristan Nixon <[email protected]>
1 parent 1b5a3bb commit a51f1bf

File tree

4 files changed

+1509
-325
lines changed

4 files changed

+1509
-325
lines changed

python/tempo/tsdf.py

+99-56
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

33
import logging
4+
import operator
45
from functools import reduce
5-
from typing import List, Collection, Union
6+
from typing import List, Union, Callable
67

78
import numpy as np
89
import pyspark.sql.functions as f
@@ -1339,80 +1340,122 @@ def tempo_fourier_util(pdf):
13391340

13401341
def extractStateIntervals(
13411342
self,
1342-
*metricCols: Collection[str],
1343-
state_definition: Union[str, Column[bool]] = "=",
1344-
) -> TSDF:
1345-
1346-
data = self.df
1343+
*metric_cols: str,
1344+
state_definition: Union[str, Callable[[Column, Column], Column]] = "=",
1345+
) -> DataFrame:
1346+
"""
1347+
Extracts intervals from a :class:`~tsdf.TSDF` based on some notion of "state", as defined by the :param
1348+
state_definition: parameter. The state definition consists of a comparison operation between the current and
1349+
previous values of a metric. If the comparison operation evaluates to true across all metric columns,
1350+
then we consider both points to be in the same "state". Changes of state occur when the comparison operator
1351+
returns false for any given metric column. So, the default state definition ('=') entails that intervals of
1352+
time wherein the metrics all remained constant. A state definition of '>=' would extract intervals wherein
1353+
the metrics were all monotonically increasing.
1354+
1355+
:param: metric_cols: the set of metric columns to evaluate for state changes
1356+
:param: state_definition: the comparison function used to evaluate individual metrics for state changes.
1357+
Either a string, giving a standard PySpark column comparison operation, or a binary function with the
1358+
signature: `(x1: Column, x2: Column) -> Column` where the returned column expression evaluates to a
1359+
:class:`~pyspark.sql.types.BooleanType`
1360+
1361+
:return: a :class:`~pyspark.sql.DataFrame` object containing the resulting intervals
1362+
"""
13471363

1348-
w = self.__baseWindow()
1364+
# https://spark.apache.org/docs/latest/sql-ref-null-semantics.html#comparison-operators-
1365+
def null_safe_equals(col1: Column, col2: Column) -> Column:
1366+
return (
1367+
f.when(col1.isNull() & col2.isNull(), True)
1368+
.when(col1.isNull() | col2.isNull(), False)
1369+
.otherwise(operator.eq(col1, col2))
1370+
)
13491371

1372+
operator_dict = {
1373+
# https://spark.apache.org/docs/latest/api/sql/#_2
1374+
"!=": operator.ne,
1375+
# https://spark.apache.org/docs/latest/api/sql/#_11
1376+
"<>": operator.ne,
1377+
# https://spark.apache.org/docs/latest/api/sql/#_8
1378+
"<": operator.lt,
1379+
# https://spark.apache.org/docs/latest/api/sql/#_9
1380+
"<=": operator.le,
1381+
# https://spark.apache.org/docs/latest/api/sql/#_10
1382+
"<=>": null_safe_equals,
1383+
# https://spark.apache.org/docs/latest/api/sql/#_12
1384+
"=": operator.eq,
1385+
# https://spark.apache.org/docs/latest/api/sql/#_13
1386+
"==": operator.eq,
1387+
# https://spark.apache.org/docs/latest/api/sql/#_14
1388+
">": operator.gt,
1389+
# https://spark.apache.org/docs/latest/api/sql/#_15
1390+
">=": operator.ge,
1391+
}
1392+
1393+
# Validate state definition and construct state comparison function
13501394
if type(state_definition) is str:
1351-
if state_definition not in ("=", "<=>", "!=", "<>", ">", "<", ">=", "<="):
1352-
logger.warning(
1353-
"A `state_definition` which has not been tested was"
1354-
"provided to the `extractStateIntervals` method."
1395+
if state_definition not in operator_dict.keys():
1396+
raise ValueError(
1397+
f"Invalid comparison operator for `state_definition` argument: {state_definition}."
13551398
)
1356-
current_state = f.array(*metricCols)
1357-
else:
1358-
current_state = state_definition
13591399

1360-
data = data.withColumn("current_state", current_state).drop(*metricCols)
1400+
def state_comparison_fn(a, b):
1401+
return operator_dict[state_definition](a, b)
13611402

1362-
data = (
1363-
data.withColumn(
1364-
"previous_state",
1365-
f.lag(f.col("current_state"), offset=1).over(w),
1366-
)
1367-
.withColumn(
1368-
"previous_ts",
1369-
f.lag(f.col(self.ts_col), offset=1).over(w),
1370-
)
1371-
.filter(f.col("previous_state").isNotNull())
1372-
)
1403+
elif callable(state_definition):
1404+
state_comparison_fn = state_definition
13731405

1374-
if type(state_definition) is str:
1375-
state_change_exp = f"""
1376-
!(current_state {state_definition} previous_state)
1377-
"""
13781406
else:
1379-
state_change_exp = "!(current_state AND previous_state)"
1407+
raise TypeError(
1408+
f"The `state_definition` argument can be of type `str` or `callable`, "
1409+
f"but received value of type {type(state_definition)}"
1410+
)
13801411

1412+
w = self.__baseWindow()
1413+
1414+
data = self.df
1415+
1416+
# Get previous timestamp to identify start time of the interval
13811417
data = data.withColumn(
1382-
"state_change",
1383-
f.expr(state_change_exp),
1384-
).drop("current_state", "previous_state")
1385-
1386-
data = (
1387-
data.withColumn(
1388-
"state_incrementer",
1389-
f.sum(f.col("state_change").cast("int")).over(w),
1418+
"previous_ts",
1419+
f.lag(f.col(self.ts_col), offset=1).over(w),
1420+
)
1421+
1422+
# Determine state intervals using user-provided the state comparison function
1423+
# The comparison occurs on the current and previous record per metric column
1424+
temp_metric_compare_cols = []
1425+
for mc in metric_cols:
1426+
temp_metric_compare_col = f"__{mc}_compare"
1427+
data = data.withColumn(
1428+
temp_metric_compare_col,
1429+
state_comparison_fn(f.col(mc), f.lag(f.col(mc), 1).over(w)),
13901430
)
1391-
.filter(~f.col("state_change"))
1392-
.drop("state_change")
1431+
temp_metric_compare_cols.append(temp_metric_compare_col)
1432+
1433+
# Remove first record which will have no state change
1434+
# and produces `null` for all state comparisons
1435+
data = data.filter(f.col("previous_ts").isNotNull())
1436+
1437+
# Each state comparison should return True if state remained constant
1438+
data = data.withColumn(
1439+
"state_change", f.array_contains(f.array(*temp_metric_compare_cols), False)
13931440
)
13941441

1395-
data = (
1442+
# Count the distinct state changes to get the unique intervals
1443+
data = data.withColumn(
1444+
"state_incrementer",
1445+
f.sum(f.col("state_change").cast("int")).over(w),
1446+
).filter(~f.col("state_change"))
1447+
1448+
# Find the start and end timestamp of the interval
1449+
result = (
13961450
data.groupBy(*self.partitionCols, "state_incrementer")
13971451
.agg(
1398-
f.struct(
1399-
f.min("previous_ts").alias("start"),
1400-
f.max(f"{self.ts_col}").alias("end"),
1401-
).alias(self.ts_col),
1452+
f.min("previous_ts").alias("start_ts"),
1453+
f.max(self.ts_col).alias("end_ts"),
14021454
)
14031455
.drop("state_incrementer")
14041456
)
14051457

1406-
result = data.select(
1407-
self.ts_col,
1408-
*self.partitionCols,
1409-
)
1410-
1411-
return TSDF(
1412-
result,
1413-
self.ts_col,
1414-
self.partitionCols,
1415-
)
1458+
return result
14161459

14171460

14181461
class _ResampledTSDF(TSDF):

python/tests/base.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def get_data_as_sdf(self, name: str, convert_ts_col=True):
6464
td = self.test_data[name]
6565
ts_cols = []
6666
if convert_ts_col and (td.get("ts_col", None) or td.get("other_ts_cols", [])):
67-
ts_cols = [td["ts_col"]]
67+
ts_cols = [td["ts_col"]] if "ts_col" in td else []
6868
ts_cols.extend(td.get("other_ts_cols", []))
6969
return self.buildTestDF(td["schema"], td["data"], ts_cols)
7070

@@ -160,8 +160,16 @@ def assertFieldsEqual(self, fieldA, fieldB):
160160
"""
161161
Test that two fields are equivalent
162162
"""
163-
self.assertEqual(fieldA.name.lower(), fieldB.name.lower())
164-
self.assertEqual(fieldA.dataType, fieldB.dataType)
163+
self.assertEqual(
164+
fieldA.name.lower(),
165+
fieldB.name.lower(),
166+
msg=f"Field {fieldA} has different name from {fieldB}",
167+
)
168+
self.assertEqual(
169+
fieldA.dataType,
170+
fieldB.dataType,
171+
msg=f"Field {fieldA} has different type from {fieldB}",
172+
)
165173
# self.assertEqual(fieldA.nullable, fieldB.nullable)
166174

167175
def assertSchemaContainsField(self, schema, field):
@@ -203,8 +211,16 @@ def assertDataFramesEqual(self, dfA, dfB):
203211
sortedB = dfB.select(colOrder)
204212
# must have identical data
205213
# that is all rows in A must be in B, and vice-versa
206-
self.assertEqual(sortedA.subtract(sortedB).count(), 0)
207-
self.assertEqual(sortedB.subtract(sortedA).count(), 0)
214+
self.assertEqual(
215+
sortedA.subtract(sortedB).count(),
216+
0,
217+
msg="There are rows in DataFrame A that are not in DataFrame B",
218+
)
219+
self.assertEqual(
220+
sortedB.subtract(sortedA).count(),
221+
0,
222+
msg="There are rows in DataFrame B that are not in DataFrame A",
223+
)
208224

209225
def assertTSDFsEqual(self, tsdfA, tsdfB):
210226
"""

0 commit comments

Comments
 (0)