Skip to content

Commit cca1458

Browse files
authored
Issue148 resample frequency (#199)
* added micro and millisecond support to window function * adding test for microsecond resampling * added correct abbrev for milliseconds and fixed parsing * added test for millisecond resample
1 parent 86657c1 commit cca1458

File tree

2 files changed

+61
-6
lines changed

2 files changed

+61
-6
lines changed

python/tempo/resample.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from pyspark.sql.window import Window
55

66
# define global frequency options
7-
7+
MUSEC = 'microsec'
8+
MS = 'ms'
89
SEC = 'sec'
910
MIN = 'min'
1011
HR = 'hr'
@@ -17,9 +18,9 @@
1718
average = "mean"
1819
ceiling = "ceil"
1920

20-
freq_dict = {'sec' : 'seconds', 'min' : 'minutes', 'hr' : 'hours', 'day' : 'days', 'hour' : 'hours'}
21+
freq_dict = {'microsec' : 'microseconds','ms' : 'milliseconds','sec' : 'seconds', 'min' : 'minutes', 'hr' : 'hours', 'day' : 'days', 'hour' : 'hours'}
2122

22-
allowableFreqs = [SEC, MIN, HR, DAY]
23+
allowableFreqs = [MUSEC, MS, SEC, MIN, HR, DAY]
2324
allowableFuncs = [floor, min, max, average, ceiling]
2425

2526
def __appendAggKey(tsdf, freq = None):
@@ -41,7 +42,7 @@ def aggregate(tsdf, freq, func, metricCols = None, prefix = None, fill = None):
4142
:param tsdf: input TSDF object
4243
:param func: aggregate function
4344
:param metricCols: columns used for aggregates
44-
:param prefix the metric columns with the aggregate named function
45+
:param prefix: the metric columns with the aggregate named function
4546
:param fill: upsample based on the time increment for 0s in numeric columns
4647
:return: TSDF object with newly aggregated timestamp as ts_col with aggregated values
4748
"""
@@ -118,13 +119,22 @@ def aggregate(tsdf, freq, func, metricCols = None, prefix = None, fill = None):
118119

119120

120121
def checkAllowableFreq(freq):
122+
"""
123+
Parses frequency and checks against allowable frequencies
124+
:param freq: frequncy at which to upsample/downsample, declared in resample function
125+
:return: list of parsed frequency value and time suffix
126+
"""
121127
if freq not in allowableFreqs:
122128
try:
123129
periods = freq.lower().split(" ")[0].strip()
124130
units = freq.lower().split(" ")[1].strip()
125131
except:
126-
raise ValueError("Allowable grouping frequencies are sec (second), min (minute), hr (hour), day. Reformat your frequency as <integer> <day/hour/minute/second>")
127-
if units.startswith(SEC):
132+
raise ValueError("Allowable grouping frequencies are microsecond (musec), millisecond (ms), sec (second), min (minute), hr (hour), day. Reformat your frequency as <integer> <day/hour/minute/second>")
133+
if units.startswith(MUSEC):
134+
return (periods, MUSEC)
135+
elif units.startswith(MS) | units.startswith("millis"):
136+
return (periods, MS)
137+
elif units.startswith(SEC):
128138
return (periods, SEC)
129139
elif units.startswith(MIN):
130140
return (periods, MIN)

python/tests/tsdf_tests.py

+45
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,51 @@ def test_resample(self):
714714
# test bars summary
715715
self.assertDataFramesEqual(bars, barsExpected)
716716

717+
def test_resample_millis(self):
718+
"""Test of resampling for millisecond windows"""
719+
schema = StructType([StructField("symbol", StringType()),
720+
StructField("date", StringType()),
721+
StructField("event_ts", StringType()),
722+
StructField("trade_pr", FloatType()),
723+
StructField("trade_pr_2", FloatType())])
724+
725+
expectedSchema = StructType([StructField("symbol", StringType()),
726+
StructField("event_ts", StringType()),
727+
StructField("floor_trade_pr", FloatType()),
728+
StructField("floor_date", StringType()),
729+
StructField("floor_trade_pr_2", FloatType())])
730+
731+
expectedSchemaMS = StructType([StructField("symbol", StringType()),
732+
StructField("event_ts", StringType(), True),
733+
StructField("date", DoubleType()),
734+
StructField("trade_pr", DoubleType()),
735+
StructField("trade_pr_2", DoubleType())])
736+
737+
738+
data = [["S1", "SAME_DT", "2020-08-01 00:00:10.12345", 349.21, 10.0],
739+
["S1", "SAME_DT", "2020-08-01 00:00:10.123", 340.21, 9.0],
740+
["S1", "SAME_DT", "2020-08-01 00:00:10.124", 353.32, 8.0]]
741+
742+
expected_data_ms = [
743+
["S1", "2020-08-01 00:00:10.123", None, 344.71, 9.5],
744+
["S1", "2020-08-01 00:00:10.124", None, 353.32, 8.0]
745+
]
746+
747+
# construct dataframes
748+
df = self.buildTestDF(schema, data)
749+
dfExpected = self.buildTestDF(expectedSchemaMS, expected_data_ms)
750+
751+
# convert to TSDF
752+
tsdf_left = TSDF(df, partition_cols=["symbol"])
753+
754+
# 30 minute aggregation
755+
resample_ms = tsdf_left.resample(freq="ms", func="mean").df.withColumn("trade_pr", F.round(F.col('trade_pr'), 2))
756+
757+
int_df = TSDF(tsdf_left.df.withColumn("event_ts", F.col("event_ts").cast("timestamp")), partition_cols = ['symbol'])
758+
interpolated = int_df.interpolate(freq='ms', func='floor', method='ffill')
759+
self.assertDataFramesEqual(resample_ms, dfExpected)
760+
761+
717762
def test_upsample(self):
718763
"""Test of range stats for 20 minute rolling window"""
719764
schema = StructType([StructField("symbol", StringType()),

0 commit comments

Comments
 (0)