1
- from sqlmodel import SQLModel , Field , select
1
+ from sqlmodel import SQLModel , Field , select , update
2
2
from sqlalchemy .ext .asyncio import AsyncSession
3
3
import logging
4
+ from datetime import datetime
4
5
5
6
from app .database .models import *
6
7
from app .database .engine import db_engine
@@ -31,6 +32,11 @@ class UserQueryError(UserDatabaseError):
31
32
"""Raised when user query fails"""
32
33
33
34
pass
35
+
36
+ class UserUpdateError (UserDatabaseError ):
37
+ """Raised when user update fails"""
38
+
39
+ pass
34
40
35
41
36
42
async def get_or_create_user (wa_id : str , name : Optional [str ] = None ) -> User :
@@ -99,7 +105,76 @@ async def get_user_by_waid(wa_id: str) -> Optional[User]:
99
105
logger .error (f"Failed to query user { wa_id } : { str (e )} " )
100
106
raise UserQueryError (f"Failed to query user: { str (e )} " )
101
107
108
+ # TODO: rename this function
109
+ async def update_user_by_waid (user : User ) -> User :
110
+ """
111
+ Update any information about an existing user and return the updated user.
112
+ """
113
+ if user is None :
114
+ logger .error ("Cannot update user: user object is None" )
115
+ raise UserUpdateError ("Cannot update user: user object is None" )
116
+
117
+ # Convert the User object to a dictionary
118
+ user_data = user .__dict__ .copy ()
119
+
120
+ logger .debug (f"Updating user { user_data } " )
121
+
122
+ # Remove the _sa_instance_state attribute
123
+ user_data .pop ("_sa_instance_state" , None )
124
+
125
+ # Extract the wa_id
126
+ wa_id = user_data .pop ("wa_id" )
127
+
128
+ # Remove the id attribute
129
+ user_data .pop ("wa_id" , None )
130
+ user_data .pop ("id" , None )
131
+
132
+ # Handle the birthday field if necessary
133
+ if "birthday" in user_data and isinstance (user_data ["birthday" ], str ):
134
+ user_data ["birthday" ] = datetime .strptime (
135
+ user_data ["birthday" ], "%Y-%m-%d"
136
+ ).date ()
137
+
138
+ async with AsyncSession (db_engine ) as session :
139
+ try :
140
+ statement = update (User ).where (User .wa_id == wa_id ).values (** user_data )
141
+ await session .execute (statement )
142
+ await session .commit ()
143
+
144
+ # Fetch the updated user
145
+ result = await session .execute (select (User ).filter_by (wa_id = wa_id ))
146
+ updated_user = result .scalar_one_or_none ()
147
+
148
+ logger .info (f"Updated user { wa_id } with { user_data } " )
149
+ return updated_user
150
+ except Exception as e :
151
+ await session .rollback ()
152
+ logger .error (f"Failed to update user { wa_id } : { str (e )} " )
153
+ raise UserUpdateError (f"Failed to update user: { str (e )} " )
154
+
155
+
156
+ # TODO: This should be replaced with get_user_by_waid or the get_or_create_user function
157
+ async def get_user_data (wa_id : str ) -> dict :
158
+ """
159
+ Retrieve user data based on wa_id.
160
+ """
161
+ async with AsyncSession (db_engine ) as session :
162
+ try :
163
+ statement = select (User ).where (User .wa_id == wa_id )
164
+ result = await session .execute (statement )
165
+ user = result .scalar_one_or_none ()
166
+ if user :
167
+ user_data = user .model_dump ()
168
+ logger .info (f"Retrieved user data for { wa_id } : { user_data } " )
169
+ return user_data
170
+ else :
171
+ logger .warning (f"No user found with wa_id { wa_id } " )
172
+ return None
173
+ except Exception as e :
174
+ logger .error (f"Failed to query user { wa_id } : { str (e )} " )
175
+ raise UserQueryError (f"Failed to query user: { str (e )} " )
102
176
177
+
103
178
async def get_user_message_history (
104
179
user_id : int , limit : int = 10
105
180
) -> Optional [List [Message ]]:
@@ -202,4 +277,4 @@ async def vector_search(query: str, n_results: int, where: dict) -> List[Chunk]:
202
277
return result .scalars ().all ()
203
278
except Exception as e :
204
279
logger .error (f"Failed to search for knowledge: { str (e )} " )
205
- raise Exception (f"Failed to search for knowledge: { str (e )} " )
280
+ raise Exception (f"Failed to search for knowledge: { str (e )} " )
0 commit comments