Skip to content

Commit 2004427

Browse files
authored
tools fix and formatting (#2441)
1 parent 2517ccd commit 2004427

18 files changed

+536
-1151
lines changed

mem0/client/main.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -618,18 +618,16 @@ def delete_webhook(self, webhook_id: int) -> Dict[str, str]:
618618
return response.json()
619619

620620
@api_error_handler
621-
def feedback(self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None) -> Dict[str, str]:
621+
def feedback(
622+
self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None
623+
) -> Dict[str, str]:
622624
VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"}
623625

624626
feedback = feedback.upper() if feedback else None
625627
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
626628
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
627629

628-
data = {
629-
"memory_id": memory_id,
630-
"feedback": feedback,
631-
"feedback_reason": feedback_reason
632-
}
630+
data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}
633631

634632
response = self.client.post("/v1/feedback/", json=data)
635633
response.raise_for_status()
@@ -1019,20 +1017,18 @@ async def delete_webhook(self, webhook_id: int) -> Dict[str, str]:
10191017
return response.json()
10201018

10211019
@api_error_handler
1022-
async def feedback(self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None) -> Dict[str, str]:
1020+
async def feedback(
1021+
self, memory_id: str, feedback: Optional[str] = None, feedback_reason: Optional[str] = None
1022+
) -> Dict[str, str]:
10231023
VALID_FEEDBACK_VALUES = {"POSITIVE", "NEGATIVE", "VERY_NEGATIVE"}
10241024

10251025
feedback = feedback.upper() if feedback else None
10261026
if feedback is not None and feedback not in VALID_FEEDBACK_VALUES:
10271027
raise ValueError(f'feedback must be one of {", ".join(VALID_FEEDBACK_VALUES)} or None')
10281028

1029-
data = {
1030-
"memory_id": memory_id,
1031-
"feedback": feedback,
1032-
"feedback_reason": feedback_reason
1033-
}
1029+
data = {"memory_id": memory_id, "feedback": feedback, "feedback_reason": feedback_reason}
10341030

10351031
response = await self.async_client.post("/v1/feedback/", json=data)
10361032
response.raise_for_status()
10371033
capture_client_event("async_client.feedback", self.sync_client, data)
1038-
return response.json()
1034+
return response.json()

mem0/configs/prompts.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,12 @@
208208
}
209209
"""
210210

211+
211212
def get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt=None):
212213
if custom_update_memory_prompt is None:
213214
global DEFAULT_UPDATE_MEMORY_PROMPT
214215
custom_update_memory_prompt = DEFAULT_UPDATE_MEMORY_PROMPT
215-
216+
216217
return f"""{custom_update_memory_prompt}
217218
218219
Below is the current content of my memory which I have collected till now. You have to update it in the following format only:
@@ -250,4 +251,4 @@ def get_update_memory_messages(retrieved_old_memory_dict, response_content, cust
250251
- If there is an update, the ID key should remain the same and only the value needs to be updated.
251252
252253
Do not return anything except the JSON format.
253-
"""
254+
"""

mem0/embeddings/configs.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,16 @@ class EmbedderConfig(BaseModel):
1313
@field_validator("config")
1414
def validate_config(cls, v, values):
1515
provider = values.data.get("provider")
16-
if provider in ["openai", "ollama", "huggingface", "azure_openai", "gemini", "vertexai", "together", "lmstudio"]:
16+
if provider in [
17+
"openai",
18+
"ollama",
19+
"huggingface",
20+
"azure_openai",
21+
"gemini",
22+
"vertexai",
23+
"together",
24+
"lmstudio",
25+
]:
1726
return v
1827
else:
1928
raise ValueError(f"Unsupported embedding provider: {provider}")

mem0/embeddings/gemini.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,7 @@ def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]
2828
list: The embedding vector.
2929
"""
3030
text = text.replace("\n", " ")
31-
response = genai.embed_content(model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims)
31+
response = genai.embed_content(
32+
model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims
33+
)
3234
return response["embedding"]

mem0/embeddings/lmstudio.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,4 @@ def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]
2626
list: The embedding vector.
2727
"""
2828
text = text.replace("\n", " ")
29-
return (
30-
self.client.embeddings.create(input=[text], model=self.config.model)
31-
.data[0]
32-
.embedding
33-
)
29+
return self.client.embeddings.create(input=[text], model=self.config.model).data[0].embedding

