Skip to content

Qualx unit tests #1599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion user_tools/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ dependencies = [
# used for spinner animation
"progress==1.6",
# used for model estimations python [3.9-3.11]
"xgboost==2.1.3",
"xgboost==2.1.4",
# used for model interpretability. python [3.9, 3.12]
"shap==0.46.0",
# dependency of shap, python [3.9, 3.12]
Expand Down
2 changes: 1 addition & 1 deletion user_tools/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -60,6 +60,13 @@ def after_scenario(context, scenario) -> None:
if hasattr(context, 'after_scenario_fn'):
context.after_scenario_fn()

# Restore original QUALX_LABEL if it existed
if hasattr(context, 'original_qualx_label'):
if context.original_qualx_label is not None:
os.environ['QUALX_LABEL'] = context.original_qualx_label
else:
os.environ.pop('QUALX_LABEL', None)


def _set_verbose_mode(context) -> None:
verbose_enabled = getattr(context.config, 'verbose', False)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Feature: Testing preprocessing functionality
As a user of the preprocessing module
I want to ensure the preprocessing functions work correctly
So that I can reliably process Spark event logs

Background:
Given SPARK_RAPIDS_TOOLS_JAR environment variable is set
And SPARK_HOME environment variable is set
And QUALX_DATA_DIR environment variable is set
And QUALX_CACHE_DIR environment variable is set
And sample event logs in the QUALX_DATA_DIR
And dataset JSON files in the datasets directory

Scenario Outline: Test preprocessing with different QUALX_LABEL settings
Given QUALX_LABEL environment variable is set to "<label>"
When preprocessing the event logs
Then preprocessing should complete successfully
And preprocessed data should contain the expected features for label "<label>"

