diff --git a/mlforecast/distributed/forecast.py b/mlforecast/distributed/forecast.py
index e52f015f..401a1cde 100644
--- a/mlforecast/distributed/forecast.py
+++ b/mlforecast/distributed/forecast.py
@@ -195,7 +195,7 @@ def _preprocess_partition(
train = part[train_mask]
valid_keep_cols = part.columns
if static_features is not None:
- valid_keep_cols.drop(static_features)
+ valid_keep_cols = valid_keep_cols.drop(static_features)
valid = part.loc[valid_mask, valid_keep_cols].merge(cutoffs, on=id_col)
transformed = ts.fit_transform(
train,
@@ -456,6 +456,8 @@ def _predict(
) -> Iterable[pd.DataFrame]:
for serialized_ts, _, serialized_valid in items:
valid = cloudpickle.loads(serialized_valid)
+ if valid is not None:
+ X_df = valid
ts = cloudpickle.loads(serialized_ts)
res = ts.predict(
models=models,
@@ -649,7 +651,11 @@ def cross_validation(
engine=self.engine,
)
results.append(fa.get_native_as_df(preds))
- return fa.union(*results)
+ if len(results) == 1:
+ res = results[0]
+ else:
+ res = fa.union(*results)
+ return res
@staticmethod
def _save_ts(items: List[List[Any]], path: str) -> Iterable[pd.DataFrame]:
diff --git a/nbs/distributed.forecast.ipynb b/nbs/distributed.forecast.ipynb
index 4540ec20..1ccd9418 100644
--- a/nbs/distributed.forecast.ipynb
+++ b/nbs/distributed.forecast.ipynb
@@ -261,7 +261,7 @@
" train = part[train_mask]\n",
" valid_keep_cols = part.columns\n",
" if static_features is not None:\n",
- " valid_keep_cols.drop(static_features)\n",
+ " valid_keep_cols = valid_keep_cols.drop(static_features)\n",
" valid = part.loc[valid_mask, valid_keep_cols].merge(cutoffs, on=id_col)\n",
" transformed = ts.fit_transform(\n",
" train,\n",
@@ -508,6 +508,8 @@
" ) -> Iterable[pd.DataFrame]:\n",
" for serialized_ts, _, serialized_valid in items:\n",
" valid = cloudpickle.loads(serialized_valid)\n",
+ " if valid is not None:\n",
+ " X_df = valid\n",
" ts = cloudpickle.loads(serialized_ts)\n",
" res = ts.predict(\n",
" models=models,\n",
@@ -695,7 +697,11 @@
" engine=self.engine,\n",
" )\n",
" results.append(fa.get_native_as_df(preds))\n",
- " return fa.union(*results)\n",
+ " if len(results) == 1:\n",
+ " res = results[0]\n",
+ " else:\n",
+ " res = fa.union(*results)\n",
+ " return res\n",
"\n",
" @staticmethod\n",
" def _save_ts(items: List[List[Any]], path: str) -> Iterable[pd.DataFrame]:\n",
diff --git a/nbs/docs/getting-started/quick_start_distributed.ipynb b/nbs/docs/getting-started/quick_start_distributed.ipynb
index 87e02a05..6f355810 100644
--- a/nbs/docs/getting-started/quick_start_distributed.ipynb
+++ b/nbs/docs/getting-started/quick_start_distributed.ipynb
@@ -53,6 +53,7 @@
"import pandas as pd\n",
"import s3fs\n",
"from sklearn.base import BaseEstimator\n",
+ "from utilsforecast.feature_engineering import fourier\n",
"\n",
"from mlforecast.distributed import DistributedMLForecast\n",
"from mlforecast.lag_transforms import ExpandingMean, ExponentiallyWeightedMean, RollingMean\n",
@@ -136,6 +137,10 @@
"
y | \n",
" static_0 | \n",
" static_1 | \n",
+ " sin1_7 | \n",
+ " sin2_7 | \n",
+ " cos1_7 | \n",
+ " cos2_7 | \n",
" \n",
" \n",
" npartitions=10 | \n",
@@ -144,6 +149,10 @@
" | \n",
" | \n",
" | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
"
\n",
" \n",
" \n",
@@ -154,6 +163,10 @@
" float64 | \n",
" int64 | \n",
" int64 | \n",
+ " float32 | \n",
+ " float32 | \n",
+ " float32 | \n",
+ " float32 | \n",
" \n",
" \n",
" id_10 | \n",
@@ -162,6 +175,10 @@
" ... | \n",
" ... | \n",
" ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
"
\n",
" \n",
" ... | \n",
@@ -170,6 +187,10 @@
" ... | \n",
" ... | \n",
" ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
"
\n",
" \n",
" id_90 | \n",
@@ -178,6 +199,10 @@
" ... | \n",
" ... | \n",
" ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
"
\n",
" \n",
" id_99 | \n",
@@ -186,6 +211,10 @@
" ... | \n",
" ... | \n",
" ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
"
\n",
" \n",
"\n",
@@ -193,13 +222,13 @@
],
"text/plain": [
"Dask DataFrame Structure:\n",
- " unique_id ds y static_0 static_1\n",
- "npartitions=10 \n",
- "id_00 object datetime64[ns] float64 int64 int64\n",
- "id_10 ... ... ... ... ...\n",
- "... ... ... ... ... ...\n",
- "id_90 ... ... ... ... ...\n",
- "id_99 ... ... ... ... ...\n",
+ " unique_id ds y static_0 static_1 sin1_7 sin2_7 cos1_7 cos2_7\n",
+ "npartitions=10 \n",
+ "id_00 object datetime64[ns] float64 int64 int64 float32 float32 float32 float32\n",
+ "id_10 ... ... ... ... ... ... ... ... ...\n",
+ "... ... ... ... ... ... ... ... ... ...\n",
+ "id_90 ... ... ... ... ... ... ... ... ...\n",
+ "id_99 ... ... ... ... ... ... ... ... ...\n",
"Dask Name: assign, 5 expressions\n",
"Expr=Assign(frame=MapPartitions(lambda))"
]
@@ -211,8 +240,9 @@
],
"source": [
"series = generate_daily_series(100, n_static_features=2, equal_ends=True, static_as_categorical=False, min_length=500, max_length=1_000)\n",
+ "train, future = fourier(series, freq='d', season_length=7, k=2, h=7)\n",
"npartitions = 10\n",
- "partitioned_series = dd.from_pandas(series.set_index('unique_id'), npartitions=npartitions) # make sure we split by the id_col\n",
+ "partitioned_series = dd.from_pandas(train.set_index('unique_id'), npartitions=npartitions) # make sure we split by the id_col\n",
"partitioned_series = partitioned_series.map_partitions(lambda df: df.reset_index())\n",
"partitioned_series['unique_id'] = partitioned_series['unique_id'].astype(str) # can't handle categoricals atm\n",
"partitioned_series"
@@ -245,7 +275,10 @@
"metadata": {},
"outputs": [],
"source": [
- "models = [DaskXGBForecast(random_state=0), DaskLGBMForecast(random_state=0)]"
+ "models = [\n",
+ " DaskXGBForecast(random_state=0),\n",
+ " DaskLGBMForecast(random_state=0, verbosity=-1),\n",
+ "]"
]
},
{
@@ -277,7 +310,7 @@
" num_threads=1,\n",
" engine=client,\n",
")\n",
- "fcst.fit(partitioned_series)"
+ "fcst.fit(partitioned_series, static_features=['static_0', 'static_1'])"
]
},
{
@@ -354,7 +387,7 @@
" fcst_np.fit(test_dd)\n",
" test_partition_results_size(fcst_np, num_partitions_test)\n",
" preds_np = fcst_np.predict(7).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
- " preds = fcst.predict(7).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
+ " preds = fcst.predict(7, X_df=future).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
" pd.testing.assert_frame_equal(\n",
" preds[['unique_id', 'ds']], \n",
" preds_np[['unique_id', 'ds']], \n",
@@ -415,36 +448,36 @@
" 0 | \n",
" id_00 | \n",
" 2002-09-27 00:00:00 | \n",
- " 22.489947 | \n",
- " 21.679944 | \n",
+ " 21.609526 | \n",
+ " 22.114111 | \n",
" \n",
" \n",
" 1 | \n",
" id_00 | \n",
" 2002-09-28 00:00:00 | \n",
- " 81.806826 | \n",
- " 84.151205 | \n",
+ " 85.623013 | \n",
+ " 84.309696 | \n",
"
\n",
" \n",
" 2 | \n",
" id_00 | \n",
" 2002-09-29 00:00:00 | \n",
- " 162.705641 | \n",
- " 164.024508 | \n",
+ " 163.107685 | \n",
+ " 163.20679 | \n",
"
\n",
" \n",
" 3 | \n",
" id_00 | \n",
" 2002-09-30 00:00:00 | \n",
- " 246.990386 | \n",
- " 246.099977 | \n",
+ " 246.96872 | \n",
+ " 245.510858 | \n",
"
\n",
" \n",
" 4 | \n",
" id_00 | \n",
" 2002-10-01 00:00:00 | \n",
- " 314.741463 | \n",
- " 315.261537 | \n",
+ " 318.521367 | \n",
+ " 314.479718 | \n",
"
\n",
" \n",
"\n",
@@ -452,11 +485,11 @@
],
"text/plain": [
" unique_id ds DaskXGBForecast DaskLGBMForecast\n",
- "0 id_00 2002-09-27 00:00:00 22.489947 21.679944\n",
- "1 id_00 2002-09-28 00:00:00 81.806826 84.151205\n",
- "2 id_00 2002-09-29 00:00:00 162.705641 164.024508\n",
- "3 id_00 2002-09-30 00:00:00 246.990386 246.099977\n",
- "4 id_00 2002-10-01 00:00:00 314.741463 315.261537"
+ "0 id_00 2002-09-27 00:00:00 21.609526 22.114111\n",
+ "1 id_00 2002-09-28 00:00:00 85.623013 84.309696\n",
+ "2 id_00 2002-09-29 00:00:00 163.107685 163.20679\n",
+ "3 id_00 2002-09-30 00:00:00 246.96872 245.510858\n",
+ "4 id_00 2002-10-01 00:00:00 318.521367 314.479718"
]
},
"execution_count": null,
@@ -465,7 +498,7 @@
}
],
"source": [
- "preds = fcst.predict(7).compute()\n",
+ "preds = fcst.predict(7, X_df=future).compute()\n",
"preds.head()"
]
},
@@ -477,8 +510,8 @@
"outputs": [],
"source": [
"#|hide\n",
- "preds2 = fcst.predict(7).compute()\n",
- "preds3 = fcst.predict(7, new_df=partitioned_series).compute()\n",
+ "preds2 = fcst.predict(7, X_df=future).compute()\n",
+ "preds3 = fcst.predict(7, new_df=partitioned_series, X_df=future).compute()\n",
"pd.testing.assert_frame_equal(preds, preds2)\n",
"pd.testing.assert_frame_equal(preds, preds3)"
]
@@ -604,8 +637,8 @@
"metadata": {},
"outputs": [],
"source": [
- "preds = fa.as_pandas(fcst.predict(10)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
- "preds2 = fa.as_pandas(fcst2.predict(10)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
+ "preds = fa.as_pandas(fcst.predict(7, X_df=future)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
+ "preds2 = fa.as_pandas(fcst2.predict(7, X_df=future)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
"pd.testing.assert_frame_equal(preds, preds2)"
]
},
@@ -635,7 +668,7 @@
" num_threads=1,\n",
" engine=client,\n",
")\n",
- "upd_fcst.fit(partitioned_series)\n",
+ "upd_fcst.fit(partitioned_series, static_features=['static_0', 'static_1'])\n",
"\n",
"new_df = (series.groupby('unique_id', observed=True)['ds'].max() + pd.offsets.Day()).reset_index()\n",
"new_df['y'] = -1.0\n",
@@ -643,7 +676,7 @@
"expected = new_df.rename(columns={'y': 'Lag1Model'})\n",
"expected = expected.astype({'unique_id': str})\n",
"expected['ds'] += pd.offsets.Day()\n",
- "upd_preds = upd_fcst.predict(1).compute()\n",
+ "upd_preds = upd_fcst.predict(1, X_df=future).compute()\n",
"pd.testing.assert_frame_equal(\n",
" upd_preds.reset_index(drop=True),\n",
" expected.reset_index(drop=True),\n",
@@ -669,7 +702,7 @@
"outputs": [],
"source": [
"local_fcst = fcst.to_local()\n",
- "local_preds = local_fcst.predict(10)\n",
+ "local_preds = local_fcst.predict(7, X_df=future)\n",
"# we don't check the dtype because sometimes these are arrow dtypes\n",
"# or different precisions of float\n",
"pd.testing.assert_frame_equal(preds, local_preds, check_dtype=False)"
@@ -686,7 +719,7 @@
"# test to_local without target transforms\n",
"fcst_no_targ_tfms = DistributedMLForecast(\n",
" models=[DaskXGBForecast(n_estimators=5, random_state=0)],\n",
- " freq='D', \n",
+ " freq='D',\n",
" lags=[1],\n",
" lag_transforms={1: [ExpandingMean()]},\n",
" date_features=['dayofweek'],\n",
@@ -718,6 +751,7 @@
" partitioned_series,\n",
" n_windows=3,\n",
" h=14,\n",
+ " static_features=['static_0', 'static_1'],\n",
")"
]
},
@@ -758,68 +792,68 @@
" \n",
" \n",
" \n",
- " 17 | \n",
- " id_01 | \n",
- " 2002-08-19 00:00:00 | \n",
- " 224.458336 | \n",
- " 222.742605 | \n",
+ " 0 | \n",
+ " id_00 | \n",
+ " 2002-08-16 00:00:00 | \n",
+ " 23.192749 | \n",
+ " 21.986437 | \n",
" 2002-08-15 00:00:00 | \n",
- " 210.723139 | \n",
+ " 11.878591 | \n",
"
\n",
" \n",
- " 43 | \n",
- " id_03 | \n",
- " 2002-08-17 00:00:00 | \n",
- " 2.235601 | \n",
- " 2.210624 | \n",
+ " 30 | \n",
+ " id_02 | \n",
+ " 2002-08-18 00:00:00 | \n",
+ " 96.59974 | \n",
+ " 96.568057 | \n",
" 2002-08-15 00:00:00 | \n",
- " 2.416967 | \n",
+ " 94.706551 | \n",
"
\n",
" \n",
- " 44 | \n",
- " id_03 | \n",
- " 2002-08-18 00:00:00 | \n",
- " 3.276747 | \n",
- " 3.239702 | \n",
+ " 80 | \n",
+ " id_05 | \n",
+ " 2002-08-26 00:00:00 | \n",
+ " 257.210466 | \n",
+ " 255.908309 | \n",
" 2002-08-15 00:00:00 | \n",
- " 3.060194 | \n",
+ " 246.051086 | \n",
"
\n",
" \n",
- " 119 | \n",
- " id_08 | \n",
- " 2002-08-23 00:00:00 | \n",
- " 131.261689 | \n",
- " 131.180289 | \n",
+ " 36 | \n",
+ " id_12 | \n",
+ " 2002-08-24 00:00:00 | \n",
+ " 401.081335 | \n",
+ " 401.697836 | \n",
" 2002-08-15 00:00:00 | \n",
- " 138.668463 | \n",
+ " 424.296882 | \n",
"
\n",
" \n",
- " 131 | \n",
- " id_09 | \n",
- " 2002-08-21 00:00:00 | \n",
- " 27.716417 | \n",
- " 28.263963 | \n",
+ " 91 | \n",
+ " id_16 | \n",
+ " 2002-08-23 00:00:00 | \n",
+ " 315.036479 | \n",
+ " 315.368377 | \n",
" 2002-08-15 00:00:00 | \n",
- " 22.88374 | \n",
+ " 300.419406 | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- " unique_id ds DaskXGBForecast DaskLGBMForecast \\\n",
- "17 id_01 2002-08-19 00:00:00 224.458336 222.742605 \n",
- "43 id_03 2002-08-17 00:00:00 2.235601 2.210624 \n",
- "44 id_03 2002-08-18 00:00:00 3.276747 3.239702 \n",
- "119 id_08 2002-08-23 00:00:00 131.261689 131.180289 \n",
- "131 id_09 2002-08-21 00:00:00 27.716417 28.263963 \n",
+ " unique_id ds DaskXGBForecast DaskLGBMForecast \\\n",
+ "0 id_00 2002-08-16 00:00:00 23.192749 21.986437 \n",
+ "30 id_02 2002-08-18 00:00:00 96.59974 96.568057 \n",
+ "80 id_05 2002-08-26 00:00:00 257.210466 255.908309 \n",
+ "36 id_12 2002-08-24 00:00:00 401.081335 401.697836 \n",
+ "91 id_16 2002-08-23 00:00:00 315.036479 315.368377 \n",
"\n",
- " cutoff y \n",
- "17 2002-08-15 00:00:00 210.723139 \n",
- "43 2002-08-15 00:00:00 2.416967 \n",
- "44 2002-08-15 00:00:00 3.060194 \n",
- "119 2002-08-15 00:00:00 138.668463 \n",
- "131 2002-08-15 00:00:00 22.88374 "
+ " cutoff y \n",
+ "0 2002-08-15 00:00:00 11.878591 \n",
+ "30 2002-08-15 00:00:00 94.706551 \n",
+ "80 2002-08-15 00:00:00 246.051086 \n",
+ "36 2002-08-15 00:00:00 424.296882 \n",
+ "91 2002-08-15 00:00:00 300.419406 "
]
},
"execution_count": null,
@@ -831,6 +865,23 @@
"cv_res.compute().head()"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ebb0e756-522d-42cf-9b6e-73b1158c21bd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "# single window CV\n",
+ "assert fcst.cross_validation(\n",
+ " partitioned_series,\n",
+ " n_windows=1,\n",
+ " h=5,\n",
+ " static_features=['static_0', 'static_1'],\n",
+ ").compute()['cutoff'].nunique() == 1"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -865,6 +916,7 @@
" i_window=0,\n",
" input_size=input_size,\n",
" ),\n",
+ " static_features=['static_0', 'static_1'],\n",
")\n",
"assert reduced_train.groupby('unique_id').size().compute().max() == input_size"
]
@@ -881,7 +933,8 @@
" partitioned_series,\n",
" n_windows=3,\n",
" h=14,\n",
- " refit=False\n",
+ " refit=False,\n",
+ " static_features=['static_0', 'static_1'],\n",
")\n",
"cv_results_df = cv_res.compute().sort_values(['unique_id', 'ds'])\n",
"cv_results_no_refit_df = cv_res_no_refit.compute().sort_values(['unique_id', 'ds'])\n",
@@ -914,12 +967,18 @@
" num_threads=1,\n",
")\n",
"fcst = DistributedMLForecast(freq='D', **flow_params)\n",
- "fcst.fit(partitioned_series)\n",
- "preds = fcst.predict(7).compute()\n",
+ "fcst.fit(partitioned_series, static_features=['static_0', 'static_1'])\n",
+ "preds = fcst.predict(7, X_df=future).compute()\n",
"fcst2 = DistributedMLForecast(freq='D', **flow_params)\n",
- "fcst2.preprocess(non_std_series, id_col='some_id', time_col='time', target_col='value')\n",
+ "fcst2.preprocess(\n",
+ " non_std_series,\n",
+ " id_col='some_id',\n",
+ " time_col='time',\n",
+ " target_col='value',\n",
+ " static_features=['static_0', 'static_1'],\n",
+ ")\n",
"fcst2.models_ = fcst.models_ # distributed training can end up with different fits\n",
- "non_std_preds = fcst2.predict(7).compute()\n",
+ "non_std_preds = fcst2.predict(7, X_df=future.rename(columns={'ds': 'time', 'unique_id': 'some_id'})).compute()\n",
"pd.testing.assert_frame_equal(\n",
" preds.drop(columns='ds'),\n",
" non_std_preds.drop(columns='time').rename(columns={'some_id': 'unique_id'})\n",
@@ -996,9 +1055,11 @@
"metadata": {},
"outputs": [],
"source": [
+ "series = generate_daily_series(100, n_static_features=2, equal_ends=True, static_as_categorical=False, min_length=500, max_length=1_000)\n",
+ "series['unique_id'] = series['unique_id'].astype(str) # can't handle categoricals atm\n",
+ "train, future = fourier(series, freq='d', season_length=7, k=2, h=7)\n",
"numPartitions = 4\n",
- "series = generate_daily_series(100, n_static_features=2, equal_ends=True, static_as_categorical=False)\n",
- "spark_series = spark.createDataFrame(series).repartitionByRange(numPartitions, 'unique_id')"
+ "spark_series = spark.createDataFrame(train).repartitionByRange(numPartitions, 'unique_id')"
]
},
{
@@ -1028,7 +1089,10 @@
"metadata": {},
"outputs": [],
"source": [
- "models = [SparkLGBMForecast(seed=0), SparkXGBForecast(random_state=0)]"
+ "models = [\n",
+ " SparkLGBMForecast(seed=0, verbosity=-1),\n",
+ " SparkXGBForecast(random_state=0),\n",
+ "]"
]
},
{
@@ -1082,7 +1146,7 @@
"source": [
"#| hide\n",
"# test num_partitions works properly\n",
- "test_spark_df = spark.createDataFrame(series)\n",
+ "test_spark_df = spark.createDataFrame(train)\n",
"num_partitions_test = 10\n",
"fcst_np = DistributedMLForecast(\n",
" models=models,\n",
@@ -1096,10 +1160,10 @@
" num_threads=1,\n",
" num_partitions=num_partitions_test,\n",
")\n",
- "fcst_np.fit(test_spark_df)\n",
+ "fcst_np.fit(test_spark_df, static_features=['static_0', 'static_1'])\n",
"test_partition_results_size(fcst_np, num_partitions_test)\n",
- "preds_np = fcst_np.predict(7).toPandas().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
- "preds = fcst.predict(7).toPandas().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
+ "preds_np = fcst_np.predict(7, X_df=future).toPandas().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
+ "preds = fcst.predict(7, X_df=future).toPandas().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
"pd.testing.assert_frame_equal(\n",
" preds[['unique_id', 'ds']], \n",
" preds_np[['unique_id', 'ds']], \n",
@@ -1121,7 +1185,7 @@
"metadata": {},
"outputs": [],
"source": [
- "preds = fcst.predict(14).toPandas()"
+ "preds = fcst.predict(7, X_df=future).toPandas()"
]
},
{
@@ -1161,37 +1225,37 @@
" \n",
" 0 | \n",
" id_00 | \n",
- " 2001-05-15 | \n",
- " 430.964632 | \n",
- " 431.202969 | \n",
+ " 2002-09-27 | \n",
+ " 15.102403 | \n",
+ " 18.631477 | \n",
"
\n",
" \n",
" 1 | \n",
" id_00 | \n",
- " 2001-05-16 | \n",
- " 505.411960 | \n",
- " 504.030227 | \n",
+ " 2002-09-28 | \n",
+ " 92.980261 | \n",
+ " 93.796269 | \n",
"
\n",
" \n",
" 2 | \n",
" id_00 | \n",
- " 2001-05-17 | \n",
- " 9.889056 | \n",
- " 9.706636 | \n",
+ " 2002-09-29 | \n",
+ " 160.090375 | \n",
+ " 159.582315 | \n",
"
\n",
" \n",
" 3 | \n",
" id_00 | \n",
- " 2001-05-18 | \n",
- " 99.359694 | \n",
- " 96.258271 | \n",
+ " 2002-09-30 | \n",
+ " 250.416113 | \n",
+ " 250.861651 | \n",
"
\n",
" \n",
" 4 | \n",
" id_00 | \n",
- " 2001-05-19 | \n",
- " 196.307731 | \n",
- " 197.443618 | \n",
+ " 2002-10-01 | \n",
+ " 323.306184 | \n",
+ " 321.564089 | \n",
"
\n",
" \n",
"\n",
@@ -1199,11 +1263,11 @@
],
"text/plain": [
" unique_id ds SparkLGBMForecast SparkXGBForecast\n",
- "0 id_00 2001-05-15 430.964632 431.202969\n",
- "1 id_00 2001-05-16 505.411960 504.030227\n",
- "2 id_00 2001-05-17 9.889056 9.706636\n",
- "3 id_00 2001-05-18 99.359694 96.258271\n",
- "4 id_00 2001-05-19 196.307731 197.443618"
+ "0 id_00 2002-09-27 15.102403 18.631477\n",
+ "1 id_00 2002-09-28 92.980261 93.796269\n",
+ "2 id_00 2002-09-29 160.090375 159.582315\n",
+ "3 id_00 2002-09-30 250.416113 250.861651\n",
+ "4 id_00 2002-10-01 323.306184 321.564089"
]
},
"execution_count": null,
@@ -1308,8 +1372,8 @@
}
],
"source": [
- "preds = fa.as_pandas(fcst.predict(10)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
- "preds2 = fa.as_pandas(fcst2.predict(10)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
+ "preds = fa.as_pandas(fcst.predict(7, X_df=future)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
+ "preds2 = fa.as_pandas(fcst2.predict(7, X_df=future)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
"pd.testing.assert_frame_equal(preds, preds2)"
]
},
@@ -1331,7 +1395,7 @@
"outputs": [],
"source": [
"local_fcst = fcst.to_local()\n",
- "local_preds = local_fcst.predict(10)\n",
+ "local_preds = local_fcst.predict(7, X_df=future)\n",
"# we don't check the dtype because sometimes these are arrow dtypes\n",
"# or different precisions of float\n",
"pd.testing.assert_frame_equal(preds, local_preds, check_dtype=False)"
@@ -1356,6 +1420,7 @@
" spark_series,\n",
" n_windows=3,\n",
" h=14,\n",
+ " static_features=['static_0', 'static_1'],\n",
").toPandas()"
]
},
@@ -1397,48 +1462,48 @@
" \n",
" \n",
" 0 | \n",
- " id_15 | \n",
- " 2001-04-04 | \n",
- " 88.438691 | \n",
- " 86.105463 | \n",
- " 2001-04-02 | \n",
- " 92.468763 | \n",
+ " id_03 | \n",
+ " 2002-08-18 | \n",
+ " 3.272922 | \n",
+ " 3.348874 | \n",
+ " 2002-08-15 | \n",
+ " 3.060194 | \n",
"
\n",
" \n",
" 1 | \n",
- " id_25 | \n",
- " 2001-04-12 | \n",
- " 355.712493 | \n",
- " 354.525400 | \n",
- " 2001-04-02 | \n",
- " 320.701359 | \n",
+ " id_09 | \n",
+ " 2002-08-20 | \n",
+ " 402.718091 | \n",
+ " 402.622501 | \n",
+ " 2002-08-15 | \n",
+ " 398.784459 | \n",
"
\n",
" \n",
" 2 | \n",
- " id_03 | \n",
- " 2001-04-08 | \n",
- " 257.243845 | \n",
- " 253.834157 | \n",
- " 2001-04-02 | \n",
- " 274.420045 | \n",
+ " id_25 | \n",
+ " 2002-08-22 | \n",
+ " 87.189811 | \n",
+ " 86.891632 | \n",
+ " 2002-08-15 | \n",
+ " 82.731377 | \n",
"
\n",
" \n",
" 3 | \n",
- " id_14 | \n",
- " 2001-04-07 | \n",
- " 24.925278 | \n",
- " 23.833504 | \n",
- " 2001-04-02 | \n",
- " 26.906679 | \n",
+ " id_06 | \n",
+ " 2002-08-21 | \n",
+ " 20.416790 | \n",
+ " 20.478502 | \n",
+ " 2002-08-15 | \n",
+ " 19.196394 | \n",
"
\n",
" \n",
" 4 | \n",
- " id_01 | \n",
- " 2001-04-16 | \n",
- " 89.180665 | \n",
- " 90.743194 | \n",
- " 2001-04-02 | \n",
- " 93.807725 | \n",
+ " id_22 | \n",
+ " 2002-08-23 | \n",
+ " 357.718513 | \n",
+ " 360.502024 | \n",
+ " 2002-08-15 | \n",
+ " 394.770699 | \n",
"
\n",
" \n",
"\n",
@@ -1446,18 +1511,18 @@
],
"text/plain": [
" unique_id ds SparkLGBMForecast SparkXGBForecast cutoff \\\n",
- "0 id_15 2001-04-04 88.438691 86.105463 2001-04-02 \n",
- "1 id_25 2001-04-12 355.712493 354.525400 2001-04-02 \n",
- "2 id_03 2001-04-08 257.243845 253.834157 2001-04-02 \n",
- "3 id_14 2001-04-07 24.925278 23.833504 2001-04-02 \n",
- "4 id_01 2001-04-16 89.180665 90.743194 2001-04-02 \n",
+ "0 id_03 2002-08-18 3.272922 3.348874 2002-08-15 \n",
+ "1 id_09 2002-08-20 402.718091 402.622501 2002-08-15 \n",
+ "2 id_25 2002-08-22 87.189811 86.891632 2002-08-15 \n",
+ "3 id_06 2002-08-21 20.416790 20.478502 2002-08-15 \n",
+ "4 id_22 2002-08-23 357.718513 360.502024 2002-08-15 \n",
"\n",
" y \n",
- "0 92.468763 \n",
- "1 320.701359 \n",
- "2 274.420045 \n",
- "3 26.906679 \n",
- "4 93.807725 "
+ "0 3.060194 \n",
+ "1 398.784459 \n",
+ "2 82.731377 \n",
+ "3 19.196394 \n",
+ "4 394.770699 "
]
},
"execution_count": null,
@@ -1540,10 +1605,10 @@
"metadata": {},
"outputs": [],
"source": [
- "series = generate_daily_series(100, n_static_features=2, equal_ends=True, static_as_categorical=False)\n",
- "# we need noncategory unique_id\n",
- "series['unique_id'] = series['unique_id'].astype(str)\n",
- "ray_series = ray.data.from_pandas(series)"
+ "series = generate_daily_series(100, n_static_features=2, equal_ends=True, static_as_categorical=False, min_length=500, max_length=1_000)\n",
+ "series['unique_id'] = series['unique_id'].astype(str) # can't handle categoricals atm\n",
+ "train, future = fourier(series, freq='d', season_length=7, k=2, h=7)\n",
+ "ray_series = ray.data.from_pandas(train)"
]
},
{
@@ -1573,7 +1638,10 @@
"metadata": {},
"outputs": [],
"source": [
- "models = [RayLGBMForecast(random_state=0), RayXGBForecast(random_state=0)]"
+ "models = [\n",
+ " RayLGBMForecast(random_state=0, verbosity=-1),\n",
+ " RayXGBForecast(random_state=0),\n",
+ "]"
]
},
{
@@ -1659,13 +1727,13 @@
" date_features=['dayofweek', 'month'],\n",
" num_threads=1,\n",
")\n",
- "fcst_np.fit(ray_series)\n",
+ "fcst_np.fit(ray_series, static_features=['static_0', 'static_1'])\n",
"# we dont use test_partition_results_size\n",
"# since the number of objects is different \n",
"# from the number of partitions\n",
"test_eq(fa.count(fcst_np._partition_results), 100) # number of series\n",
- "preds_np = fcst_np.predict(7).to_pandas().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
- "preds = fcst.predict(7).to_pandas().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
+ "preds_np = fcst_np.predict(7, X_df=future).to_pandas().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
+ "preds = fcst.predict(7, X_df=future).to_pandas().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
"pd.testing.assert_frame_equal(\n",
" preds[['unique_id', 'ds']], \n",
" preds_np[['unique_id', 'ds']], \n",
@@ -1687,7 +1755,7 @@
"metadata": {},
"outputs": [],
"source": [
- "preds = fcst.predict(14).to_pandas()"
+ "preds = fcst.predict(7, X_df=future).to_pandas()"
]
},
{
@@ -1726,38 +1794,38 @@
" \n",
" \n",
" 0 | \n",
- " id_01 | \n",
- " 2001-05-15 | \n",
- " 118.505341 | \n",
- " 118.32222 | \n",
+ " id_00 | \n",
+ " 2002-09-27 | \n",
+ " 15.232455 | \n",
+ " 10.38301 | \n",
"
\n",
" \n",
" 1 | \n",
- " id_01 | \n",
- " 2001-05-16 | \n",
- " 152.321457 | \n",
- " 152.265915 | \n",
+ " id_00 | \n",
+ " 2002-09-28 | \n",
+ " 92.288994 | \n",
+ " 92.531502 | \n",
"
\n",
" \n",
" 2 | \n",
- " id_01 | \n",
- " 2001-05-17 | \n",
- " 181.979599 | \n",
- " 181.945618 | \n",
+ " id_00 | \n",
+ " 2002-09-29 | \n",
+ " 160.043472 | \n",
+ " 160.722885 | \n",
"
\n",
" \n",
" 3 | \n",
- " id_01 | \n",
- " 2001-05-18 | \n",
- " 9.530758 | \n",
- " 9.543224 | \n",
+ " id_00 | \n",
+ " 2002-09-30 | \n",
+ " 250.03212 | \n",
+ " 252.821899 | \n",
"
\n",
" \n",
" 4 | \n",
- " id_01 | \n",
- " 2001-05-19 | \n",
- " 40.503441 | \n",
- " 40.661186 | \n",
+ " id_00 | \n",
+ " 2002-10-01 | \n",
+ " 322.905182 | \n",
+ " 324.387695 | \n",
"
\n",
" \n",
"\n",
@@ -1765,11 +1833,11 @@
],
"text/plain": [
" unique_id ds RayLGBMForecast RayXGBForecast\n",
- "0 id_01 2001-05-15 118.505341 118.32222\n",
- "1 id_01 2001-05-16 152.321457 152.265915\n",
- "2 id_01 2001-05-17 181.979599 181.945618\n",
- "3 id_01 2001-05-18 9.530758 9.543224\n",
- "4 id_01 2001-05-19 40.503441 40.661186"
+ "0 id_00 2002-09-27 15.232455 10.38301\n",
+ "1 id_00 2002-09-28 92.288994 92.531502\n",
+ "2 id_00 2002-09-29 160.043472 160.722885\n",
+ "3 id_00 2002-09-30 250.03212 252.821899\n",
+ "4 id_00 2002-10-01 322.905182 324.387695"
]
},
"execution_count": null,
@@ -1850,8 +1918,8 @@
"metadata": {},
"outputs": [],
"source": [
- "preds = fa.as_pandas(fcst.predict(10)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
- "preds2 = fa.as_pandas(fcst2.predict(10)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
+ "preds = fa.as_pandas(fcst.predict(7, X_df=future)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
+ "preds2 = fa.as_pandas(fcst2.predict(7, X_df=future)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n",
"pd.testing.assert_frame_equal(preds, preds2)"
]
},
@@ -1873,7 +1941,7 @@
"outputs": [],
"source": [
"local_fcst = fcst.to_local()\n",
- "local_preds = local_fcst.predict(10)\n",
+ "local_preds = local_fcst.predict(7, X_df=future)\n",
"# we don't check the dtype because sometimes these are arrow dtypes\n",
"# or different precisions of float\n",
"pd.testing.assert_frame_equal(preds, local_preds, check_dtype=False)"
@@ -1898,6 +1966,7 @@
" ray_series,\n",
" n_windows=3,\n",
" h=14,\n",
+ " static_features=['static_0', 'static_1'],\n",
").to_pandas()"
]
},
@@ -1939,48 +2008,48 @@
" \n",
" \n",
" 0 | \n",
- " id_10 | \n",
- " 2001-05-01 | \n",
- " 24.767561 | \n",
- " 24.528799 | \n",
- " 2001-04-30 | \n",
- " 31.878545 | \n",
+ " id_04 | \n",
+ " 2002-09-20 | \n",
+ " 118.982094 | \n",
+ " 117.577477 | \n",
+ " 2002-09-12 | \n",
+ " 118.603489 | \n",
"
\n",
" \n",
" 1 | \n",
- " id_10 | \n",
- " 2001-05-07 | \n",
- " 1.916985 | \n",
- " 2.323445 | \n",
- " 2001-04-30 | \n",
- " 7.365955 | \n",
+ " id_04 | \n",
+ " 2002-09-24 | \n",
+ " 51.461491 | \n",
+ " 50.120552 | \n",
+ " 2002-09-12 | \n",
+ " 52.668389 | \n",
"
\n",
" \n",
" 2 | \n",
- " id_13 | \n",
- " 2001-05-01 | \n",
- " 210.900330 | \n",
- " 212.959320 | \n",
- " 2001-04-30 | \n",
- " 190.485236 | \n",
+ " id_05 | \n",
+ " 2002-09-20 | \n",
+ " 27.594826 | \n",
+ " 24.421537 | \n",
+ " 2002-09-12 | \n",
+ " 20.120710 | \n",
"
\n",
" \n",
" 3 | \n",
- " id_14 | \n",
- " 2001-05-01 | \n",
- " 196.620819 | \n",
- " 196.253036 | \n",
- " 2001-04-30 | \n",
- " 213.631212 | \n",
+ " id_05 | \n",
+ " 2002-09-25 | \n",
+ " 411.615204 | \n",
+ " 412.093384 | \n",
+ " 2002-09-12 | \n",
+ " 419.621422 | \n",
"
\n",
" \n",
" 4 | \n",
- " id_14 | \n",
- " 2001-05-03 | \n",
- " 323.323334 | \n",
- " 322.372894 | \n",
- " 2001-04-30 | \n",
- " 338.234837 | \n",
+ " id_08 | \n",
+ " 2002-09-25 | \n",
+ " 83.210945 | \n",
+ " 83.842705 | \n",
+ " 2002-09-12 | \n",
+ " 86.344885 | \n",
"
\n",
" \n",
"\n",
@@ -1988,11 +2057,11 @@
],
"text/plain": [
" unique_id ds RayLGBMForecast RayXGBForecast cutoff y\n",
- "0 id_10 2001-05-01 24.767561 24.528799 2001-04-30 31.878545\n",
- "1 id_10 2001-05-07 1.916985 2.323445 2001-04-30 7.365955\n",
- "2 id_13 2001-05-01 210.900330 212.959320 2001-04-30 190.485236\n",
- "3 id_14 2001-05-01 196.620819 196.253036 2001-04-30 213.631212\n",
- "4 id_14 2001-05-03 323.323334 322.372894 2001-04-30 338.234837"
+ "0 id_04 2002-09-20 118.982094 117.577477 2002-09-12 118.603489\n",
+ "1 id_04 2002-09-24 51.461491 50.120552 2002-09-12 52.668389\n",
+ "2 id_05 2002-09-20 27.594826 24.421537 2002-09-12 20.120710\n",
+ "3 id_05 2002-09-25 411.615204 412.093384 2002-09-12 419.621422\n",
+ "4 id_08 2002-09-25 83.210945 83.842705 2002-09-12 86.344885"
]
},
"execution_count": null,