|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import logging
|
| 4 | +import operator |
4 | 5 | from functools import reduce
|
5 |
| -from typing import List, Collection, Union |
| 6 | +from typing import List, Union, Callable |
6 | 7 |
|
7 | 8 | import numpy as np
|
8 | 9 | import pyspark.sql.functions as f
|
@@ -1339,80 +1340,122 @@ def tempo_fourier_util(pdf):
|
1339 | 1340 |
|
1340 | 1341 | def extractStateIntervals(
|
1341 | 1342 | self,
|
1342 |
| - *metricCols: Collection[str], |
1343 |
| - state_definition: Union[str, Column[bool]] = "=", |
1344 |
| - ) -> TSDF: |
1345 |
| - |
1346 |
| - data = self.df |
| 1343 | + *metric_cols: str, |
| 1344 | + state_definition: Union[str, Callable[[Column, Column], Column]] = "=", |
| 1345 | + ) -> DataFrame: |
| 1346 | + """ |
| 1347 | + Extracts intervals from a :class:`~tsdf.TSDF` based on some notion of "state", as defined by the :param |
| 1348 | + state_definition: parameter. The state definition consists of a comparison operation between the current and |
| 1349 | + previous values of a metric. If the comparison operation evaluates to true across all metric columns, |
| 1350 | + then we consider both points to be in the same "state". Changes of state occur when the comparison operator |
| 1351 | + returns false for any given metric column. So, the default state definition ('=') entails that intervals of |
| 1352 | + time wherein the metrics all remained constant. A state definition of '>=' would extract intervals wherein |
| 1353 | + the metrics were all monotonically increasing. |
| 1354 | +
|
| 1355 | + :param: metric_cols: the set of metric columns to evaluate for state changes |
| 1356 | + :param: state_definition: the comparison function used to evaluate individual metrics for state changes. |
| 1357 | + Either a string, giving a standard PySpark column comparison operation, or a binary function with the |
| 1358 | + signature: `(x1: Column, x2: Column) -> Column` where the returned column expression evaluates to a |
| 1359 | + :class:`~pyspark.sql.types.BooleanType` |
| 1360 | +
|
| 1361 | + :return: a :class:`~pyspark.sql.DataFrame` object containing the resulting intervals |
| 1362 | + """ |
1347 | 1363 |
|
1348 |
| - w = self.__baseWindow() |
| 1364 | + # https://spark.apache.org/docs/latest/sql-ref-null-semantics.html#comparison-operators- |
| 1365 | + def null_safe_equals(col1: Column, col2: Column) -> Column: |
| 1366 | + return ( |
| 1367 | + f.when(col1.isNull() & col2.isNull(), True) |
| 1368 | + .when(col1.isNull() | col2.isNull(), False) |
| 1369 | + .otherwise(operator.eq(col1, col2)) |
| 1370 | + ) |
1349 | 1371 |
|
| 1372 | + operator_dict = { |
| 1373 | + # https://spark.apache.org/docs/latest/api/sql/#_2 |
| 1374 | + "!=": operator.ne, |
| 1375 | + # https://spark.apache.org/docs/latest/api/sql/#_11 |
| 1376 | + "<>": operator.ne, |
| 1377 | + # https://spark.apache.org/docs/latest/api/sql/#_8 |
| 1378 | + "<": operator.lt, |
| 1379 | + # https://spark.apache.org/docs/latest/api/sql/#_9 |
| 1380 | + "<=": operator.le, |
| 1381 | + # https://spark.apache.org/docs/latest/api/sql/#_10 |
| 1382 | + "<=>": null_safe_equals, |
| 1383 | + # https://spark.apache.org/docs/latest/api/sql/#_12 |
| 1384 | + "=": operator.eq, |
| 1385 | + # https://spark.apache.org/docs/latest/api/sql/#_13 |
| 1386 | + "==": operator.eq, |
| 1387 | + # https://spark.apache.org/docs/latest/api/sql/#_14 |
| 1388 | + ">": operator.gt, |
| 1389 | + # https://spark.apache.org/docs/latest/api/sql/#_15 |
| 1390 | + ">=": operator.ge, |
| 1391 | + } |
| 1392 | + |
| 1393 | + # Validate state definition and construct state comparison function |
1350 | 1394 | if type(state_definition) is str:
|
1351 |
| - if state_definition not in ("=", "<=>", "!=", "<>", ">", "<", ">=", "<="): |
1352 |
| - logger.warning( |
1353 |
| - "A `state_definition` which has not been tested was" |
1354 |
| - "provided to the `extractStateIntervals` method." |
| 1395 | + if state_definition not in operator_dict.keys(): |
| 1396 | + raise ValueError( |
| 1397 | + f"Invalid comparison operator for `state_definition` argument: {state_definition}." |
1355 | 1398 | )
|
1356 |
| - current_state = f.array(*metricCols) |
1357 |
| - else: |
1358 |
| - current_state = state_definition |
1359 | 1399 |
|
1360 |
| - data = data.withColumn("current_state", current_state).drop(*metricCols) |
| 1400 | + def state_comparison_fn(a, b): |
| 1401 | + return operator_dict[state_definition](a, b) |
1361 | 1402 |
|
1362 |
| - data = ( |
1363 |
| - data.withColumn( |
1364 |
| - "previous_state", |
1365 |
| - f.lag(f.col("current_state"), offset=1).over(w), |
1366 |
| - ) |
1367 |
| - .withColumn( |
1368 |
| - "previous_ts", |
1369 |
| - f.lag(f.col(self.ts_col), offset=1).over(w), |
1370 |
| - ) |
1371 |
| - .filter(f.col("previous_state").isNotNull()) |
1372 |
| - ) |
| 1403 | + elif callable(state_definition): |
| 1404 | + state_comparison_fn = state_definition |
1373 | 1405 |
|
1374 |
| - if type(state_definition) is str: |
1375 |
| - state_change_exp = f""" |
1376 |
| - !(current_state {state_definition} previous_state) |
1377 |
| - """ |
1378 | 1406 | else:
|
1379 |
| - state_change_exp = "!(current_state AND previous_state)" |
| 1407 | + raise TypeError( |
| 1408 | + f"The `state_definition` argument can be of type `str` or `callable`, " |
| 1409 | + f"but received value of type {type(state_definition)}" |
| 1410 | + ) |
1380 | 1411 |
|
| 1412 | + w = self.__baseWindow() |
| 1413 | + |
| 1414 | + data = self.df |
| 1415 | + |
| 1416 | + # Get previous timestamp to identify start time of the interval |
1381 | 1417 | data = data.withColumn(
|
1382 |
| - "state_change", |
1383 |
| - f.expr(state_change_exp), |
1384 |
| - ).drop("current_state", "previous_state") |
1385 |
| - |
1386 |
| - data = ( |
1387 |
| - data.withColumn( |
1388 |
| - "state_incrementer", |
1389 |
| - f.sum(f.col("state_change").cast("int")).over(w), |
| 1418 | + "previous_ts", |
| 1419 | + f.lag(f.col(self.ts_col), offset=1).over(w), |
| 1420 | + ) |
| 1421 | + |
| 1422 | + # Determine state intervals using user-provided the state comparison function |
| 1423 | + # The comparison occurs on the current and previous record per metric column |
| 1424 | + temp_metric_compare_cols = [] |
| 1425 | + for mc in metric_cols: |
| 1426 | + temp_metric_compare_col = f"__{mc}_compare" |
| 1427 | + data = data.withColumn( |
| 1428 | + temp_metric_compare_col, |
| 1429 | + state_comparison_fn(f.col(mc), f.lag(f.col(mc), 1).over(w)), |
1390 | 1430 | )
|
1391 |
| - .filter(~f.col("state_change")) |
1392 |
| - .drop("state_change") |
| 1431 | + temp_metric_compare_cols.append(temp_metric_compare_col) |
| 1432 | + |
| 1433 | + # Remove first record which will have no state change |
| 1434 | + # and produces `null` for all state comparisons |
| 1435 | + data = data.filter(f.col("previous_ts").isNotNull()) |
| 1436 | + |
| 1437 | + # Each state comparison should return True if state remained constant |
| 1438 | + data = data.withColumn( |
| 1439 | + "state_change", f.array_contains(f.array(*temp_metric_compare_cols), False) |
1393 | 1440 | )
|
1394 | 1441 |
|
1395 |
| - data = ( |
| 1442 | + # Count the distinct state changes to get the unique intervals |
| 1443 | + data = data.withColumn( |
| 1444 | + "state_incrementer", |
| 1445 | + f.sum(f.col("state_change").cast("int")).over(w), |
| 1446 | + ).filter(~f.col("state_change")) |
| 1447 | + |
| 1448 | + # Find the start and end timestamp of the interval |
| 1449 | + result = ( |
1396 | 1450 | data.groupBy(*self.partitionCols, "state_incrementer")
|
1397 | 1451 | .agg(
|
1398 |
| - f.struct( |
1399 |
| - f.min("previous_ts").alias("start"), |
1400 |
| - f.max(f"{self.ts_col}").alias("end"), |
1401 |
| - ).alias(self.ts_col), |
| 1452 | + f.min("previous_ts").alias("start_ts"), |
| 1453 | + f.max(self.ts_col).alias("end_ts"), |
1402 | 1454 | )
|
1403 | 1455 | .drop("state_incrementer")
|
1404 | 1456 | )
|
1405 | 1457 |
|
1406 |
| - result = data.select( |
1407 |
| - self.ts_col, |
1408 |
| - *self.partitionCols, |
1409 |
| - ) |
1410 |
| - |
1411 |
| - return TSDF( |
1412 |
| - result, |
1413 |
| - self.ts_col, |
1414 |
| - self.partitionCols, |
1415 |
| - ) |
| 1458 | + return result |
1416 | 1459 |
|
1417 | 1460 |
|
1418 | 1461 | class _ResampledTSDF(TSDF):
|
|
0 commit comments