Examples:
| label |
| Duration |
| duration_sum |
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@
This module defines utility functions used by the end-to-end tests using behave.
"""

import glob
import logging
import os
import subprocess
Expand Down Expand Up @@ -111,6 +112,14 @@ def get_local_event_logs_dir(cls) -> str:
def get_spark_rapids_cli() -> str:
return os.path.join(os.environ['E2E_TEST_VENV_DIR'], 'bin', 'spark_rapids')

@staticmethod
def get_spark_home() -> str:
venv_path = os.environ['E2E_TEST_VENV_DIR']
spark_home = glob.glob(os.path.join(venv_path, 'lib', '*', 'site-packages', 'pyspark'))
if spark_home:
return spark_home[0]
raise RuntimeError("Spark home not found")

@staticmethod
def get_tools_jar_file() -> str:
return os.environ['E2E_TEST_TOOLS_JAR_PATH']
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import os
from behave import given, when, then
from spark_rapids_tools.tools.qualx.preprocess import (
load_datasets,
expected_raw_features
)
from e2e_utils import E2ETestUtils

# Get logger from E2ETestUtils
logger = E2ETestUtils.get_logger()


@given('SPARK_HOME environment variable is set')
def set_spark_home_env(context):
"""Set the SPARK_HOME environment variable using spark home directory."""
spark_home = E2ETestUtils.get_spark_home()
os.environ['SPARK_HOME'] = spark_home
assert 'SPARK_HOME' in os.environ


@given('SPARK_RAPIDS_TOOLS_JAR environment variable is set')
def set_tools_jar_env(context):
"""Set the SPARK_RAPIDS_TOOLS_JAR environment variable using tools jar file."""
tools_jar = E2ETestUtils.get_tools_jar_file()
os.environ['SPARK_RAPIDS_TOOLS_JAR'] = tools_jar
assert 'SPARK_RAPIDS_TOOLS_JAR' in os.environ
assert os.path.exists(os.environ['SPARK_RAPIDS_TOOLS_JAR'])


@given('QUALX_DATA_DIR environment variable is set')
def set_qualx_data_dir_env(context):
"""Set the QUALX_DATA_DIR environment variable using test resources directory."""
context.qualx_data_dir = E2ETestUtils.get_local_event_logs_dir()
os.environ['QUALX_DATA_DIR'] = context.qualx_data_dir
assert 'QUALX_DATA_DIR' in os.environ


@given('QUALX_CACHE_DIR environment variable is set')
def set_qualx_cache_dir_env(context):
"""Set the QUALX_CACHE_DIR environment variable using test resources directory."""
context.qualx_cache_dir = os.path.join(E2ETestUtils.get_e2e_tests_resource_path(), 'qualx_cache')
os.environ['QUALX_CACHE_DIR'] = context.qualx_cache_dir
assert 'QUALX_CACHE_DIR' in os.environ


@given('QUALX_LABEL environment variable is set to "{label}"')
def step_impl(context, label):
"""Set the QUALX_LABEL environment variable to the specified value."""
context.qualx_label = label
os.environ['QUALX_LABEL'] = label


@given('sample event logs in the QUALX_DATA_DIR')
def check_sample_event_logs(context):
"""Verify sample event logs exist in test resources."""
assert os.path.exists(context.qualx_data_dir)
event_logs = glob.glob(os.path.join(context.qualx_data_dir, '**', '*.zstd'), recursive=True)
assert len(event_logs) > 0, "No event logs found in the QUALX_DATA_DIR"


@given('dataset JSON files in the datasets directory')
def check_dataset_json(context):
"""Verify dataset JSON file exists in test resources."""
context.dataset_path = os.path.join(E2ETestUtils.get_e2e_tests_resource_path(), 'datasets')
dataset_json = glob.glob(os.path.join(context.dataset_path, '**', '*.json'), recursive=True)
assert len(dataset_json) > 0, "No dataset JSON files found in the datasets directory"


@when('preprocessing the event logs')
def load_and_preprocess_logs(context):
"""Load and preprocess the event logs."""
try:
context.datasets, context.profile_df = load_datasets(context.dataset_path)
context.preprocessing_success = True
except Exception as e:
context.preprocessing_success = False
context.preprocessing_error = str(e)


@then('preprocessing should complete successfully')
def verify_preprocessing_success(context):
"""Verify that preprocessing completed without errors."""
assert context.preprocessing_success, \
f"Preprocessing failed with error: {getattr(context, 'preprocessing_error', 'Unknown error')}"
assert context.datasets is not None, "Datasets dictionary should not be None"
assert not context.profile_df.empty, "Profile DataFrame should not be empty"
assert len(context.profile_df) == 194, "Profile DataFrame should have 194 rows"


@then('preprocessed data should contain the expected features for label "{label}"')
def verify_expected_features(context, label):
"""Verify that the preprocessed data contains all expected features for the given label."""
actual_features = set(context.profile_df.columns)
missing_features = expected_raw_features - actual_features
extra_features = actual_features - expected_raw_features

assert label in actual_features, f"Label {label} should be in the expected features"
assert not missing_features, f"Missing expected features: {missing_features}"
assert not extra_features, f"Found unexpected features: {extra_features}"
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"nds_local": {
"eventlogs": [
"${QUALX_DATA_DIR}/onprem/nds/power/eventlogs"
],
"app_meta": {
"app-20231122005806-0064": {"runType": "CPU", "scaleFactor": 10},
"app-20231122010741-0065": {"runType": "CPU", "scaleFactor": 30},
"app-20231031225845-0001": {"runType": "CPU", "scaleFactor": 100},
"app-20231122011610-0066": {"runType": "CPU", "scaleFactor": 300},
"app-20231122020415-0067": {"runType": "CPU", "scaleFactor": 1000},
"app-20231114200842-0001": {"runType": "GPU", "scaleFactor": 10},
"app-20231114202719-0004": {"runType": "GPU", "scaleFactor": 30},
"app-20231031225101-0000": {"runType": "GPU", "scaleFactor": 100},
"app-20231108165915-0039": {"runType": "GPU", "scaleFactor": 300},
"app-20231115192922-0016": {"runType": "GPU", "scaleFactor": 1000}
},
"split_function": "${QUALX_DIR}/plugins/nds.py"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Sample eventlogs

These eventlogs were produced by running the [NDS benchmarks](https://github.com/NVIDIA/spark-rapids-benchmarks) on both CPU and GPU versions of a Spark local cluster set up in an onprem environment. For simplicity, these are just copies of eventlogs used for training the qualx `onprem` model.
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion user_tools/tests/spark_rapids_tools_ut/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
15 changes: 15 additions & 0 deletions user_tools/tests/spark_rapids_tools_ut/qualx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""init file of the unit-tests package"""
47 changes: 47 additions & 0 deletions user_tools/tests/spark_rapids_tools_ut/qualx/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test qualx_config module"""
import pytest # pylint: disable=import-error
from spark_rapids_tools.tools.qualx.config import (
get_cache_dir,
get_label,
)
from ..conftest import SparkRapidsToolsUT


class TestConfig(SparkRapidsToolsUT):
"""Test class for qualx_config module"""
def test_get_cache_dir(self, monkeypatch):
# Test with mock environment variable
monkeypatch.setenv('QUALX_CACHE_DIR', 'test_cache')
assert get_cache_dir() == 'test_cache'

# Test without environment variable (should use default)
monkeypatch.delenv('QUALX_CACHE_DIR')
assert get_cache_dir() == 'qualx_cache'

def test_get_label(self, monkeypatch):
# Test with duration_sum
monkeypatch.setenv('QUALX_LABEL', 'duration_sum')
assert get_label() == 'duration_sum'

# Test with unsupported label
with pytest.raises(AssertionError):
monkeypatch.setenv('QUALX_LABEL', 'duration')
get_label()

# Test without environment variable (should use default)
monkeypatch.delenv('QUALX_LABEL')
assert get_label() == 'Duration'
Loading