Skip to content

Commit 71270ec

Browse files
Add WikipediaSearchTool to default tools (#514)
Co-authored-by: Aymeric Roucher <[email protected]>
1 parent 13a6293 commit 71270ec

File tree

3 files changed

+134
-1
lines changed

3 files changed

+134
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ test = [
8888
"python-dotenv>=1.0.1", # For test_all_docs
8989
"smolagents[all]",
9090
"rank-bm25", # For test_all_docs
91+
"Wikipedia-API>=0.8.1",
9192
]
9293
dev = [
9394
"smolagents[quality,test]",

src/smolagents/default_tools.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,102 @@ def forward(self, url: str) -> str:
261261
return f"An unexpected error occurred: {str(e)}"
262262

263263

264+
class WikipediaSearchTool(Tool):
265+
"""
266+
WikipediaSearchTool searches Wikipedia and returns a summary or full text of the given topic, along with the page URL.
267+
268+
Attributes:
269+
user_agent (str): A custom user-agent string to identify the project. This is required as per Wikipedia API policies, read more here: http://github.com/martin-majlis/Wikipedia-API/blob/master/README.rst
270+
language (str): The language in which to retrieve Wikipedia articles.
271+
http://meta.wikimedia.org/wiki/List_of_Wikipedias
272+
content_type (str): Defines the content to fetch. Can be "summary" for a short summary or "text" for the full article.
273+
extract_format (str): Defines the output format. Can be `"WIKI"` or `"HTML"`.
274+
275+
Example:
276+
>>> from smolagents import CodeAgent, HfApiModel, WikipediaSearchTool
277+
>>> agent = CodeAgent(
278+
>>> tools=[
279+
>>> WikipediaSearchTool(
280+
>>> user_agent="MyResearchBot ([email protected])",
281+
>>> language="en",
282+
>>> content_type="summary", # or "text"
283+
>>> extract_format="WIKI",
284+
>>> )
285+
>>> ],
286+
>>> model=HfApiModel(),
287+
>>> )
288+
>>> agent.run("Python_(programming_language)")
289+
"""
290+
291+
name = "wikipedia_search"
292+
description = "Searches Wikipedia and returns a summary or full text of the given topic, along with the page URL."
293+
inputs = {
294+
"query": {
295+
"type": "string",
296+
"description": "The topic to search on Wikipedia.",
297+
}
298+
}
299+
output_type = "string"
300+
301+
def __init__(
302+
self,
303+
user_agent: str = "Smolagents ([email protected])",
304+
language: str = "en",
305+
content_type: str = "text",
306+
extract_format: str = "WIKI",
307+
):
308+
super().__init__()
309+
try:
310+
import wikipediaapi
311+
except ImportError as e:
312+
raise ImportError(
313+
"You must install `wikipedia-api` to run this tool: for instance run `pip install wikipedia-api`"
314+
) from e
315+
if not user_agent:
316+
raise ValueError("User-agent is required. Provide a meaningful identifier for your project.")
317+
318+
self.user_agent = user_agent
319+
self.language = language
320+
self.content_type = content_type
321+
322+
# Map string format to wikipediaapi.ExtractFormat
323+
extract_format_map = {
324+
"WIKI": wikipediaapi.ExtractFormat.WIKI,
325+
"HTML": wikipediaapi.ExtractFormat.HTML,
326+
}
327+
328+
if extract_format not in extract_format_map:
329+
raise ValueError("Invalid extract_format. Choose between 'WIKI' or 'HTML'.")
330+
331+
self.extract_format = extract_format_map[extract_format]
332+
333+
self.wiki = wikipediaapi.Wikipedia(
334+
user_agent=self.user_agent, language=self.language, extract_format=self.extract_format
335+
)
336+
337+
def forward(self, query: str) -> str:
338+
try:
339+
page = self.wiki.page(query)
340+
341+
if not page.exists():
342+
return f"No Wikipedia page found for '{query}'. Try a different query."
343+
344+
title = page.title
345+
url = page.fullurl
346+
347+
if self.content_type == "summary":
348+
text = page.summary
349+
elif self.content_type == "text":
350+
text = page.text
351+
else:
352+
return "⚠️ Invalid `content_type`. Use either 'summary' or 'text'."
353+
354+
return f"✅ **Wikipedia Page:** {title}\n\n**Content:** {text}\n\n🔗 **Read more:** {url}"
355+
356+
except Exception as e:
357+
return f"Error fetching Wikipedia summary: {str(e)}"
358+
359+
264360
class SpeechToTextTool(PipelineTool):
265361
default_checkpoint = "openai/whisper-large-v3-turbo"
266362
description = "This is a tool that transcribes an audio into text. It returns the transcribed text."
@@ -312,5 +408,6 @@ def decode(self, outputs):
312408
"DuckDuckGoSearchTool",
313409
"GoogleSearchTool",
314410
"VisitWebpageTool",
411+
"WikipediaSearchTool",
315412
"SpeechToTextTool",
316413
]

tests/test_default_tools.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
import pytest
1818

1919
from smolagents.agent_types import _AGENT_TYPE_MAPPING
20-
from smolagents.default_tools import DuckDuckGoSearchTool, PythonInterpreterTool, SpeechToTextTool, VisitWebpageTool
20+
from smolagents.default_tools import (
21+
DuckDuckGoSearchTool,
22+
PythonInterpreterTool,
23+
SpeechToTextTool,
24+
VisitWebpageTool,
25+
WikipediaSearchTool,
26+
)
2127

2228
from .test_tools import ToolTesterMixin
2329

@@ -87,3 +93,32 @@ def test_new_instance(self):
8793
assert tool is not None
8894
assert tool.pre_processor_class == WhisperProcessor
8995
assert tool.model_class == WhisperForConditionalGeneration
96+
97+
98+
@pytest.mark.parametrize(
99+
"language, content_type, extract_format, query",
100+
[
101+
("en", "summary", "HTML", "Python_(programming_language)"), # English, Summary Mode, HTML format
102+
("en", "text", "WIKI", "Python_(programming_language)"), # English, Full Text Mode, WIKI format
103+
("es", "summary", "HTML", "Python_(lenguaje_de_programación)"), # Spanish, Summary Mode, HTML format
104+
("es", "text", "WIKI", "Python_(lenguaje_de_programación)"), # Spanish, Full Text Mode, WIKI format
105+
],
106+
)
107+
def test_wikipedia_search(language, content_type, extract_format, query):
108+
tool = WikipediaSearchTool(
109+
user_agent="TestAgent ([email protected])",
110+
language=language,
111+
content_type=content_type,
112+
extract_format=extract_format,
113+
)
114+
115+
result = tool.forward(query)
116+
117+
assert isinstance(result, str), "Output should be a string"
118+
assert "✅ **Wikipedia Page:**" in result, "Response should contain Wikipedia page title"
119+
assert "🔗 **Read more:**" in result, "Response should contain Wikipedia page URL"
120+
121+
if content_type == "summary":
122+
assert len(result.split()) < 1000, "Summary mode should return a shorter text"
123+
if content_type == "text":
124+
assert len(result.split()) > 1000, "Full text mode should return a longer text"

0 commit comments

Comments
 (0)