7
7
8
8
import numpy as np
9
9
from sklearn .base import BaseEstimator , clone , is_classifier
10
- from sklearn .model_selection import StratifiedKFold
10
+ from sklearn .model_selection import KFold
11
11
12
12
from ..utils import check_bandit_feedback_inputs
13
13
@@ -132,7 +132,10 @@ def fit(
132
132
if self .fitting_method in ["iw" , "mrdr" ]:
133
133
assert (
134
134
action_dist is not None
135
- ), "When either 'iw' or 'mrdr' is used as the 'fitting_method' argument, then `action_dist` must be given"
135
+ ), "When either 'iw' or 'mrdr' is used as the 'fitting_method' argument, then action_dist must be given"
136
+ assert (
137
+ pscore is not None
138
+ ), "When either 'iw' or 'mrdr' is used as the 'fitting_method' argument, then pscore must be given"
136
139
n_data = context .shape [0 ]
137
140
for position_ in np .arange (self .len_list ):
138
141
idx = position == position_
@@ -252,9 +255,22 @@ def fit_predict(
252
255
Estimated expected rewards for the given logged bandit feedback at each item and position by the regression model.
253
256
254
257
"""
255
- assert n_folds > 1 and isinstance (
258
+ assert n_folds > 0 and isinstance (
256
259
n_folds , int
257
- ), f"n_folds must be an integer larger than 1, but { n_folds } is given"
260
+ ), f"n_folds must be a positive integer, but { n_folds } is given"
261
+ if self .len_list == 1 :
262
+ position = np .zeros_like (action )
263
+ else :
264
+ assert (
265
+ position is not None
266
+ ), "position has to be set when len_list is larger than 1"
267
+ if self .fitting_method in ["iw" , "mrdr" ]:
268
+ assert (
269
+ action_dist is not None
270
+ ), "When either 'iw' or 'mrdr' is used as the 'fitting_method' argument, then action_dist must be given"
271
+ assert (
272
+ pscore is not None
273
+ ), "When either 'iw' or 'mrdr' is used as the 'fitting_method' argument, then pscore must be given"
258
274
259
275
if n_folds == 1 :
260
276
self .fit (
@@ -270,8 +286,8 @@ def fit_predict(
270
286
estimated_rewards_by_reg_model = np .zeros (
271
287
(context .shape [0 ], self .n_actions , self .len_list )
272
288
)
273
- skf = StratifiedKFold (n_splits = n_folds )
274
- skf .get_n_splits (context , reward )
289
+ skf = KFold (n_splits = n_folds , shuffle = True )
290
+ skf .get_n_splits (context )
275
291
for train_idx , test_idx in skf .split (context , reward ):
276
292
action_dist_tr = (
277
293
action_dist [train_idx ] if action_dist is not None else action_dist
0 commit comments