Skip to content

Commit 2d99f59

Browse files
committed
test: added test for generate_sbom function
Signed-off-by: Meet Soni <[email protected]>
1 parent 8e1b878 commit 2d99f59

File tree

2 files changed

+73
-6
lines changed

2 files changed

+73
-6
lines changed

cve_bin_tool/output_engine/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,7 @@ def __init__(
712712
self.sbom_format = sbom_format
713713
self.sbom_root = sbom_root
714714
self.offline = offline
715+
self.sbom_packages = {}
715716

716717
def output_cves(self, outfile, output_type="console"):
717718
"""Output a list of CVEs
@@ -913,7 +914,6 @@ def generate_sbom(
913914
):
914915
"""Create SBOM package and generate SBOM file."""
915916
# Create SBOM
916-
sbom_packages = {}
917917
sbom_relationships = []
918918
my_package = SBOMPackage()
919919
sbom_relationship = SBOMRelationship()
@@ -930,7 +930,7 @@ def generate_sbom(
930930
my_package.set_licenseconcluded(license)
931931
my_package.set_supplier("UNKNOWN", "NOASSERTION")
932932
# Store package data
933-
sbom_packages[(my_package.get_name(), my_package.get_value("version"))] = (
933+
self.sbom_packages[(my_package.get_name(), my_package.get_value("version"))] = (
934934
my_package.get_package()
935935
)
936936
sbom_relationship.initialise()
@@ -941,16 +941,16 @@ def generate_sbom(
941941
my_package.initialise()
942942
my_package.set_name(product_data.product)
943943
my_package.set_version(product_data.version)
944-
if product_data.vendor != "UNKNOWN":
944+
if product_data.vendor.casefold() != "UNKNOWN".casefold():
945945
my_package.set_supplier("Organization", product_data.vendor)
946946
my_package.set_licensedeclared(license)
947947
my_package.set_licenseconcluded(license)
948948
if not (
949949
(my_package.get_name(), my_package.get_value("version"))
950-
in sbom_packages
950+
in self.sbom_packages
951951
and product_data.vendor == "unknown"
952952
):
953-
sbom_packages[
953+
self.sbom_packages[
954954
(my_package.get_name(), my_package.get_value("version"))
955955
] = my_package.get_package()
956956
sbom_relationship.initialise()
@@ -961,7 +961,7 @@ def generate_sbom(
961961

962962
# Generate SBOM
963963
my_sbom = SBOM()
964-
my_sbom.add_packages(sbom_packages)
964+
my_sbom.add_packages(self.sbom_packages)
965965
my_sbom.add_relationships(sbom_relationships)
966966
my_generator = SBOMGenerator(
967967
sbom_type=sbom_type,

test/test_output_engine.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import unittest
1313
from datetime import datetime
1414
from pathlib import Path
15+
from unittest.mock import MagicMock, call, patch
1516

1617
from jsonschema import validate
1718
from rich.console import Console
@@ -1101,6 +1102,10 @@ class TestOutputEngine(unittest.TestCase):
11011102
]
11021103

11031104
def setUp(self) -> None:
1105+
self.all_product_data = [
1106+
ProductInfo(product="product1", version="1.0", vendor="VendorA"),
1107+
ProductInfo(product="product2", version="2.0", vendor="unknown"),
1108+
]
11041109
self.output_engine = OutputEngine(
11051110
all_cve_data=self.MOCK_OUTPUT,
11061111
scanned_dir="",
@@ -1111,6 +1116,68 @@ def setUp(self) -> None:
11111116
)
11121117
self.mock_file = tempfile.NamedTemporaryFile("w+", encoding="utf-8")
11131118

1119+
def test_generate_sbom(self):
1120+
with patch(
1121+
"cve_bin_tool.output_engine.SBOMPackage"
1122+
) as mock_sbom_package, patch("cve_bin_tool.output_engine.SBOMRelationship"):
1123+
mock_package_instance = MagicMock()
1124+
mock_sbom_package.return_value = mock_package_instance
1125+
1126+
self.output_engine.generate_sbom(
1127+
all_product_data=self.all_product_data,
1128+
filename="test.sbom",
1129+
sbom_type="spdx",
1130+
sbom_format="tag",
1131+
sbom_root="CVE-SCAN",
1132+
)
1133+
1134+
# Assertions
1135+
mock_package_instance.set_name.assert_any_call("CVEBINTOOL-CVE-SCAN")
1136+
1137+
# Check if set_name is called for each product
1138+
expected_calls = [
1139+
call(product.product) for product in self.all_product_data
1140+
]
1141+
mock_package_instance.set_name.assert_has_calls(
1142+
expected_calls, any_order=True
1143+
)
1144+
1145+
# Check if set_version is called for each product
1146+
expected_calls = [
1147+
call(product.version) for product in self.all_product_data
1148+
]
1149+
mock_package_instance.set_version.assert_has_calls(
1150+
expected_calls, any_order=True
1151+
)
1152+
1153+
# Check if set_supplier is called for VendorA
1154+
mock_package_instance.set_supplier.assert_any_call(
1155+
"Organization", "VendorA"
1156+
)
1157+
1158+
for call_args in mock_package_instance.set_supplier.call_args_list:
1159+
args, _ = call_args
1160+
self.assertNotEqual(args, ("Organization", "unknown"))
1161+
1162+
# Check if set_licensedeclared and set_licenseconcluded are called for each product
1163+
expected_calls = [call("NOASSERTION")] * len(self.all_product_data)
1164+
mock_package_instance.set_licensedeclared.assert_has_calls(
1165+
expected_calls, any_order=True
1166+
)
1167+
mock_package_instance.set_licenseconcluded.assert_has_calls(
1168+
expected_calls, any_order=True
1169+
)
1170+
1171+
# Ensure packages are added to sbom_packages correctly
1172+
expected_packages = {
1173+
mock_package_instance.get_package.return_value,
1174+
mock_package_instance.get_package.return_value,
1175+
}
1176+
actual_packages = [
1177+
package for package in self.output_engine.sbom_packages.values()
1178+
]
1179+
self.assertEqual(actual_packages, list(expected_packages))
1180+
11141181
def tearDown(self) -> None:
11151182
self.mock_file.close()
11161183

0 commit comments

Comments
 (0)