|
11 | 11 | import numpy as np
|
12 | 12 | from sklearn.base import BaseEstimator
|
13 | 13 |
|
14 |
| -from alibi.api.defaults import (DEFAULT_DATA_PD, DEFAULT_DATA_PDVARIANCE, |
15 |
| - DEFAULT_META_PD, DEFAULT_META_PDVARIANCE) |
| 14 | +from alibi.api.defaults import (DEFAULT_DATA_PD, DEFAULT_DATA_PDVARIANCE, DEFAULT_META_PDVARIANCE) |
16 | 15 | from alibi.api.interfaces import Explainer, Explanation
|
17 | 16 | from alibi.explainers import plot_pd
|
18 | 17 | from alibi.explainers.partial_dependence import (Kind, PartialDependence,
|
@@ -559,10 +558,9 @@ def _plot_feature_importance(exp: Explanation,
|
559 | 558 | feature_importance = feature_importance[sorted_indices][:top_k]
|
560 | 559 |
|
561 | 560 | # construct pd explanation object to reuse `plot_pd` function
|
562 |
| - meta = copy.deepcopy(DEFAULT_META_PD) |
563 |
| - data = copy.deepcopy(DEFAULT_DATA_PD) |
564 |
| - meta.update(exp.meta) |
| 561 | + meta = copy.deepcopy(exp.meta) |
565 | 562 | meta['params']['kind'] = 'average'
|
| 563 | + data = copy.deepcopy(DEFAULT_DATA_PD) |
566 | 564 | data.update(feature_names=feature_names,
|
567 | 565 | feature_values=feature_values,
|
568 | 566 | pd_values=pd_values,
|
@@ -673,10 +671,9 @@ def _plot_feature_interaction(exp: Explanation,
|
673 | 671 | merged_features = list(itertools.chain.from_iterable(merged_features)) # type: ignore[arg-type]
|
674 | 672 |
|
675 | 673 | # construct pd explanation object to reuse `plot_pd` function
|
676 |
| - meta = copy.deepcopy(DEFAULT_META_PD) |
677 |
| - data = copy.deepcopy(DEFAULT_DATA_PD) |
678 |
| - meta.update(exp.meta) |
| 674 | + meta = copy.deepcopy(exp.meta) |
679 | 675 | meta['params']['kind'] = 'average'
|
| 676 | + data = copy.deepcopy(DEFAULT_DATA_PD) |
680 | 677 | data.update(feature_names=merged_feature_names,
|
681 | 678 | feature_values=merged_feature_values,
|
682 | 679 | pd_values=merged_pd_values,
|
@@ -704,11 +701,11 @@ def _plot_feature_interaction(exp: Explanation,
|
704 | 701 |
|
705 | 702 | # set title for the first conditional importance plot
|
706 | 703 | ax = axes.flatten()[step * i + 1]
|
707 |
| - ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name2, ft_name1, conditional_importance[i][0].item())) |
| 704 | + ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name2, ft_name1, conditional_importance[i][0][target_idx])) |
708 | 705 |
|
709 | 706 | # set title for the second conditional importance plot
|
710 | 707 | ax = axes.flatten()[step * i + 2]
|
711 |
| - ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name1, ft_name2, conditional_importance[i][1].item())) |
| 708 | + ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name1, ft_name2, conditional_importance[i][1][target_idx])) |
712 | 709 |
|
713 | 710 | return axes
|
714 | 711 |
|
|
0 commit comments