Skip to content

Add asyncio support for LLM (OpenAI), Chain (LLMChain, LLMMathChain), and Agent #841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
396 changes: 396 additions & 0 deletions docs/async/async_primitives.ipynb

Large diffs are not rendered by default.

133 changes: 129 additions & 4 deletions langchain/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Chain that takes in an input and produces an action and action input."""
from __future__ import annotations

import asyncio
import json
import logging
from abc import abstractmethod
Expand Down Expand Up @@ -71,6 +72,19 @@ def _get_next_action(self, full_inputs: Dict[str, str]) -> AgentAction:
tool=parsed_output[0], tool_input=parsed_output[1], log=full_output
)

async def _aget_next_action(self, full_inputs: Dict[str, str]) -> AgentAction:
full_output = await self.llm_chain.apredict(**full_inputs)
parsed_output = self._extract_tool_and_input(full_output)
while parsed_output is None:
full_output = self._fix_text(full_output)
full_inputs["agent_scratchpad"] += full_output
output = await self.llm_chain.apredict(**full_inputs)
full_output += output
parsed_output = self._extract_tool_and_input(full_output)
return AgentAction(
tool=parsed_output[0], tool_input=parsed_output[1], log=full_output
)

def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Union[AgentAction, AgentFinish]:
Expand All @@ -84,15 +98,40 @@ def plan(
Returns:
Action specifying what tool to use.
"""
thoughts = self._construct_scratchpad(intermediate_steps)
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
full_inputs = {**kwargs, **new_inputs}

full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
action = self._get_next_action(full_inputs)
if action.tool == self.finish_tool_name:
return AgentFinish({"output": action.tool_input}, action.log)
return action

async def aplan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.

Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
**kwargs: User inputs.

