diff --git a/mlforecast/distributed/forecast.py b/mlforecast/distributed/forecast.py index e52f015f..401a1cde 100644 --- a/mlforecast/distributed/forecast.py +++ b/mlforecast/distributed/forecast.py @@ -195,7 +195,7 @@ def _preprocess_partition( train = part[train_mask] valid_keep_cols = part.columns if static_features is not None: - valid_keep_cols.drop(static_features) + valid_keep_cols = valid_keep_cols.drop(static_features) valid = part.loc[valid_mask, valid_keep_cols].merge(cutoffs, on=id_col) transformed = ts.fit_transform( train, @@ -456,6 +456,8 @@ def _predict( ) -> Iterable[pd.DataFrame]: for serialized_ts, _, serialized_valid in items: valid = cloudpickle.loads(serialized_valid) + if valid is not None: + X_df = valid ts = cloudpickle.loads(serialized_ts) res = ts.predict( models=models, @@ -649,7 +651,11 @@ def cross_validation( engine=self.engine, ) results.append(fa.get_native_as_df(preds)) - return fa.union(*results) + if len(results) == 1: + res = results[0] + else: + res = fa.union(*results) + return res @staticmethod def _save_ts(items: List[List[Any]], path: str) -> Iterable[pd.DataFrame]: diff --git a/nbs/distributed.forecast.ipynb b/nbs/distributed.forecast.ipynb index 4540ec20..1ccd9418 100644 --- a/nbs/distributed.forecast.ipynb +++ b/nbs/distributed.forecast.ipynb @@ -261,7 +261,7 @@ " train = part[train_mask]\n", " valid_keep_cols = part.columns\n", " if static_features is not None:\n", - " valid_keep_cols.drop(static_features)\n", + " valid_keep_cols = valid_keep_cols.drop(static_features)\n", " valid = part.loc[valid_mask, valid_keep_cols].merge(cutoffs, on=id_col)\n", " transformed = ts.fit_transform(\n", " train,\n", @@ -508,6 +508,8 @@ " ) -> Iterable[pd.DataFrame]:\n", " for serialized_ts, _, serialized_valid in items:\n", " valid = cloudpickle.loads(serialized_valid)\n", + " if valid is not None:\n", + " X_df = valid\n", " ts = cloudpickle.loads(serialized_ts)\n", " res = ts.predict(\n", " models=models,\n", @@ -695,7 +697,11 @@ " engine=self.engine,\n", " )\n", " results.append(fa.get_native_as_df(preds))\n", - " return fa.union(*results)\n", + " if len(results) == 1:\n", + " res = results[0]\n", + " else:\n", + " res = fa.union(*results)\n", + " return res\n", "\n", " @staticmethod\n", " def _save_ts(items: List[List[Any]], path: str) -> Iterable[pd.DataFrame]:\n", diff --git a/nbs/docs/getting-started/quick_start_distributed.ipynb b/nbs/docs/getting-started/quick_start_distributed.ipynb index 87e02a05..6f355810 100644 --- a/nbs/docs/getting-started/quick_start_distributed.ipynb +++ b/nbs/docs/getting-started/quick_start_distributed.ipynb @@ -53,6 +53,7 @@ "import pandas as pd\n", "import s3fs\n", "from sklearn.base import BaseEstimator\n", + "from utilsforecast.feature_engineering import fourier\n", "\n", "from mlforecast.distributed import DistributedMLForecast\n", "from mlforecast.lag_transforms import ExpandingMean, ExponentiallyWeightedMean, RollingMean\n", @@ -136,6 +137,10 @@ " y\n", " static_0\n", " static_1\n", + " sin1_7\n", + " sin2_7\n", + " cos1_7\n", + " cos2_7\n", " \n", " \n", " npartitions=10\n", @@ -144,6 +149,10 @@ " \n", " \n", " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", @@ -154,6 +163,10 @@ " float64\n", " int64\n", " int64\n", + " float32\n", + " float32\n", + " float32\n", + " float32\n", " \n", " \n", " id_10\n", @@ -162,6 +175,10 @@ " ...\n", " ...\n", " ...\n", + " ...\n", + " ...\n", + " ...\n", + " ...\n", " \n", " \n", " ...\n", @@ -170,6 +187,10 @@ " ...\n", " ...\n", " ...\n", + " ...\n", + " ...\n", + " ...\n", + " ...\n", " \n", " \n", " id_90\n", @@ -178,6 +199,10 @@ " ...\n", " ...\n", " ...\n", + " ...\n", + " ...\n", + " ...\n", + " ...\n", " \n", " \n", " id_99\n", @@ -186,6 +211,10 @@ " ...\n", " ...\n", " ...\n", + " ...\n", + " ...\n", + " ...\n", + " ...\n", " \n", " \n", "\n", @@ -193,13 +222,13 @@ ], "text/plain": [ "Dask DataFrame Structure:\n", - " unique_id ds y static_0 static_1\n", - "npartitions=10 \n", - "id_00 object datetime64[ns] float64 int64 int64\n", - "id_10 ... ... ... ... ...\n", - "... ... ... ... ... ...\n", - "id_90 ... ... ... ... ...\n", - "id_99 ... ... ... ... ...\n", + " unique_id ds y static_0 static_1 sin1_7 sin2_7 cos1_7 cos2_7\n", + "npartitions=10 \n", + "id_00 object datetime64[ns] float64 int64 int64 float32 float32 float32 float32\n", + "id_10 ... ... ... ... ... ... ... ... ...\n", + "... ... ... ... ... ... ... ... ... ...\n", + "id_90 ... ... ... ... ... ... ... ... ...\n", + "id_99 ... ... ... ... ... ... ... ... ...\n", "Dask Name: assign, 5 expressions\n", "Expr=Assign(frame=MapPartitions(lambda))" ] @@ -211,8 +240,9 @@ ], "source": [ "series = generate_daily_series(100, n_static_features=2, equal_ends=True, static_as_categorical=False, min_length=500, max_length=1_000)\n", + "train, future = fourier(series, freq='d', season_length=7, k=2, h=7)\n", "npartitions = 10\n", - "partitioned_series = dd.from_pandas(series.set_index('unique_id'), npartitions=npartitions) # make sure we split by the id_col\n", + "partitioned_series = dd.from_pandas(train.set_index('unique_id'), npartitions=npartitions) # make sure we split by the id_col\n", "partitioned_series = partitioned_series.map_partitions(lambda df: df.reset_index())\n", "partitioned_series['unique_id'] = partitioned_series['unique_id'].astype(str) # can't handle categoricals atm\n", "partitioned_series" @@ -245,7 +275,10 @@ "metadata": {}, "outputs": [], "source": [ - "models = [DaskXGBForecast(random_state=0), DaskLGBMForecast(random_state=0)]" + "models = [\n", + " DaskXGBForecast(random_state=0),\n", + " DaskLGBMForecast(random_state=0, verbosity=-1),\n", + "]" ] }, { @@ -277,7 +310,7 @@ " num_threads=1,\n", " engine=client,\n", ")\n", - "fcst.fit(partitioned_series)" + "fcst.fit(partitioned_series, static_features=['static_0', 'static_1'])" ] }, { @@ -354,7 +387,7 @@ " fcst_np.fit(test_dd)\n", " test_partition_results_size(fcst_np, num_partitions_test)\n", " preds_np = fcst_np.predict(7).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", - " preds = fcst.predict(7).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", + " preds = fcst.predict(7, X_df=future).compute().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", " pd.testing.assert_frame_equal(\n", " preds[['unique_id', 'ds']], \n", " preds_np[['unique_id', 'ds']], \n", @@ -415,36 +448,36 @@ " 0\n", " id_00\n", " 2002-09-27 00:00:00\n", - " 22.489947\n", - " 21.679944\n", + " 21.609526\n", + " 22.114111\n", " \n", " \n", " 1\n", " id_00\n", " 2002-09-28 00:00:00\n", - " 81.806826\n", - " 84.151205\n", + " 85.623013\n", + " 84.309696\n", " \n", " \n", " 2\n", " id_00\n", " 2002-09-29 00:00:00\n", - " 162.705641\n", - " 164.024508\n", + " 163.107685\n", + " 163.20679\n", " \n", " \n", " 3\n", " id_00\n", " 2002-09-30 00:00:00\n", - " 246.990386\n", - " 246.099977\n", + " 246.96872\n", + " 245.510858\n", " \n", " \n", " 4\n", " id_00\n", " 2002-10-01 00:00:00\n", - " 314.741463\n", - " 315.261537\n", + " 318.521367\n", + " 314.479718\n", " \n", " \n", "\n", @@ -452,11 +485,11 @@ ], "text/plain": [ " unique_id ds DaskXGBForecast DaskLGBMForecast\n", - "0 id_00 2002-09-27 00:00:00 22.489947 21.679944\n", - "1 id_00 2002-09-28 00:00:00 81.806826 84.151205\n", - "2 id_00 2002-09-29 00:00:00 162.705641 164.024508\n", - "3 id_00 2002-09-30 00:00:00 246.990386 246.099977\n", - "4 id_00 2002-10-01 00:00:00 314.741463 315.261537" + "0 id_00 2002-09-27 00:00:00 21.609526 22.114111\n", + "1 id_00 2002-09-28 00:00:00 85.623013 84.309696\n", + "2 id_00 2002-09-29 00:00:00 163.107685 163.20679\n", + "3 id_00 2002-09-30 00:00:00 246.96872 245.510858\n", + "4 id_00 2002-10-01 00:00:00 318.521367 314.479718" ] }, "execution_count": null, @@ -465,7 +498,7 @@ } ], "source": [ - "preds = fcst.predict(7).compute()\n", + "preds = fcst.predict(7, X_df=future).compute()\n", "preds.head()" ] }, @@ -477,8 +510,8 @@ "outputs": [], "source": [ "#|hide\n", - "preds2 = fcst.predict(7).compute()\n", - "preds3 = fcst.predict(7, new_df=partitioned_series).compute()\n", + "preds2 = fcst.predict(7, X_df=future).compute()\n", + "preds3 = fcst.predict(7, new_df=partitioned_series, X_df=future).compute()\n", "pd.testing.assert_frame_equal(preds, preds2)\n", "pd.testing.assert_frame_equal(preds, preds3)" ] @@ -604,8 +637,8 @@ "metadata": {}, "outputs": [], "source": [ - "preds = fa.as_pandas(fcst.predict(10)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", - "preds2 = fa.as_pandas(fcst2.predict(10)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", + "preds = fa.as_pandas(fcst.predict(7, X_df=future)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", + "preds2 = fa.as_pandas(fcst2.predict(7, X_df=future)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", "pd.testing.assert_frame_equal(preds, preds2)" ] }, @@ -635,7 +668,7 @@ " num_threads=1,\n", " engine=client,\n", ")\n", - "upd_fcst.fit(partitioned_series)\n", + "upd_fcst.fit(partitioned_series, static_features=['static_0', 'static_1'])\n", "\n", "new_df = (series.groupby('unique_id', observed=True)['ds'].max() + pd.offsets.Day()).reset_index()\n", "new_df['y'] = -1.0\n", @@ -643,7 +676,7 @@ "expected = new_df.rename(columns={'y': 'Lag1Model'})\n", "expected = expected.astype({'unique_id': str})\n", "expected['ds'] += pd.offsets.Day()\n", - "upd_preds = upd_fcst.predict(1).compute()\n", + "upd_preds = upd_fcst.predict(1, X_df=future).compute()\n", "pd.testing.assert_frame_equal(\n", " upd_preds.reset_index(drop=True),\n", " expected.reset_index(drop=True),\n", @@ -669,7 +702,7 @@ "outputs": [], "source": [ "local_fcst = fcst.to_local()\n", - "local_preds = local_fcst.predict(10)\n", + "local_preds = local_fcst.predict(7, X_df=future)\n", "# we don't check the dtype because sometimes these are arrow dtypes\n", "# or different precisions of float\n", "pd.testing.assert_frame_equal(preds, local_preds, check_dtype=False)" @@ -686,7 +719,7 @@ "# test to_local without target transforms\n", "fcst_no_targ_tfms = DistributedMLForecast(\n", " models=[DaskXGBForecast(n_estimators=5, random_state=0)],\n", - " freq='D', \n", + " freq='D',\n", " lags=[1],\n", " lag_transforms={1: [ExpandingMean()]},\n", " date_features=['dayofweek'],\n", @@ -718,6 +751,7 @@ " partitioned_series,\n", " n_windows=3,\n", " h=14,\n", + " static_features=['static_0', 'static_1'],\n", ")" ] }, @@ -758,68 +792,68 @@ " \n", " \n", " \n", - " 17\n", - " id_01\n", - " 2002-08-19 00:00:00\n", - " 224.458336\n", - " 222.742605\n", + " 0\n", + " id_00\n", + " 2002-08-16 00:00:00\n", + " 23.192749\n", + " 21.986437\n", " 2002-08-15 00:00:00\n", - " 210.723139\n", + " 11.878591\n", " \n", " \n", - " 43\n", - " id_03\n", - " 2002-08-17 00:00:00\n", - " 2.235601\n", - " 2.210624\n", + " 30\n", + " id_02\n", + " 2002-08-18 00:00:00\n", + " 96.59974\n", + " 96.568057\n", " 2002-08-15 00:00:00\n", - " 2.416967\n", + " 94.706551\n", " \n", " \n", - " 44\n", - " id_03\n", - " 2002-08-18 00:00:00\n", - " 3.276747\n", - " 3.239702\n", + " 80\n", + " id_05\n", + " 2002-08-26 00:00:00\n", + " 257.210466\n", + " 255.908309\n", " 2002-08-15 00:00:00\n", - " 3.060194\n", + " 246.051086\n", " \n", " \n", - " 119\n", - " id_08\n", - " 2002-08-23 00:00:00\n", - " 131.261689\n", - " 131.180289\n", + " 36\n", + " id_12\n", + " 2002-08-24 00:00:00\n", + " 401.081335\n", + " 401.697836\n", " 2002-08-15 00:00:00\n", - " 138.668463\n", + " 424.296882\n", " \n", " \n", - " 131\n", - " id_09\n", - " 2002-08-21 00:00:00\n", - " 27.716417\n", - " 28.263963\n", + " 91\n", + " id_16\n", + " 2002-08-23 00:00:00\n", + " 315.036479\n", + " 315.368377\n", " 2002-08-15 00:00:00\n", - " 22.88374\n", + " 300.419406\n", " \n", " \n", "\n", "" ], "text/plain": [ - " unique_id ds DaskXGBForecast DaskLGBMForecast \\\n", - "17 id_01 2002-08-19 00:00:00 224.458336 222.742605 \n", - "43 id_03 2002-08-17 00:00:00 2.235601 2.210624 \n", - "44 id_03 2002-08-18 00:00:00 3.276747 3.239702 \n", - "119 id_08 2002-08-23 00:00:00 131.261689 131.180289 \n", - "131 id_09 2002-08-21 00:00:00 27.716417 28.263963 \n", + " unique_id ds DaskXGBForecast DaskLGBMForecast \\\n", + "0 id_00 2002-08-16 00:00:00 23.192749 21.986437 \n", + "30 id_02 2002-08-18 00:00:00 96.59974 96.568057 \n", + "80 id_05 2002-08-26 00:00:00 257.210466 255.908309 \n", + "36 id_12 2002-08-24 00:00:00 401.081335 401.697836 \n", + "91 id_16 2002-08-23 00:00:00 315.036479 315.368377 \n", "\n", - " cutoff y \n", - "17 2002-08-15 00:00:00 210.723139 \n", - "43 2002-08-15 00:00:00 2.416967 \n", - "44 2002-08-15 00:00:00 3.060194 \n", - "119 2002-08-15 00:00:00 138.668463 \n", - "131 2002-08-15 00:00:00 22.88374 " + " cutoff y \n", + "0 2002-08-15 00:00:00 11.878591 \n", + "30 2002-08-15 00:00:00 94.706551 \n", + "80 2002-08-15 00:00:00 246.051086 \n", + "36 2002-08-15 00:00:00 424.296882 \n", + "91 2002-08-15 00:00:00 300.419406 " ] }, "execution_count": null, @@ -831,6 +865,23 @@ "cv_res.compute().head()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "ebb0e756-522d-42cf-9b6e-73b1158c21bd", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# single window CV\n", + "assert fcst.cross_validation(\n", + " partitioned_series,\n", + " n_windows=1,\n", + " h=5,\n", + " static_features=['static_0', 'static_1'],\n", + ").compute()['cutoff'].nunique() == 1" + ] + }, { "cell_type": "code", "execution_count": null, @@ -865,6 +916,7 @@ " i_window=0,\n", " input_size=input_size,\n", " ),\n", + " static_features=['static_0', 'static_1'],\n", ")\n", "assert reduced_train.groupby('unique_id').size().compute().max() == input_size" ] @@ -881,7 +933,8 @@ " partitioned_series,\n", " n_windows=3,\n", " h=14,\n", - " refit=False\n", + " refit=False,\n", + " static_features=['static_0', 'static_1'],\n", ")\n", "cv_results_df = cv_res.compute().sort_values(['unique_id', 'ds'])\n", "cv_results_no_refit_df = cv_res_no_refit.compute().sort_values(['unique_id', 'ds'])\n", @@ -914,12 +967,18 @@ " num_threads=1,\n", ")\n", "fcst = DistributedMLForecast(freq='D', **flow_params)\n", - "fcst.fit(partitioned_series)\n", - "preds = fcst.predict(7).compute()\n", + "fcst.fit(partitioned_series, static_features=['static_0', 'static_1'])\n", + "preds = fcst.predict(7, X_df=future).compute()\n", "fcst2 = DistributedMLForecast(freq='D', **flow_params)\n", - "fcst2.preprocess(non_std_series, id_col='some_id', time_col='time', target_col='value')\n", + "fcst2.preprocess(\n", + " non_std_series,\n", + " id_col='some_id',\n", + " time_col='time',\n", + " target_col='value',\n", + " static_features=['static_0', 'static_1'],\n", + ")\n", "fcst2.models_ = fcst.models_ # distributed training can end up with different fits\n", - "non_std_preds = fcst2.predict(7).compute()\n", + "non_std_preds = fcst2.predict(7, X_df=future.rename(columns={'ds': 'time', 'unique_id': 'some_id'})).compute()\n", "pd.testing.assert_frame_equal(\n", " preds.drop(columns='ds'),\n", " non_std_preds.drop(columns='time').rename(columns={'some_id': 'unique_id'})\n", @@ -996,9 +1055,11 @@ "metadata": {}, "outputs": [], "source": [ + "series = generate_daily_series(100, n_static_features=2, equal_ends=True, static_as_categorical=False, min_length=500, max_length=1_000)\n", + "series['unique_id'] = series['unique_id'].astype(str) # can't handle categoricals atm\n", + "train, future = fourier(series, freq='d', season_length=7, k=2, h=7)\n", "numPartitions = 4\n", - "series = generate_daily_series(100, n_static_features=2, equal_ends=True, static_as_categorical=False)\n", - "spark_series = spark.createDataFrame(series).repartitionByRange(numPartitions, 'unique_id')" + "spark_series = spark.createDataFrame(train).repartitionByRange(numPartitions, 'unique_id')" ] }, { @@ -1028,7 +1089,10 @@ "metadata": {}, "outputs": [], "source": [ - "models = [SparkLGBMForecast(seed=0), SparkXGBForecast(random_state=0)]" + "models = [\n", + " SparkLGBMForecast(seed=0, verbosity=-1),\n", + " SparkXGBForecast(random_state=0),\n", + "]" ] }, { @@ -1082,7 +1146,7 @@ "source": [ "#| hide\n", "# test num_partitions works properly\n", - "test_spark_df = spark.createDataFrame(series)\n", + "test_spark_df = spark.createDataFrame(train)\n", "num_partitions_test = 10\n", "fcst_np = DistributedMLForecast(\n", " models=models,\n", @@ -1096,10 +1160,10 @@ " num_threads=1,\n", " num_partitions=num_partitions_test,\n", ")\n", - "fcst_np.fit(test_spark_df)\n", + "fcst_np.fit(test_spark_df, static_features=['static_0', 'static_1'])\n", "test_partition_results_size(fcst_np, num_partitions_test)\n", - "preds_np = fcst_np.predict(7).toPandas().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", - "preds = fcst.predict(7).toPandas().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", + "preds_np = fcst_np.predict(7, X_df=future).toPandas().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", + "preds = fcst.predict(7, X_df=future).toPandas().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", "pd.testing.assert_frame_equal(\n", " preds[['unique_id', 'ds']], \n", " preds_np[['unique_id', 'ds']], \n", @@ -1121,7 +1185,7 @@ "metadata": {}, "outputs": [], "source": [ - "preds = fcst.predict(14).toPandas()" + "preds = fcst.predict(7, X_df=future).toPandas()" ] }, { @@ -1161,37 +1225,37 @@ " \n", " 0\n", " id_00\n", - " 2001-05-15\n", - " 430.964632\n", - " 431.202969\n", + " 2002-09-27\n", + " 15.102403\n", + " 18.631477\n", " \n", " \n", " 1\n", " id_00\n", - " 2001-05-16\n", - " 505.411960\n", - " 504.030227\n", + " 2002-09-28\n", + " 92.980261\n", + " 93.796269\n", " \n", " \n", " 2\n", " id_00\n", - " 2001-05-17\n", - " 9.889056\n", - " 9.706636\n", + " 2002-09-29\n", + " 160.090375\n", + " 159.582315\n", " \n", " \n", " 3\n", " id_00\n", - " 2001-05-18\n", - " 99.359694\n", - " 96.258271\n", + " 2002-09-30\n", + " 250.416113\n", + " 250.861651\n", " \n", " \n", " 4\n", " id_00\n", - " 2001-05-19\n", - " 196.307731\n", - " 197.443618\n", + " 2002-10-01\n", + " 323.306184\n", + " 321.564089\n", " \n", " \n", "\n", @@ -1199,11 +1263,11 @@ ], "text/plain": [ " unique_id ds SparkLGBMForecast SparkXGBForecast\n", - "0 id_00 2001-05-15 430.964632 431.202969\n", - "1 id_00 2001-05-16 505.411960 504.030227\n", - "2 id_00 2001-05-17 9.889056 9.706636\n", - "3 id_00 2001-05-18 99.359694 96.258271\n", - "4 id_00 2001-05-19 196.307731 197.443618" + "0 id_00 2002-09-27 15.102403 18.631477\n", + "1 id_00 2002-09-28 92.980261 93.796269\n", + "2 id_00 2002-09-29 160.090375 159.582315\n", + "3 id_00 2002-09-30 250.416113 250.861651\n", + "4 id_00 2002-10-01 323.306184 321.564089" ] }, "execution_count": null, @@ -1308,8 +1372,8 @@ } ], "source": [ - "preds = fa.as_pandas(fcst.predict(10)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", - "preds2 = fa.as_pandas(fcst2.predict(10)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", + "preds = fa.as_pandas(fcst.predict(7, X_df=future)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", + "preds2 = fa.as_pandas(fcst2.predict(7, X_df=future)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", "pd.testing.assert_frame_equal(preds, preds2)" ] }, @@ -1331,7 +1395,7 @@ "outputs": [], "source": [ "local_fcst = fcst.to_local()\n", - "local_preds = local_fcst.predict(10)\n", + "local_preds = local_fcst.predict(7, X_df=future)\n", "# we don't check the dtype because sometimes these are arrow dtypes\n", "# or different precisions of float\n", "pd.testing.assert_frame_equal(preds, local_preds, check_dtype=False)" @@ -1356,6 +1420,7 @@ " spark_series,\n", " n_windows=3,\n", " h=14,\n", + " static_features=['static_0', 'static_1'],\n", ").toPandas()" ] }, @@ -1397,48 +1462,48 @@ " \n", " \n", " 0\n", - " id_15\n", - " 2001-04-04\n", - " 88.438691\n", - " 86.105463\n", - " 2001-04-02\n", - " 92.468763\n", + " id_03\n", + " 2002-08-18\n", + " 3.272922\n", + " 3.348874\n", + " 2002-08-15\n", + " 3.060194\n", " \n", " \n", " 1\n", - " id_25\n", - " 2001-04-12\n", - " 355.712493\n", - " 354.525400\n", - " 2001-04-02\n", - " 320.701359\n", + " id_09\n", + " 2002-08-20\n", + " 402.718091\n", + " 402.622501\n", + " 2002-08-15\n", + " 398.784459\n", " \n", " \n", " 2\n", - " id_03\n", - " 2001-04-08\n", - " 257.243845\n", - " 253.834157\n", - " 2001-04-02\n", - " 274.420045\n", + " id_25\n", + " 2002-08-22\n", + " 87.189811\n", + " 86.891632\n", + " 2002-08-15\n", + " 82.731377\n", " \n", " \n", " 3\n", - " id_14\n", - " 2001-04-07\n", - " 24.925278\n", - " 23.833504\n", - " 2001-04-02\n", - " 26.906679\n", + " id_06\n", + " 2002-08-21\n", + " 20.416790\n", + " 20.478502\n", + " 2002-08-15\n", + " 19.196394\n", " \n", " \n", " 4\n", - " id_01\n", - " 2001-04-16\n", - " 89.180665\n", - " 90.743194\n", - " 2001-04-02\n", - " 93.807725\n", + " id_22\n", + " 2002-08-23\n", + " 357.718513\n", + " 360.502024\n", + " 2002-08-15\n", + " 394.770699\n", " \n", " \n", "\n", @@ -1446,18 +1511,18 @@ ], "text/plain": [ " unique_id ds SparkLGBMForecast SparkXGBForecast cutoff \\\n", - "0 id_15 2001-04-04 88.438691 86.105463 2001-04-02 \n", - "1 id_25 2001-04-12 355.712493 354.525400 2001-04-02 \n", - "2 id_03 2001-04-08 257.243845 253.834157 2001-04-02 \n", - "3 id_14 2001-04-07 24.925278 23.833504 2001-04-02 \n", - "4 id_01 2001-04-16 89.180665 90.743194 2001-04-02 \n", + "0 id_03 2002-08-18 3.272922 3.348874 2002-08-15 \n", + "1 id_09 2002-08-20 402.718091 402.622501 2002-08-15 \n", + "2 id_25 2002-08-22 87.189811 86.891632 2002-08-15 \n", + "3 id_06 2002-08-21 20.416790 20.478502 2002-08-15 \n", + "4 id_22 2002-08-23 357.718513 360.502024 2002-08-15 \n", "\n", " y \n", - "0 92.468763 \n", - "1 320.701359 \n", - "2 274.420045 \n", - "3 26.906679 \n", - "4 93.807725 " + "0 3.060194 \n", + "1 398.784459 \n", + "2 82.731377 \n", + "3 19.196394 \n", + "4 394.770699 " ] }, "execution_count": null, @@ -1540,10 +1605,10 @@ "metadata": {}, "outputs": [], "source": [ - "series = generate_daily_series(100, n_static_features=2, equal_ends=True, static_as_categorical=False)\n", - "# we need noncategory unique_id\n", - "series['unique_id'] = series['unique_id'].astype(str)\n", - "ray_series = ray.data.from_pandas(series)" + "series = generate_daily_series(100, n_static_features=2, equal_ends=True, static_as_categorical=False, min_length=500, max_length=1_000)\n", + "series['unique_id'] = series['unique_id'].astype(str) # can't handle categoricals atm\n", + "train, future = fourier(series, freq='d', season_length=7, k=2, h=7)\n", + "ray_series = ray.data.from_pandas(train)" ] }, { @@ -1573,7 +1638,10 @@ "metadata": {}, "outputs": [], "source": [ - "models = [RayLGBMForecast(random_state=0), RayXGBForecast(random_state=0)]" + "models = [\n", + " RayLGBMForecast(random_state=0, verbosity=-1),\n", + " RayXGBForecast(random_state=0),\n", + "]" ] }, { @@ -1659,13 +1727,13 @@ " date_features=['dayofweek', 'month'],\n", " num_threads=1,\n", ")\n", - "fcst_np.fit(ray_series)\n", + "fcst_np.fit(ray_series, static_features=['static_0', 'static_1'])\n", "# we dont use test_partition_results_size\n", "# since the number of objects is different \n", "# from the number of partitions\n", "test_eq(fa.count(fcst_np._partition_results), 100) # number of series\n", - "preds_np = fcst_np.predict(7).to_pandas().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", - "preds = fcst.predict(7).to_pandas().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", + "preds_np = fcst_np.predict(7, X_df=future).to_pandas().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", + "preds = fcst.predict(7, X_df=future).to_pandas().sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", "pd.testing.assert_frame_equal(\n", " preds[['unique_id', 'ds']], \n", " preds_np[['unique_id', 'ds']], \n", @@ -1687,7 +1755,7 @@ "metadata": {}, "outputs": [], "source": [ - "preds = fcst.predict(14).to_pandas()" + "preds = fcst.predict(7, X_df=future).to_pandas()" ] }, { @@ -1726,38 +1794,38 @@ " \n", " \n", " 0\n", - " id_01\n", - " 2001-05-15\n", - " 118.505341\n", - " 118.32222\n", + " id_00\n", + " 2002-09-27\n", + " 15.232455\n", + " 10.38301\n", " \n", " \n", " 1\n", - " id_01\n", - " 2001-05-16\n", - " 152.321457\n", - " 152.265915\n", + " id_00\n", + " 2002-09-28\n", + " 92.288994\n", + " 92.531502\n", " \n", " \n", " 2\n", - " id_01\n", - " 2001-05-17\n", - " 181.979599\n", - " 181.945618\n", + " id_00\n", + " 2002-09-29\n", + " 160.043472\n", + " 160.722885\n", " \n", " \n", " 3\n", - " id_01\n", - " 2001-05-18\n", - " 9.530758\n", - " 9.543224\n", + " id_00\n", + " 2002-09-30\n", + " 250.03212\n", + " 252.821899\n", " \n", " \n", " 4\n", - " id_01\n", - " 2001-05-19\n", - " 40.503441\n", - " 40.661186\n", + " id_00\n", + " 2002-10-01\n", + " 322.905182\n", + " 324.387695\n", " \n", " \n", "\n", @@ -1765,11 +1833,11 @@ ], "text/plain": [ " unique_id ds RayLGBMForecast RayXGBForecast\n", - "0 id_01 2001-05-15 118.505341 118.32222\n", - "1 id_01 2001-05-16 152.321457 152.265915\n", - "2 id_01 2001-05-17 181.979599 181.945618\n", - "3 id_01 2001-05-18 9.530758 9.543224\n", - "4 id_01 2001-05-19 40.503441 40.661186" + "0 id_00 2002-09-27 15.232455 10.38301\n", + "1 id_00 2002-09-28 92.288994 92.531502\n", + "2 id_00 2002-09-29 160.043472 160.722885\n", + "3 id_00 2002-09-30 250.03212 252.821899\n", + "4 id_00 2002-10-01 322.905182 324.387695" ] }, "execution_count": null, @@ -1850,8 +1918,8 @@ "metadata": {}, "outputs": [], "source": [ - "preds = fa.as_pandas(fcst.predict(10)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", - "preds2 = fa.as_pandas(fcst2.predict(10)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", + "preds = fa.as_pandas(fcst.predict(7, X_df=future)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", + "preds2 = fa.as_pandas(fcst2.predict(7, X_df=future)).sort_values(['unique_id', 'ds']).reset_index(drop=True)\n", "pd.testing.assert_frame_equal(preds, preds2)" ] }, @@ -1873,7 +1941,7 @@ "outputs": [], "source": [ "local_fcst = fcst.to_local()\n", - "local_preds = local_fcst.predict(10)\n", + "local_preds = local_fcst.predict(7, X_df=future)\n", "# we don't check the dtype because sometimes these are arrow dtypes\n", "# or different precisions of float\n", "pd.testing.assert_frame_equal(preds, local_preds, check_dtype=False)" @@ -1898,6 +1966,7 @@ " ray_series,\n", " n_windows=3,\n", " h=14,\n", + " static_features=['static_0', 'static_1'],\n", ").to_pandas()" ] }, @@ -1939,48 +2008,48 @@ " \n", " \n", " 0\n", - " id_10\n", - " 2001-05-01\n", - " 24.767561\n", - " 24.528799\n", - " 2001-04-30\n", - " 31.878545\n", + " id_04\n", + " 2002-09-20\n", + " 118.982094\n", + " 117.577477\n", + " 2002-09-12\n", + " 118.603489\n", " \n", " \n", " 1\n", - " id_10\n", - " 2001-05-07\n", - " 1.916985\n", - " 2.323445\n", - " 2001-04-30\n", - " 7.365955\n", + " id_04\n", + " 2002-09-24\n", + " 51.461491\n", + " 50.120552\n", + " 2002-09-12\n", + " 52.668389\n", " \n", " \n", " 2\n", - " id_13\n", - " 2001-05-01\n", - " 210.900330\n", - " 212.959320\n", - " 2001-04-30\n", - " 190.485236\n", + " id_05\n", + " 2002-09-20\n", + " 27.594826\n", + " 24.421537\n", + " 2002-09-12\n", + " 20.120710\n", " \n", " \n", " 3\n", - " id_14\n", - " 2001-05-01\n", - " 196.620819\n", - " 196.253036\n", - " 2001-04-30\n", - " 213.631212\n", + " id_05\n", + " 2002-09-25\n", + " 411.615204\n", + " 412.093384\n", + " 2002-09-12\n", + " 419.621422\n", " \n", " \n", " 4\n", - " id_14\n", - " 2001-05-03\n", - " 323.323334\n", - " 322.372894\n", - " 2001-04-30\n", - " 338.234837\n", + " id_08\n", + " 2002-09-25\n", + " 83.210945\n", + " 83.842705\n", + " 2002-09-12\n", + " 86.344885\n", " \n", " \n", "\n", @@ -1988,11 +2057,11 @@ ], "text/plain": [ " unique_id ds RayLGBMForecast RayXGBForecast cutoff y\n", - "0 id_10 2001-05-01 24.767561 24.528799 2001-04-30 31.878545\n", - "1 id_10 2001-05-07 1.916985 2.323445 2001-04-30 7.365955\n", - "2 id_13 2001-05-01 210.900330 212.959320 2001-04-30 190.485236\n", - "3 id_14 2001-05-01 196.620819 196.253036 2001-04-30 213.631212\n", - "4 id_14 2001-05-03 323.323334 322.372894 2001-04-30 338.234837" + "0 id_04 2002-09-20 118.982094 117.577477 2002-09-12 118.603489\n", + "1 id_04 2002-09-24 51.461491 50.120552 2002-09-12 52.668389\n", + "2 id_05 2002-09-20 27.594826 24.421537 2002-09-12 20.120710\n", + "3 id_05 2002-09-25 411.615204 412.093384 2002-09-12 419.621422\n", + "4 id_08 2002-09-25 83.210945 83.842705 2002-09-12 86.344885" ] }, "execution_count": null,