@@ -303,6 +303,7 @@ def estimate_policy_value(
303
303
position = position ,
304
304
pi_b = pi_b ,
305
305
action_dist = action_dist ,
306
+ p_e_a = p_e_a ,
306
307
)
307
308
elif self .embedding_selection_method == "greedy" :
308
309
return self ._estimate_with_greedy_pruning (
@@ -313,6 +314,7 @@ def estimate_policy_value(
313
314
position = position ,
314
315
pi_b = pi_b ,
315
316
action_dist = action_dist ,
317
+ p_e_a = p_e_a ,
316
318
)
317
319
else :
318
320
return self ._estimate_round_rewards (
@@ -335,6 +337,7 @@ def _estimate_with_exact_pruning(
335
337
pi_b : np .ndarray ,
336
338
action_dist : np .ndarray ,
337
339
position : np .ndarray ,
340
+ p_e_a : Optional [np .ndarray ] = None ,
338
341
) -> float :
339
342
"""Apply an exact version of data-drive action embedding selection."""
340
343
n_emb_dim = action_embed .shape [1 ]
@@ -344,16 +347,29 @@ def _estimate_with_exact_pruning(
344
347
comb_list = list (itertools .combinations (feat_list , i ))
345
348
theta_list_ , cnf_list_ = [], []
346
349
for comb in comb_list :
347
- theta , cnf = self ._estimate_round_rewards (
348
- context = context ,
349
- reward = reward ,
350
- action = action ,
351
- action_embed = action_embed [:, comb ],
352
- pi_b = pi_b ,
353
- action_dist = action_dist ,
354
- position = position ,
355
- with_dev = True ,
356
- )
350
+ if p_e_a is None :
351
+ theta , cnf = self ._estimate_round_rewards (
352
+ context = context ,
353
+ reward = reward ,
354
+ action = action ,
355
+ action_embed = action_embed [:, comb ],
356
+ pi_b = pi_b ,
357
+ action_dist = action_dist ,
358
+ position = position ,
359
+ with_dev = True ,
360
+ )
361
+ else :
362
+ theta , cnf = self ._estimate_round_rewards (
363
+ context = context ,
364
+ reward = reward ,
365
+ action = action ,
366
+ action_embed = action_embed [:, comb ],
367
+ pi_b = pi_b ,
368
+ action_dist = action_dist ,
369
+ position = position ,
370
+ p_e_a = p_e_a [:, :, comb ],
371
+ with_dev = True ,
372
+ )
357
373
if len (theta_list ) > 0 :
358
374
theta_list_ .append (theta ), cnf_list_ .append (cnf )
359
375
else :
@@ -380,23 +396,37 @@ def _estimate_with_greedy_pruning(
380
396
pi_b : np .ndarray ,
381
397
action_dist : np .ndarray ,
382
398
position : np .ndarray ,
399
+ p_e_a : Optional [np .ndarray ] = None ,
383
400
) -> float :
384
401
"""Apply a greedy version of data-drive action embedding selection."""
385
402
n_emb_dim = action_embed .shape [1 ]
386
403
theta_list , cnf_list = [], []
387
404
current_feat , C = np .arange (n_emb_dim ), np .sqrt (6 ) - 1
388
405
389
406
# init
390
- theta , cnf = self ._estimate_round_rewards (
391
- context = context ,
392
- reward = reward ,
393
- action = action ,
394
- action_embed = action_embed [:, current_feat ],
395
- pi_b = pi_b ,
396
- action_dist = action_dist ,
397
- position = position ,
398
- with_dev = True ,
399
- )
407
+ if p_e_a is None :
408
+ theta , cnf = self ._estimate_round_rewards (
409
+ context = context ,
410
+ reward = reward ,
411
+ action = action ,
412
+ action_embed = action_embed [:, current_feat ],
413
+ pi_b = pi_b ,
414
+ action_dist = action_dist ,
415
+ position = position ,
416
+ with_dev = True ,
417
+ )
418
+ else :
419
+ theta , cnf = self ._estimate_round_rewards (
420
+ context = context ,
421
+ reward = reward ,
422
+ action = action ,
423
+ action_embed = action_embed [:, current_feat ],
424
+ pi_b = pi_b ,
425
+ action_dist = action_dist ,
426
+ position = position ,
427
+ p_e_a = p_e_a [:, :, current_feat ],
428
+ with_dev = True ,
429
+ )
400
430
theta_list .append (theta ), cnf_list .append (cnf )
401
431
402
432
# iterate
@@ -405,16 +435,29 @@ def _estimate_with_greedy_pruning(
405
435
for d in current_feat :
406
436
idx_without_d = np .where (current_feat != d , True , False )
407
437
candidate_feat = current_feat [idx_without_d ]
408
- theta , cnf = self ._estimate_round_rewards (
409
- context = context ,
410
- reward = reward ,
411
- action = action ,
412
- action_embed = action_embed [:, candidate_feat ],
413
- pi_b = pi_b ,
414
- action_dist = action_dist ,
415
- position = position ,
416
- with_dev = True ,
417
- )
438
+ if p_e_a is None :
439
+ theta , cnf = self ._estimate_round_rewards (
440
+ context = context ,
441
+ reward = reward ,
442
+ action = action ,
443
+ action_embed = action_embed [:, candidate_feat ],
444
+ pi_b = pi_b ,
445
+ action_dist = action_dist ,
446
+ position = position ,
447
+ with_dev = True ,
448
+ )
449
+ else :
450
+ theta , cnf = self ._estimate_round_rewards (
451
+ context = context ,
452
+ reward = reward ,
453
+ action = action ,
454
+ action_embed = action_embed [:, candidate_feat ],
455
+ pi_b = pi_b ,
456
+ action_dist = action_dist ,
457
+ position = position ,
458
+ p_e_a = p_e_a [:, :, candidate_feat ],
459
+ with_dev = True ,
460
+ )
418
461
d_list_ .append (d )
419
462
theta_list_ .append (theta ), cnf_list_ .append (cnf )
420
463
0 commit comments