|
3 | 3 | from dateutil import parser as dt_parser
|
4 | 4 |
|
5 | 5 | import pyspark.sql.functions as F
|
| 6 | +from pyspark.sql.dataframe import DataFrame |
6 | 7 |
|
7 | 8 | from tempo.tsdf import TSDF
|
8 | 9 | from tests.base import SparkTest
|
@@ -450,6 +451,140 @@ def test_upsample(self):
|
450 | 451 | self.assertDataFramesEqual(bars, barsExpected)
|
451 | 452 |
|
452 | 453 |
|
| 454 | +class extractStateIntervalsTest(SparkTest): |
| 455 | + """Test of finding time ranges for metrics with constant state.""" |
| 456 | + |
| 457 | + def create_expected_test_df( |
| 458 | + self, |
| 459 | + df, |
| 460 | + ) -> DataFrame: |
| 461 | + return ( |
| 462 | + # StringType not converting to TimeStamp type inside of struct so forcing |
| 463 | + df.withColumn( |
| 464 | + "event_ts", |
| 465 | + F.struct( |
| 466 | + F.to_timestamp("event_ts.start").alias("start"), |
| 467 | + F.to_timestamp("event_ts.end").alias("end"), |
| 468 | + ), |
| 469 | + ) |
| 470 | + ) |
| 471 | + |
| 472 | + def test_eq_extractStateIntervals(self): |
| 473 | + |
| 474 | + # construct dataframes |
| 475 | + input_tsdf = self.get_data_as_tsdf("input") |
| 476 | + expected_df = self.get_data_as_sdf("expected") |
| 477 | + expected_df = self.create_expected_test_df(expected_df) |
| 478 | + |
| 479 | + # call extractStateIntervals method |
| 480 | + extractStateIntervals_eq_1_df = input_tsdf.extractStateIntervals( |
| 481 | + "metric_1", "metric_2", "metric_3" |
| 482 | + ).df |
| 483 | + extractStateIntervals_eq_2_df = input_tsdf.extractStateIntervals( |
| 484 | + "metric_1", "metric_2", "metric_3", state_definition="<=>" |
| 485 | + ).df |
| 486 | + |
| 487 | + # test extractStateIntervals_tsdf summary |
| 488 | + self.assertDataFramesEqual(extractStateIntervals_eq_1_df, expected_df) |
| 489 | + self.assertDataFramesEqual(extractStateIntervals_eq_2_df, expected_df) |
| 490 | + |
| 491 | + def test_ne_extractStateIntervals(self): |
| 492 | + |
| 493 | + # construct dataframes |
| 494 | + input_tsdf = self.get_data_as_tsdf("input") |
| 495 | + expected_df = self.get_data_as_sdf("expected") |
| 496 | + expected_df = self.create_expected_test_df(expected_df) |
| 497 | + |
| 498 | + # call extractStateIntervals method |
| 499 | + extractStateIntervals_ne_1_df = input_tsdf.extractStateIntervals( |
| 500 | + "metric_1", "metric_2", "metric_3", state_definition="!=" |
| 501 | + ).df |
| 502 | + extractStateIntervals_ne_2_df = input_tsdf.extractStateIntervals( |
| 503 | + "metric_1", "metric_2", "metric_3", state_definition="<>" |
| 504 | + ).df |
| 505 | + |
| 506 | + # test extractStateIntervals_tsdf summary |
| 507 | + self.assertDataFramesEqual(extractStateIntervals_ne_1_df, expected_df) |
| 508 | + self.assertDataFramesEqual(extractStateIntervals_ne_2_df, expected_df) |
| 509 | + |
| 510 | + def test_gt_extractStateIntervals(self): |
| 511 | + |
| 512 | + # construct dataframes |
| 513 | + input_tsdf = self.get_data_as_tsdf("input") |
| 514 | + expected_df = self.get_data_as_sdf("expected") |
| 515 | + expected_df = self.create_expected_test_df(expected_df) |
| 516 | + |
| 517 | + # call extractStateIntervals method |
| 518 | + extractStateIntervals_gt_df = input_tsdf.extractStateIntervals( |
| 519 | + "metric_1", "metric_2", "metric_3", state_definition=">" |
| 520 | + ).df |
| 521 | + |
| 522 | + self.assertDataFramesEqual(extractStateIntervals_gt_df, expected_df) |
| 523 | + |
| 524 | + def test_lt_extractStateIntervals(self): |
| 525 | + # construct dataframes |
| 526 | + input_tsdf = self.get_data_as_tsdf("input") |
| 527 | + expected_df = self.get_data_as_sdf("expected") |
| 528 | + expected_df = self.create_expected_test_df(expected_df) |
| 529 | + |
| 530 | + # call extractStateIntervals method |
| 531 | + extractStateIntervals_lt_df = input_tsdf.extractStateIntervals( |
| 532 | + "metric_1", "metric_2", "metric_3", state_definition="<" |
| 533 | + ).df |
| 534 | + |
| 535 | + # test extractStateIntervals_tsdf summary |
| 536 | + self.assertDataFramesEqual(extractStateIntervals_lt_df, expected_df) |
| 537 | + |
| 538 | + def test_gte_extractStateIntervals(self): |
| 539 | + # construct dataframes |
| 540 | + input_tsdf = self.get_data_as_tsdf("input") |
| 541 | + expected_df = self.get_data_as_sdf("expected") |
| 542 | + expected_df = self.create_expected_test_df(expected_df) |
| 543 | + |
| 544 | + # call extractStateIntervals method |
| 545 | + extractStateIntervals_gt_df = input_tsdf.extractStateIntervals( |
| 546 | + "metric_1", "metric_2", "metric_3", state_definition=">=" |
| 547 | + ).df |
| 548 | + |
| 549 | + self.assertDataFramesEqual(extractStateIntervals_gt_df, expected_df) |
| 550 | + |
| 551 | + def test_lte_extractStateIntervals(self): |
| 552 | + |
| 553 | + # construct dataframes |
| 554 | + input_tsdf = self.get_data_as_tsdf("input") |
| 555 | + expected_df = self.get_data_as_sdf("expected") |
| 556 | + expected_df = self.create_expected_test_df(expected_df) |
| 557 | + |
| 558 | + # call extractStateIntervals method |
| 559 | + extractStateIntervals_lte_df = input_tsdf.extractStateIntervals( |
| 560 | + "metric_1", "metric_2", "metric_3", state_definition="<=" |
| 561 | + ).df |
| 562 | + |
| 563 | + # test extractStateIntervals_tsdf summary |
| 564 | + self.assertDataFramesEqual(extractStateIntervals_lte_df, expected_df) |
| 565 | + |
| 566 | + def test_bool_col_extractStateIntervals(self): |
| 567 | + |
| 568 | + # construct dataframes |
| 569 | + input_tsdf = self.get_data_as_tsdf("input") |
| 570 | + expected_df = self.get_data_as_sdf("expected") |
| 571 | + expected_df = self.create_expected_test_df(expected_df) |
| 572 | + |
| 573 | + # call extractStateIntervals method |
| 574 | + extractStateIntervals_bool_col_df = input_tsdf.extractStateIntervals( |
| 575 | + "metric_1", |
| 576 | + "metric_2", |
| 577 | + "metric_3", |
| 578 | + state_definition=F.abs( |
| 579 | + F.col("metric_1") - F.col("metric_2") - F.col("metric_3") |
| 580 | + ) |
| 581 | + < F.lit(10), |
| 582 | + ).df |
| 583 | + |
| 584 | + # test extractStateIntervals_tsdf summary |
| 585 | + self.assertDataFramesEqual(extractStateIntervals_bool_col_df, expected_df) |
| 586 | + |
| 587 | + |
453 | 588 | # MAIN
|
454 | 589 | if __name__ == "__main__":
|
455 | 590 | unittest.main()
|
0 commit comments