Returns:
Action specifying what tool to use.
"""
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
action = await self._aget_next_action(full_inputs)
if action.tool == self.finish_tool_name:
return AgentFinish({"output": action.tool_input}, action.log)
return action

def get_full_inputs(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
) -> Dict[str, Any]:
"""Create the full inputs for the LLMChain from intermediate steps."""
thoughts = self._construct_scratchpad(intermediate_steps)
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
full_inputs = {**kwargs, **new_inputs}
return full_inputs

def prepare_for_new_call(self) -> None:
"""Prepare the agent for new call, if needed."""
pass
Expand Down Expand Up @@ -338,6 +377,14 @@ def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, An

def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
"""Run text through and get agent response."""
# Make sure that every tool is synchronous (not a coroutine)
for tool in self.tools:
if asyncio.iscoroutinefunction(tool.func):
raise ValueError(
"Tools cannot be asynchronous for `run` method. "
"Please use `arun` instead."
)

# Do any preparation necessary when receiving a new input.
self.agent.prepare_for_new_call()
# Construct a mapping of tool name to tool for easy lookup
Expand Down Expand Up @@ -399,3 +446,81 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
self.early_stopping_method, intermediate_steps, **inputs
)
return self._return(output, intermediate_steps)

async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
"""Run text through and get agent response."""
# Make sure that every tool is asynchronous (a coroutine)
for tool in self.tools:
if tool.coroutine and not asyncio.iscoroutinefunction(tool.coroutine):
raise ValueError(
"The coroutine for the tool must be a coroutine function."
)

# Do any preparation necessary when receiving a new input.
self.agent.prepare_for_new_call()
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools}
# We construct a mapping from each tool to a color, used for logging.
color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green"]
)
intermediate_steps: List[Tuple[AgentAction, str]] = []
# Let's start tracking the iterations the agent has gone through
iterations = 0
# We now enter the agent loop (until it returns something).
while self._should_continue(iterations):
# Call the LLM to see what to do.
output = await self.agent.aplan(intermediate_steps, **inputs)
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
return self._return(output, intermediate_steps)

# Otherwise we lookup the tool
if output.tool in name_to_tool_map:
tool = name_to_tool_map[output.tool]
self.callback_manager.on_tool_start(
{"name": str(tool.func)[:60] + "..."},
output,
verbose=self.verbose,
)
try:
# We then call the tool on the tool input to get an observation
observation = (
await tool.coroutine(output.tool_input)
if tool.coroutine
# If the tool is not a coroutine, we run it in the executor
# to avoid blocking the event loop.
else await asyncio.get_event_loop().run_in_executor(
None, tool.func, output.tool_input
)
)
color = color_mapping[output.tool]
return_direct = tool.return_direct
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_tool_error(e, verbose=self.verbose)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the callback manager have async methods? If I'm using an async agent I might also want to use an async callback handler

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rn, callbacks are all synchronous but on our roadmap to change soon

raise e
else:
self.callback_manager.on_tool_start(
{"name": "N/A"}, output, verbose=self.verbose
)
observation = f"{output.tool} is not a valid tool, try another one."
color = None
return_direct = False
llm_prefix = "" if return_direct else self.agent.llm_prefix
self.callback_manager.on_tool_end(
observation,
color=color,
observation_prefix=self.agent.observation_prefix,
llm_prefix=llm_prefix,
verbose=self.verbose,
)
intermediate_steps.append((output, observation))
if return_direct:
# Set the log to "" because we do not want to log it.
output = AgentFinish({self.agent.return_values[0]: observation}, "")
return self._return(output, intermediate_steps)
iterations += 1
output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
return self._return(output, intermediate_steps)
16 changes: 9 additions & 7 deletions langchain/agents/load_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ def _get_pal_colored_objects(llm: BaseLLM) -> Tool:

def _get_llm_math(llm: BaseLLM) -> Tool:
return Tool(
"Calculator",
LLMMathChain(llm=llm).run,
"Useful for when you need to answer questions about math.",
name="Calculator",
description="Useful for when you need to answer questions about math.",
func=LLMMathChain(llm=llm, callback_manager=llm.callback_manager).run,
coroutine=LLMMathChain(llm=llm, callback_manager=llm.callback_manager).arun,
)


Expand Down Expand Up @@ -132,9 +133,10 @@ def _get_google_search(**kwargs: Any) -> Tool:

def _get_serpapi(**kwargs: Any) -> Tool:
return Tool(
"Search",
SerpAPIWrapper(**kwargs).run,
"A search engine. Useful for when you need to answer questions about current events. Input should be a search query.",
name="Search",
description="A search engine. Useful for when you need to answer questions about current events. Input should be a search query.",
func=SerpAPIWrapper(**kwargs).run,
coroutine=SerpAPIWrapper(**kwargs).arun,
)


Expand All @@ -145,7 +147,7 @@ def _get_serpapi(**kwargs: Any) -> Tool:
_EXTRA_OPTIONAL_TOOLS = {
"wolfram-alpha": (_get_wolfram_alpha, ["wolfram_alpha_appid"]),
"google-search": (_get_google_search, ["google_api_key", "google_cse_id"]),
"serpapi": (_get_serpapi, ["serpapi_api_key"]),
"serpapi": (_get_serpapi, ["serpapi_api_key", "aiosession"]),
}


Expand Down
7 changes: 6 additions & 1 deletion langchain/agents/tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Interface for tools."""
import asyncio
from dataclasses import dataclass
from inspect import signature
from typing import Any, Callable, Optional, Union
from typing import Any, Awaitable, Callable, Optional, Union


@dataclass
Expand All @@ -12,9 +13,13 @@ class Tool:
func: Callable[[str], str]
description: Optional[str] = None
return_direct: bool = False
# If the tool has a coroutine, then we can use this to run it asynchronously
coroutine: Optional[Callable[[str], Awaitable[str]]] = None

def __call__(self, *args: Any, **kwargs: Any) -> str:
"""Make tools callable by piping through to `func`."""
if asyncio.iscoroutinefunction(self.func):
raise TypeError("Coroutine cannot be called directly")
return self.func(*args, **kwargs)


Expand Down
12 changes: 8 additions & 4 deletions langchain/callbacks/stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
class StdOutCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""

def __init__(self, color: str = "green") -> None:
"""Initialize callback handler."""
self.color = color

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
Expand Down Expand Up @@ -50,7 +54,7 @@ def on_tool_start(
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
print_text(action.log, color=color)
print_text(action.log, color=color if color else self.color)

def on_tool_end(
self,
Expand All @@ -62,7 +66,7 @@ def on_tool_end(
) -> None:
"""If not the final action, print out observation."""
print_text(f"\n{observation_prefix}")
print_text(output, color=color)
print_text(output, color=color if color else self.color)
print_text(f"\n{llm_prefix}")

def on_tool_error(
Expand All @@ -79,10 +83,10 @@ def on_text(
**kwargs: Optional[str],
) -> None:
"""Run when agent ends."""
print_text(text, color=color, end=end)
print_text(text, color=color if color else self.color, end=end)

def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
print_text(finish.log, color=color, end="\n")
print_text(finish.log, color=color if self.color else color, end="\n")
Loading