Skip to content

Commit c61786e

Browse files
authored
Merge pull request #920 from ScrapeGraphAI/pre/beta
Pre/beta
2 parents 4b4efe0 + 948164f commit c61786e

File tree

3 files changed

+117
-2
lines changed

3 files changed

+117
-2
lines changed

tests/test_depth_search_graph.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from unittest.mock import patch, MagicMock
2+
from scrapegraphai.graphs.depth_search_graph import DepthSearchGraph
3+
from scrapegraphai.graphs.abstract_graph import AbstractGraph
4+
import pytest
5+
6+
7+
class TestDepthSearchGraph:
8+
"""Test suite for DepthSearchGraph class"""
9+
10+
@pytest.mark.parametrize(
11+
"source, expected_input_key",
12+
[
13+
("https://example.com", "url"),
14+
("/path/to/local/directory", "local_dir"),
15+
],
16+
)
17+
def test_depth_search_graph_initialization(self, source, expected_input_key):
18+
"""
19+
Test that DepthSearchGraph initializes correctly with different source types.
20+
This test verifies that the input_key is set to 'url' for web sources and
21+
'local_dir' for local directory sources.
22+
"""
23+
prompt = "Test prompt"
24+
config = {"llm": {"model": "mock_model"}}
25+
26+
# Mock both BaseGraph and _create_llm method
27+
with patch("scrapegraphai.graphs.depth_search_graph.BaseGraph"), \
28+
patch.object(AbstractGraph, '_create_llm', return_value=MagicMock()):
29+
graph = DepthSearchGraph(prompt, source, config)
30+
31+
assert graph.prompt == prompt
32+
assert graph.source == source
33+
assert graph.config == config
34+
assert graph.input_key == expected_input_key

tests/test_json_scraper_graph.py

+57-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from pydantic import BaseModel
3+
from pydantic import BaseModel, Field
44
from scrapegraphai.graphs.json_scraper_graph import JSONScraperGraph
55
from unittest.mock import Mock, patch
66

@@ -133,4 +133,60 @@ def test_json_scraper_graph_no_answer_found(self, mock_create_llm, mock_generate
133133
mock_execute.assert_called_once_with({"user_prompt": "Query that produces no answer", "json": "path/to/empty/file.json"})
134134
mock_fetch_node.assert_called_once()
135135
mock_generate_answer_node.assert_called_once()
136+
mock_create_llm.assert_called_once_with({"model": "test-model", "temperature": 0})
137+
138+
@pytest.fixture
139+
def mock_llm_model(self):
140+
return Mock()
141+
142+
@pytest.fixture
143+
def mock_embedder_model(self):
144+
return Mock()
145+
146+
@patch('scrapegraphai.graphs.json_scraper_graph.FetchNode')
147+
@patch('scrapegraphai.graphs.json_scraper_graph.GenerateAnswerNode')
148+
@patch.object(JSONScraperGraph, '_create_llm')
149+
def test_json_scraper_graph_with_custom_schema(self, mock_create_llm, mock_generate_answer_node, mock_fetch_node, mock_llm_model, mock_embedder_model):
150+
"""
151+
Test JSONScraperGraph with a custom schema.
152+
This test checks if the graph correctly handles a custom schema input
153+
and passes it to the GenerateAnswerNode.
154+
"""
155+
# Define a custom schema
156+
class CustomSchema(BaseModel):
157+
name: str = Field(..., description="Name of the attraction")
158+
description: str = Field(..., description="Description of the attraction")
159+
160+
# Mock the _create_llm method to return a mock LLM model
161+
mock_create_llm.return_value = mock_llm_model
162+
163+
# Mock the execute method of BaseGraph
164+
with patch('scrapegraphai.graphs.json_scraper_graph.BaseGraph.execute') as mock_execute:
165+
mock_execute.return_value = ({"answer": "Mocked answer with custom schema"}, {})
166+
167+
# Create a JSONScraperGraph instance with a custom schema
168+
graph = JSONScraperGraph(
169+
prompt="List attractions in Chioggia",
170+
source="path/to/chioggia.json",
171+
config={"llm": {"model": "test-model", "temperature": 0}},
172+
schema=CustomSchema
173+
)
174+
175+
# Set mocked embedder model
176+
graph.embedder_model = mock_embedder_model
177+
178+
# Run the graph
179+
result = graph.run()
180+
181+
# Assertions
182+
assert result == "Mocked answer with custom schema"
183+
assert graph.input_key == "json"
184+
mock_execute.assert_called_once_with({"user_prompt": "List attractions in Chioggia", "json": "path/to/chioggia.json"})
185+
mock_fetch_node.assert_called_once()
186+
mock_generate_answer_node.assert_called_once()
187+
188+
# Check if the custom schema was passed to GenerateAnswerNode
189+
generate_answer_node_call = mock_generate_answer_node.call_args[1]
190+
assert generate_answer_node_call['node_config']['schema'] == CustomSchema
191+
136192
mock_create_llm.assert_called_once_with({"model": "test-model", "temperature": 0})

tests/test_search_graph.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,29 @@ def test_max_results_config(self, mock_create_llm, mock_base_graph, mock_merge_a
7979
# Assert
8080
mock_search_internet.assert_called_once()
8181
call_args = mock_search_internet.call_args
82-
assert call_args.kwargs['node_config']['max_results'] == max_results
82+
assert call_args.kwargs['node_config']['max_results'] == max_results
83+
84+
@patch('scrapegraphai.graphs.search_graph.SearchInternetNode')
85+
@patch('scrapegraphai.graphs.search_graph.GraphIteratorNode')
86+
@patch('scrapegraphai.graphs.search_graph.MergeAnswersNode')
87+
@patch('scrapegraphai.graphs.search_graph.BaseGraph')
88+
@patch('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_llm')
89+
def test_custom_search_engine_config(self, mock_create_llm, mock_base_graph, mock_merge_answers, mock_graph_iterator, mock_search_internet):
90+
"""
91+
Test that the custom search_engine parameter from the config is correctly passed to the SearchInternetNode.
92+
"""
93+
# Arrange
94+
prompt = "Test prompt"
95+
custom_search_engine = "custom_engine"
96+
config = {
97+
"llm": {"model": "test-model"},
98+
"search_engine": custom_search_engine
99+
}
100+
101+
# Act
102+
search_graph = SearchGraph(prompt, config)
103+
104+
# Assert
105+
mock_search_internet.assert_called_once()
106+
call_args = mock_search_internet.call_args
107+
assert call_args.kwargs['node_config']['search_engine'] == custom_search_engine

0 commit comments

Comments
 (0)