mem0/embeddings/openai.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
1717

1818
api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
1919
base_url = (
20-
self.config.openai_base_url
21-
or os.getenv("OPENAI_API_BASE")
22-
or os.getenv("OPENAI_BASE_URL")
23-
or "https://api.openai.com/v1"
20+
self.config.openai_base_url
21+
or os.getenv("OPENAI_API_BASE")
22+
or os.getenv("OPENAI_BASE_URL")
23+
or "https://api.openai.com/v1"
2424
)
2525
if os.environ.get("OPENAI_API_BASE"):
2626
warnings.warn(
2727
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
2828
"Please use 'OPENAI_BASE_URL' instead.",
29-
DeprecationWarning
29+
DeprecationWarning,
3030
)
3131

3232
self.client = OpenAI(api_key=api_key, base_url=base_url)
@@ -42,4 +42,8 @@ def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]
4242
list: The embedding vector.
4343
"""
4444
text = text.replace("\n", " ")
45-
return self.client.embeddings.create(input=[text], model=self.config.model, dimensions = self.config.embedding_dims).data[0].embedding
45+
return (
46+
self.client.embeddings.create(input=[text], model=self.config.model, dimensions=self.config.embedding_dims)
47+
.data[0]
48+
.embedding
49+
)

mem0/llms/azure_openai_structured.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import os
32
from typing import Dict, List, Optional
43

@@ -36,6 +35,8 @@ def generate_response(
3635
self,
3736
messages: List[Dict[str, str]],
3837
response_format: Optional[str] = None,
38+
tools: Optional[List[Dict]] = None,
39+
tool_choice: str = "auto",
3940
) -> str:
4041
"""
4142
Generate a response based on the given messages using Azure OpenAI.

mem0/llms/base.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Optional
2+
from typing import Dict, List, Optional
33

44
from mem0.configs.llms.base import BaseLlmConfig
55

@@ -17,12 +17,14 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):
1717
self.config = config
1818

1919
@abstractmethod
20-
def generate_response(self, messages):
20+
def generate_response(self, messages, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
2121
"""
2222
Generate a response based on the given messages.
2323
2424
Args:
2525
messages (list): List of message dicts containing 'role' and 'content'.
26+
tools (list, optional): List of tools that the model can call. Defaults to None.
27+
tool_choice (str, optional): Tool choice method. Defaults to "auto".
2628
2729
Returns:
2830
str: The generated response.

mem0/llms/groq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,4 @@ def generate_response(
8484
params["tool_choice"] = tool_choice
8585

8686
response = self.client.chat.completions.create(**params)
87-
return self._parse_response(response, tools)
87+
return self._parse_response(response, tools)

mem0/llms/lmstudio.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@ class LMStudioLLM(LLMBase):
1010
def __init__(self, config: Optional[BaseLlmConfig] = None):
1111
super().__init__(config)
1212

13-
self.config.model = self.config.model or "lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF/Meta-Llama-3.1-70B-Instruct-IQ2_M.gguf"
13+
self.config.model = (
14+
self.config.model
15+
or "lmstudio-community/Meta-Llama-3.1-70B-Instruct-GGUF/Meta-Llama-3.1-70B-Instruct-IQ2_M.gguf"
16+
)
1417
self.config.api_key = self.config.api_key or "lm-studio"
1518

1619
self.client = OpenAI(base_url=self.config.lmstudio_base_url, api_key=self.config.api_key)
@@ -20,7 +23,7 @@ def generate_response(
2023
messages: List[Dict[str, str]],
2124
response_format: dict = {"type": "json_object"},
2225
tools: Optional[List[Dict]] = None,
23-
tool_choice: str = "auto"
26+
tool_choice: str = "auto",
2427
):
2528
"""
2629
Generate a response based on the given messages using LM Studio.
@@ -39,10 +42,10 @@ def generate_response(
3942
"messages": messages,
4043
"temperature": self.config.temperature,
4144
"max_tokens": self.config.max_tokens,
42-
"top_p": self.config.top_p
45+
"top_p": self.config.top_p,
4346
}
4447
if response_format:
4548
params["response_format"] = response_format
4649

4750
response = self.client.chat.completions.create(**params)
48-
return response.choices[0].message.content
51+
return response.choices[0].message.content

mem0/llms/openai.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23
import warnings
34
from typing import Dict, List, Optional
@@ -34,7 +35,7 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):
3435
warnings.warn(
3536
"The environment variable 'OPENAI_API_BASE' is deprecated and will be removed in the 0.1.80. "
3637
"Please use 'OPENAI_BASE_URL' instead.",
37-
DeprecationWarning
38+
DeprecationWarning,
3839
)
3940

