Skip to content

Commit 507d262

Browse files
authored
FIX: Add fixture to set pyspark python driver (#3044)
* Add fixture to set pyspark python driver * Typo * Changelog
1 parent 8f9f7d1 commit 507d262

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
6060
- Fixed failing unit tests
6161
([dsgibbons#29](https://github.com/dsgibbons/shap/pull/29) by @dsgibbons,
6262
[dsgibbons#20](https://github.com/dsgibbons/shap/pull/20) by @simonangerbauer,
63+
[#3044](https://github.com/slundberg/shap/pull/3044) by @connortann,
6364
[dsgibbons#24](https://github.com/dsgibbons/shap/pull/24) by @connortann).
6465
- Include CUDA GPU C extension files in the source distribution
6566
([#3009](https://github.com/slundberg/shap/pull/3009) by @jklaise).

tests/explainers/test_tree.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import itertools
33
import math
44
import pickle
5+
import sys
56

67
import numpy as np
78
import pandas as pd
@@ -203,7 +204,13 @@ def test_ngboost():
203204
explainer.shap_values(X).sum(1) + explainer.expected_value - model.predict(X))) < 1e-5
204205

205206

206-
def test_pyspark_classifier_decision_tree():
207+
@pytest.fixture
208+
def configure_pyspark_python(monkeypatch):
209+
monkeypatch.setenv("PYSPARK_PYTHON", sys.executable)
210+
monkeypatch.setenv("PYSPARK_DRIVER_PYTHON", sys.executable)
211+
212+
213+
def test_pyspark_classifier_decision_tree(configure_pyspark_python):
207214
# pylint: disable=bare-except
208215
pyspark = pytest.importorskip("pyspark")
209216
pytest.importorskip("pyspark.ml")
@@ -258,7 +265,7 @@ def test_pyspark_classifier_decision_tree():
258265
spark.stop()
259266

260267

261-
def test_pyspark_regression_decision_tree():
268+
def test_pyspark_regression_decision_tree(configure_pyspark_python):
262269
# pylint: disable=bare-except
263270
pyspark = pytest.importorskip("pyspark")
264271
pytest.importorskip("pyspark.ml")

0 commit comments

Comments
 (0)