Skip to content

Commit bf7372f

Browse files
authored
Adding Autocomplete to OSS (crewAIInc#1198)
* Cleaned up model_config * Fix pydantic issues * 99% done with autocomplete * fixed test issues * Fix type checking issues
1 parent 3451b6f commit bf7372f

File tree

14 files changed

+109
-121
lines changed

14 files changed

+109
-121
lines changed

src/crewai/agent.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,11 @@ class Agent(BaseAgent):
113113
description="Maximum number of retries for an agent to execute a task when an error occurs.",
114114
)
115115

116-
def __init__(__pydantic_self__, **data):
117-
config = data.pop("config", {})
118-
super().__init__(**config, **data)
119-
__pydantic_self__.agent_ops_agent_name = __pydantic_self__.role
116+
@model_validator(mode="after")
117+
def set_agent_ops_agent_name(self) -> "Agent":
118+
"""Set agent ops agent name."""
119+
self.agent_ops_agent_name = self.role
120+
return self
120121

121122
@model_validator(mode="after")
122123
def set_agent_executor(self) -> "Agent":
@@ -213,7 +214,7 @@ def execute_task(
213214
raise e
214215
result = self.execute_task(task, context, tools)
215216

216-
if self.max_rpm:
217+
if self.max_rpm and self._rpm_controller:
217218
self._rpm_controller.stop_rpm_counter()
218219

219220
# If there was any tool in self.tools_results that had result_as_answer

src/crewai/agents/agent_builder/base_agent.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pydantic import (
88
UUID4,
99
BaseModel,
10-
ConfigDict,
1110
Field,
1211
InstanceOf,
1312
PrivateAttr,
@@ -74,12 +73,17 @@ class BaseAgent(ABC, BaseModel):
7473
"""
7574

7675
__hash__ = object.__hash__ # type: ignore
77-
_logger: Logger = PrivateAttr()
78-
_rpm_controller: RPMController = PrivateAttr(default=None)
76+
_logger: Logger = PrivateAttr(default_factory=lambda: Logger(verbose=False))
77+
_rpm_controller: Optional[RPMController] = PrivateAttr(default=None)
7978
_request_within_rpm_limit: Any = PrivateAttr(default=None)
80-
formatting_errors: int = 0
81-
model_config = ConfigDict(arbitrary_types_allowed=True)
79+
_original_role: Optional[str] = PrivateAttr(default=None)
80+
_original_goal: Optional[str] = PrivateAttr(default=None)
81+
_original_backstory: Optional[str] = PrivateAttr(default=None)
82+
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
8283
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
84+
formatting_errors: int = Field(
85+
default=0, description="Number of formatting errors."
86+
)
8387
role: str = Field(description="Role of the agent")
8488
goal: str = Field(description="Objective of the agent")
8589
backstory: str = Field(description="Backstory of the agent")
@@ -123,15 +127,6 @@ class BaseAgent(ABC, BaseModel):
123127
default=None, description="Maximum number of tokens for the agent's execution."
124128
)
125129

126-
_original_role: str | None = None
127-
_original_goal: str | None = None
128-
_original_backstory: str | None = None
129-
_token_process: TokenProcess = TokenProcess()
130-
131-
def __init__(__pydantic_self__, **data):
132-
config = data.pop("config", {})
133-
super().__init__(**config, **data)
134-
135130
@model_validator(mode="after")
136131
def set_config_attributes(self):
137132
if self.config:

src/crewai/agents/cache/cache_handler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
from typing import Optional
1+
from typing import Any, Dict, Optional
22

3+
from pydantic import BaseModel, PrivateAttr
34

4-
class CacheHandler:
5-
"""Callback handler for tool usage."""
65

7-
_cache: dict = {}
6+
class CacheHandler(BaseModel):
7+
"""Callback handler for tool usage."""
88

9-
def __init__(self):
10-
self._cache = {}
9+
_cache: Dict[str, Any] = PrivateAttr(default_factory=dict)
1110

1211
def add(self, tool, input, output):
1312
self._cache[f"{tool}-{input}"] = output

src/crewai/crew.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from pydantic import (
1111
UUID4,
1212
BaseModel,
13-
ConfigDict,
1413
Field,
1514
InstanceOf,
1615
Json,
@@ -105,7 +104,6 @@ class Crew(BaseModel):
105104

106105
name: Optional[str] = Field(default=None)
107106
cache: bool = Field(default=True)
108-
model_config = ConfigDict(arbitrary_types_allowed=True)
109107
tasks: List[Task] = Field(default_factory=list)
110108
agents: List[BaseAgent] = Field(default_factory=list)
111109
process: Process = Field(default=Process.sequential)

src/crewai/project/crew_base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44

55
import yaml
66
from dotenv import load_dotenv
7-
from pydantic import ConfigDict
87

98
load_dotenv()
109

1110

1211
def CrewBase(cls):
1312
class WrappedClass(cls):
14-
model_config = ConfigDict(arbitrary_types_allowed=True)
1513
is_crew_class: bool = True # type: ignore
1614

1715
# Get the directory of the class being decorated

src/crewai/project/pipeline_base.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,40 @@
1-
from typing import Callable, Dict
2-
3-
from pydantic import ConfigDict
1+
from typing import Any, Callable, Dict, List, Type, Union
42

53
from crewai.crew import Crew
64
from crewai.pipeline.pipeline import Pipeline
75
from crewai.routers.router import Router
86

7+
PipelineStage = Union[Crew, List[Crew], Router]
8+
99

1010
# TODO: Could potentially remove. Need to check with @joao and @gui if this is needed for CrewAI+
11-
def PipelineBase(cls):
11+
def PipelineBase(cls: Type[Any]) -> Type[Any]:
1212
class WrappedClass(cls):
13-
model_config = ConfigDict(arbitrary_types_allowed=True)
1413
is_pipeline_class: bool = True # type: ignore
14+
stages: List[PipelineStage]
1515

16-
def __init__(self, *args, **kwargs):
16+
def __init__(self, *args: Any, **kwargs: Any) -> None:
1717
super().__init__(*args, **kwargs)
1818
self.stages = []
1919
self._map_pipeline_components()
2020

21-
def _get_all_functions(self):
21+
def _get_all_functions(self) -> Dict[str, Callable[..., Any]]:
2222
return {
2323
name: getattr(self, name)
2424
for name in dir(self)
2525
if callable(getattr(self, name))
2626
}
2727

2828
def _filter_functions(
29-
self, functions: Dict[str, Callable], attribute: str
30-
) -> Dict[str, Callable]:
29+
self, functions: Dict[str, Callable[..., Any]], attribute: str
30+
) -> Dict[str, Callable[..., Any]]:
3131
return {
3232
name: func
3333
for name, func in functions.items()
3434
if hasattr(func, attribute)
3535
}
3636

37-
def _map_pipeline_components(self):
37+
def _map_pipeline_components(self) -> None:
3838
all_functions = self._get_all_functions()
3939
crew_functions = self._filter_functions(all_functions, "is_crew")
4040
router_functions = self._filter_functions(all_functions, "is_router")

src/crewai/routers/router.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,43 @@
11
from copy import deepcopy
2-
from typing import Any, Callable, Dict, Generic, Tuple, TypeVar
2+
from typing import Any, Callable, Dict, Tuple
33

44
from pydantic import BaseModel, Field, PrivateAttr
55

6-
T = TypeVar("T", bound=Dict[str, Any])
7-
U = TypeVar("U")
86

7+
class Route(BaseModel):
8+
condition: Callable[[Dict[str, Any]], bool]
9+
pipeline: Any
910

10-
class Route(Generic[T, U]):
11-
condition: Callable[[T], bool]
12-
pipeline: U
1311

14-
def __init__(self, condition: Callable[[T], bool], pipeline: U):
15-
self.condition = condition
16-
self.pipeline = pipeline
17-
18-
19-
class Router(BaseModel, Generic[T, U]):
20-
routes: Dict[str, Route[T, U]] = Field(
12+
class Router(BaseModel):
13+
routes: Dict[str, Route] = Field(
2114
default_factory=dict,
2215
description="Dictionary of route names to (condition, pipeline) tuples",
2316
)
24-
default: U = Field(..., description="Default pipeline if no conditions are met")
17+
default: Any = Field(..., description="Default pipeline if no conditions are met")
2518
_route_types: Dict[str, type] = PrivateAttr(default_factory=dict)
2619

27-
model_config = {"arbitrary_types_allowed": True}
20+
class Config:
21+
arbitrary_types_allowed = True
2822

29-
def __init__(self, routes: Dict[str, Route[T, U]], default: U, **data):
23+
def __init__(self, routes: Dict[str, Route], default: Any, **data):
3024
super().__init__(routes=routes, default=default, **data)
3125
self._check_copyable(default)
3226
for name, route in routes.items():
3327
self._check_copyable(route.pipeline)
3428
self._route_types[name] = type(route.pipeline)
3529

3630
@staticmethod
37-
def _check_copyable(obj):
31+
def _check_copyable(obj: Any) -> None:
3832
if not hasattr(obj, "copy") or not callable(getattr(obj, "copy")):
3933
raise ValueError(f"Object of type {type(obj)} must have a 'copy' method")
4034

4135
def add_route(
4236
self,
4337
name: str,
44-
condition: Callable[[T], bool],
45-
pipeline: U,
46-
) -> "Router[T, U]":
38+
condition: Callable[[Dict[str, Any]], bool],
39+
pipeline: Any,
40+
) -> "Router":
4741
"""
4842
Add a named route with its condition and corresponding pipeline to the router.
4943
@@ -60,7 +54,7 @@ def add_route(
6054
self._route_types[name] = type(pipeline)
6155
return self
6256

63-
def route(self, input_data: T) -> Tuple[U, str]:
57+
def route(self, input_data: Dict[str, Any]) -> Tuple[Any, str]:
6458
"""
6559
Evaluate the input against the conditions and return the appropriate pipeline.
6660
@@ -76,15 +70,15 @@ def route(self, input_data: T) -> Tuple[U, str]:
7670

7771
return self.default, "default"
7872

79-
def copy(self) -> "Router[T, U]":
73+
def copy(self) -> "Router":
8074
"""Create a deep copy of the Router."""
8175
new_routes = {
8276
name: Route(
8377
condition=deepcopy(route.condition),
84-
pipeline=route.pipeline.copy(), # type: ignore
78+
pipeline=route.pipeline.copy(),
8579
)
8680
for name, route in self.routes.items()
8781
}
88-
new_default = self.default.copy() # type: ignore
82+
new_default = self.default.copy()
8983

9084
return Router(routes=new_routes, default=new_default)

src/crewai/task.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@
99
from typing import Any, Dict, List, Optional, Tuple, Type, Union
1010

1111
from opentelemetry.trace import Span
12-
from pydantic import UUID4, BaseModel, Field, field_validator, model_validator
12+
from pydantic import (
13+
UUID4,
14+
BaseModel,
15+
Field,
16+
PrivateAttr,
17+
field_validator,
18+
model_validator,
19+
)
1320
from pydantic_core import PydanticCustomError
1421

1522
from crewai.agents.agent_builder.base_agent import BaseAgent
@@ -39,9 +46,6 @@ class Task(BaseModel):
3946
tools: List of tools/resources limited for task execution.
4047
"""
4148

42-
class Config:
43-
arbitrary_types_allowed = True
44-
4549
__hash__ = object.__hash__ # type: ignore
4650
used_tools: int = 0
4751
tools_errors: int = 0
@@ -104,16 +108,12 @@ class Config:
104108
default=None,
105109
)
106110

107-
_telemetry: Telemetry
108-
_execution_span: Span | None = None
109-
_original_description: str | None = None
110-
_original_expected_output: str | None = None
111-
_thread: threading.Thread | None = None
112-
_execution_time: float | None = None
113-
114-
def __init__(__pydantic_self__, **data):
115-
config = data.pop("config", {})
116-
super().__init__(**config, **data)
111+
_telemetry: Telemetry = PrivateAttr(default_factory=Telemetry)
112+
_execution_span: Optional[Span] = PrivateAttr(default=None)
113+
_original_description: Optional[str] = PrivateAttr(default=None)
114+
_original_expected_output: Optional[str] = PrivateAttr(default=None)
115+
_thread: Optional[threading.Thread] = PrivateAttr(default=None)
116+
_execution_time: Optional[float] = PrivateAttr(default=None)
117117

118118
@field_validator("id", mode="before")
119119
@classmethod
@@ -137,12 +137,6 @@ def output_file_validation(cls, value: str) -> str:
137137
return value[1:]
138138
return value
139139

140-
@model_validator(mode="after")
141-
def set_private_attrs(self) -> "Task":
142-
"""Set private attributes."""
143-
self._telemetry = Telemetry()
144-
return self
145-
146140
@model_validator(mode="after")
147141
def set_attributes_based_on_config(self) -> "Task":
148142
"""Set attributes based on the agent configuration."""
@@ -263,9 +257,7 @@ def _execute_core(
263257
content = (
264258
json_output
265259
if json_output
266-
else pydantic_output.model_dump_json()
267-
if pydantic_output
268-
else result
260+
else pydantic_output.model_dump_json() if pydantic_output else result
269261
)
270262
self._save_file(content)
271263

src/crewai/tools/cache_tools.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
from langchain.tools import StructuredTool
2-
from pydantic import BaseModel, ConfigDict, Field
2+
from pydantic import BaseModel, Field
33

44
from crewai.agents.cache import CacheHandler
55

66

77
class CacheTools(BaseModel):
88
"""Default tools to hit the cache."""
99

10-
model_config = ConfigDict(arbitrary_types_allowed=True)
1110
name: str = "Hit Cache"
1211
cache_handler: CacheHandler = Field(
1312
description="Cache Handler for the crew",
14-
default=CacheHandler(),
13+
default_factory=CacheHandler,
1514
)
1615

1716
def tool(self):

src/crewai/utilities/logger.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from datetime import datetime
22

3-
from crewai.utilities.printer import Printer
3+
from pydantic import BaseModel, Field, PrivateAttr
44

5+
from crewai.utilities.printer import Printer
56

6-
class Logger:
7-
_printer = Printer()
87

9-
def __init__(self, verbose=False):
10-
self.verbose = verbose
8+
class Logger(BaseModel):
9+
verbose: bool = Field(default=False)
10+
_printer: Printer = PrivateAttr(default_factory=Printer)
1111

1212
def log(self, level, message, color="bold_green"):
1313
if self.verbose:

0 commit comments

Comments
 (0)