Skip to content

Commit 6c4bbd9

Browse files
authored
Merge pull request #4 from st-tech/feature/fix_reg_model
fix some bugs in cross-fitting
2 parents efdcab2 + 2e089d8 commit 6c4bbd9

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

obp/ope/regression_model.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import numpy as np
99
from sklearn.base import BaseEstimator, clone, is_classifier
10-
from sklearn.model_selection import StratifiedKFold
10+
from sklearn.model_selection import KFold
1111

1212
from ..utils import check_bandit_feedback_inputs
1313

@@ -132,7 +132,10 @@ def fit(
132132
if self.fitting_method in ["iw", "mrdr"]:
133133
assert (
134134
action_dist is not None
135-
), "When either 'iw' or 'mrdr' is used as the 'fitting_method' argument, then `action_dist` must be given"
135+
), "When either 'iw' or 'mrdr' is used as the 'fitting_method' argument, then action_dist must be given"
136+
assert (
137+
pscore is not None
138+
), "When either 'iw' or 'mrdr' is used as the 'fitting_method' argument, then pscore must be given"
136139
n_data = context.shape[0]
137140
for position_ in np.arange(self.len_list):
138141
idx = position == position_
@@ -252,9 +255,22 @@ def fit_predict(
252255
Estimated expected rewards for the given logged bandit feedback at each item and position by the regression model.
253256
254257
"""
255-
assert n_folds > 1 and isinstance(
258+
assert n_folds > 0 and isinstance(
256259
n_folds, int
257-
), f"n_folds must be an integer larger than 1, but {n_folds} is given"
260+
), f"n_folds must be a positive integer, but {n_folds} is given"
261+
if self.len_list == 1:
262+
position = np.zeros_like(action)
263+
else:
264+
assert (
265+
position is not None
266+
), "position has to be set when len_list is larger than 1"
267+
if self.fitting_method in ["iw", "mrdr"]:
268+
assert (
269+
action_dist is not None
270+
), "When either 'iw' or 'mrdr' is used as the 'fitting_method' argument, then action_dist must be given"
271+
assert (
272+
pscore is not None
273+
), "When either 'iw' or 'mrdr' is used as the 'fitting_method' argument, then pscore must be given"
258274

259275
if n_folds == 1:
260276
self.fit(
@@ -270,8 +286,8 @@ def fit_predict(
270286
estimated_rewards_by_reg_model = np.zeros(
271287
(context.shape[0], self.n_actions, self.len_list)
272288
)
273-
skf = StratifiedKFold(n_splits=n_folds)
274-
skf.get_n_splits(context, reward)
289+
skf = KFold(n_splits=n_folds, shuffle=True)
290+
skf.get_n_splits(context)
275291
for train_idx, test_idx in skf.split(context, reward):
276292
action_dist_tr = (
277293
action_dist[train_idx] if action_dist is not None else action_dist

obp/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.0"
1+
__version__ = "0.3.1"

0 commit comments

Comments
 (0)