4041
self.client = OpenAI(api_key=api_key, base_url=base_url)

mem0/llms/openai_structured.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import os
32
from typing import Dict, List, Optional
43

@@ -23,6 +22,8 @@ def generate_response(
2322
self,
2423
messages: List[Dict[str, str]],
2524
response_format: Optional[str] = None,
25+
tools: Optional[List[Dict]] = None,
26+
tool_choice: str = "auto",
2627
) -> str:
2728
"""
2829
Generate a response based on the given messages using OpenAI.

mem0/llms/xai.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,21 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):
1818
base_url = self.config.xai_base_url or os.getenv("XAI_API_BASE") or "https://api.x.ai/v1"
1919
self.client = OpenAI(api_key=api_key, base_url=base_url)
2020

21-
def generate_response(self, messages: List[Dict[str, str]], response_format=None):
21+
def generate_response(
22+
self,
23+
messages: List[Dict[str, str]],
24+
response_format=None,
25+
tools: Optional[List[Dict]] = None,
26+
tool_choice: str = "auto",
27+
):
2228
"""
2329
Generate a response based on the given messages using XAI.
2430
2531
Args:
2632
messages (list): List of message dicts containing 'role' and 'content'.
2733
response_format (str or object, optional): Format of the response. Defaults to "text".
34+
tools (list, optional): List of tools that the model can call. Defaults to None.
35+
tool_choice (str, optional): Tool choice method. Defaults to "auto".
2836
2937
Returns:
3038
str: The generated response.

mem0/memory/graph_memory.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
except ImportError:
1313
raise ImportError("rank_bm25 is not installed. Please install it using pip install rank-bm25")
1414

15-
from mem0.graphs.tools import (DELETE_MEMORY_STRUCT_TOOL_GRAPH,
16-
DELETE_MEMORY_TOOL_GRAPH,
17-
EXTRACT_ENTITIES_STRUCT_TOOL,
18-
EXTRACT_ENTITIES_TOOL, RELATIONS_STRUCT_TOOL,
19-
RELATIONS_TOOL)
15+
from mem0.graphs.tools import (
16+
DELETE_MEMORY_STRUCT_TOOL_GRAPH,
17+
DELETE_MEMORY_TOOL_GRAPH,
18+
EXTRACT_ENTITIES_STRUCT_TOOL,
19+
EXTRACT_ENTITIES_TOOL,
20+
RELATIONS_STRUCT_TOOL,
21+
RELATIONS_TOOL,
22+
)
2023
from mem0.graphs.utils import EXTRACT_RELATIONS_PROMPT, get_delete_messages
2124
from mem0.utils.factory import EmbedderFactory, LlmFactory
2225

mem0/vector_stores/pinecone.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from pinecone import Pinecone, PodSpec, ServerlessSpec
99
from pinecone.data.dataclasses.vector import Vector
1010
except ImportError:
11-
raise ImportError("Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`") from None
11+
raise ImportError(
12+
"Pinecone requires extra dependencies. Install with `pip install pinecone pinecone-text`"
13+
) from None
1214

1315
from mem0.vector_stores.base import VectorStoreBase
1416

@@ -34,7 +36,7 @@ def __init__(
3436
hybrid_search: bool,
3537
metric: str,
3638
batch_size: int,
37-
extra_params: Optional[Dict[str, Any]]
39+
extra_params: Optional[Dict[str, Any]],
3840
):
3941
"""
4042
Initialize the Pinecone vector store.
@@ -199,7 +201,9 @@ def _create_filter(self, filters: Optional[Dict]) -> Dict:
199201

200202
return pinecone_filter
201203

202-
def search(self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
204+
def search(
205+
self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
206+
) -> List[OutputData]:
203207
"""
204208
Search for similar vectors.
205209

0 commit comments

Comments
 (0)