Skip to content

Commit 65e65f3

Browse files
Fix/pd variance plot test coverage (#820)
* Fixed minor bug for conditional plots. * Included importance explanation fixture. * Included interaction explanation fixture. * Included test for hbar ncols. Solved flake8 issues. * Included value error ax for hbar plot. * Included interaction plot tests. * Included plot_pd_variance tests. * Updated the explanation objects with random values. Fixed test that failed before because of simple explanation values. * More random values. * Removed lazy fixture.
1 parent 536b7be commit 65e65f3

File tree

2 files changed

+378
-21
lines changed

2 files changed

+378
-21
lines changed

alibi/explainers/pd_variance.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
import numpy as np
1212
from sklearn.base import BaseEstimator
1313

14-
from alibi.api.defaults import (DEFAULT_DATA_PD, DEFAULT_DATA_PDVARIANCE,
15-
DEFAULT_META_PD, DEFAULT_META_PDVARIANCE)
14+
from alibi.api.defaults import (DEFAULT_DATA_PD, DEFAULT_DATA_PDVARIANCE, DEFAULT_META_PDVARIANCE)
1615
from alibi.api.interfaces import Explainer, Explanation
1716
from alibi.explainers import plot_pd
1817
from alibi.explainers.partial_dependence import (Kind, PartialDependence,
@@ -559,10 +558,9 @@ def _plot_feature_importance(exp: Explanation,
559558
feature_importance = feature_importance[sorted_indices][:top_k]
560559

561560
# construct pd explanation object to reuse `plot_pd` function
562-
meta = copy.deepcopy(DEFAULT_META_PD)
563-
data = copy.deepcopy(DEFAULT_DATA_PD)
564-
meta.update(exp.meta)
561+
meta = copy.deepcopy(exp.meta)
565562
meta['params']['kind'] = 'average'
563+
data = copy.deepcopy(DEFAULT_DATA_PD)
566564
data.update(feature_names=feature_names,
567565
feature_values=feature_values,
568566
pd_values=pd_values,
@@ -673,10 +671,9 @@ def _plot_feature_interaction(exp: Explanation,
673671
merged_features = list(itertools.chain.from_iterable(merged_features)) # type: ignore[arg-type]
674672

675673
# construct pd explanation object to reuse `plot_pd` function
676-
meta = copy.deepcopy(DEFAULT_META_PD)
677-
data = copy.deepcopy(DEFAULT_DATA_PD)
678-
meta.update(exp.meta)
674+
meta = copy.deepcopy(exp.meta)
679675
meta['params']['kind'] = 'average'
676+
data = copy.deepcopy(DEFAULT_DATA_PD)
680677
data.update(feature_names=merged_feature_names,
681678
feature_values=merged_feature_values,
682679
pd_values=merged_pd_values,
@@ -704,11 +701,11 @@ def _plot_feature_interaction(exp: Explanation,
704701

705702
# set title for the first conditional importance plot
706703
ax = axes.flatten()[step * i + 1]
707-
ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name2, ft_name1, conditional_importance[i][0].item()))
704+
ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name2, ft_name1, conditional_importance[i][0][target_idx]))
708705

709706
# set title for the second conditional importance plot
710707
ax = axes.flatten()[step * i + 2]
711-
ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name1, ft_name2, conditional_importance[i][1].item()))
708+
ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name1, ft_name2, conditional_importance[i][1][target_idx]))
712709

713710
return axes
714711

0 commit comments

Comments
 (0)