Skip to content

Commit 9d62615

Browse files
authored
Merge pull request #172 from st-tech/feature/mips-slope-with-true-iw
Allowing slope to use the true marginal importance weight for mips
2 parents 122743e + 0e94113 commit 9d62615

File tree

1 file changed

+73
-30
lines changed

1 file changed

+73
-30
lines changed

obp/ope/estimators_embed.py

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ def estimate_policy_value(
303303
position=position,
304304
pi_b=pi_b,
305305
action_dist=action_dist,
306+
p_e_a=p_e_a,
306307
)
307308
elif self.embedding_selection_method == "greedy":
308309
return self._estimate_with_greedy_pruning(
@@ -313,6 +314,7 @@ def estimate_policy_value(
313314
position=position,
314315
pi_b=pi_b,
315316
action_dist=action_dist,
317+
p_e_a=p_e_a,
316318
)
317319
else:
318320
return self._estimate_round_rewards(
@@ -335,6 +337,7 @@ def _estimate_with_exact_pruning(
335337
pi_b: np.ndarray,
336338
action_dist: np.ndarray,
337339
position: np.ndarray,
340+
p_e_a: Optional[np.ndarray] = None,
338341
) -> float:
339342
"""Apply an exact version of data-drive action embedding selection."""
340343
n_emb_dim = action_embed.shape[1]
@@ -344,16 +347,29 @@ def _estimate_with_exact_pruning(
344347
comb_list = list(itertools.combinations(feat_list, i))
345348
theta_list_, cnf_list_ = [], []
346349
for comb in comb_list:
347-
theta, cnf = self._estimate_round_rewards(
348-
context=context,
349-
reward=reward,
350-
action=action,
351-
action_embed=action_embed[:, comb],
352-
pi_b=pi_b,
353-
action_dist=action_dist,
354-
position=position,
355-
with_dev=True,
356-
)
350+
if p_e_a is None:
351+
theta, cnf = self._estimate_round_rewards(
352+
context=context,
353+
reward=reward,
354+
action=action,
355+
action_embed=action_embed[:, comb],
356+
pi_b=pi_b,
357+
action_dist=action_dist,
358+
position=position,
359+
with_dev=True,
360+
)
361+
else:
362+
theta, cnf = self._estimate_round_rewards(
363+
context=context,
364+
reward=reward,
365+
action=action,
366+
action_embed=action_embed[:, comb],
367+
pi_b=pi_b,
368+
action_dist=action_dist,
369+
position=position,
370+
p_e_a=p_e_a[:, :, comb],
371+
with_dev=True,
372+
)
357373
if len(theta_list) > 0:
358374
theta_list_.append(theta), cnf_list_.append(cnf)
359375
else:
@@ -380,23 +396,37 @@ def _estimate_with_greedy_pruning(
380396
pi_b: np.ndarray,
381397
action_dist: np.ndarray,
382398
position: np.ndarray,
399+
p_e_a: Optional[np.ndarray] = None,
383400
) -> float:
384401
"""Apply a greedy version of data-drive action embedding selection."""
385402
n_emb_dim = action_embed.shape[1]
386403
theta_list, cnf_list = [], []
387404
current_feat, C = np.arange(n_emb_dim), np.sqrt(6) - 1
388405

389406
# init
390-
theta, cnf = self._estimate_round_rewards(
391-
context=context,
392-
reward=reward,
393-
action=action,
394-
action_embed=action_embed[:, current_feat],
395-
pi_b=pi_b,
396-
action_dist=action_dist,
397-
position=position,
398-
with_dev=True,
399-
)
407+
if p_e_a is None:
408+
theta, cnf = self._estimate_round_rewards(
409+
context=context,
410+
reward=reward,
411+
action=action,
412+
action_embed=action_embed[:, current_feat],
413+
pi_b=pi_b,
414+
action_dist=action_dist,
415+
position=position,
416+
with_dev=True,
417+
)
418+
else:
419+
theta, cnf = self._estimate_round_rewards(
420+
context=context,
421+
reward=reward,
422+
action=action,
423+
action_embed=action_embed[:, current_feat],
424+
pi_b=pi_b,
425+
action_dist=action_dist,
426+
position=position,
427+
p_e_a=p_e_a[:, :, current_feat],
428+
with_dev=True,
429+
)
400430
theta_list.append(theta), cnf_list.append(cnf)
401431

402432
# iterate
@@ -405,16 +435,29 @@ def _estimate_with_greedy_pruning(
405435
for d in current_feat:
406436
idx_without_d = np.where(current_feat != d, True, False)
407437
candidate_feat = current_feat[idx_without_d]
408-
theta, cnf = self._estimate_round_rewards(
409-
context=context,
410-
reward=reward,
411-
action=action,
412-
action_embed=action_embed[:, candidate_feat],
413-
pi_b=pi_b,
414-
action_dist=action_dist,
415-
position=position,
416-
with_dev=True,
417-
)
438+
if p_e_a is None:
439+
theta, cnf = self._estimate_round_rewards(
440+
context=context,
441+
reward=reward,
442+
action=action,
443+
action_embed=action_embed[:, candidate_feat],
444+
pi_b=pi_b,
445+
action_dist=action_dist,
446+
position=position,
447+
with_dev=True,
448+
)
449+
else:
450+
theta, cnf = self._estimate_round_rewards(
451+
context=context,
452+
reward=reward,
453+
action=action,
454+
action_embed=action_embed[:, candidate_feat],
455+
pi_b=pi_b,
456+
action_dist=action_dist,
457+
position=position,
458+
p_e_a=p_e_a[:, :, candidate_feat],
459+
with_dev=True,
460+
)
418461
d_list_.append(d)
419462
theta_list_.append(theta), cnf_list_.append(cnf)
420463

0 commit comments

Comments
 (0)