Skip to content

Commit 59944e8

Browse files
authored
Merge pull request #44 from Tanzania-AI-Community/fredy/development/onboarding-flow
Fredy/development/onboarding flow
2 parents b449158 + e9f6e39 commit 59944e8

19 files changed

+1581
-155
lines changed

.env.template

+5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ WHATSAPP_VERIFY_TOKEN=
1818
# 60 day access token to use META's API for the Twiga bot
1919
WHATSAPP_API_TOKEN=
2020

21+
# Whatsapp number for testing purposes
22+
RECIPIENT_WAID=
23+
24+
DAILY_MESSAGE_LIMIT=100
25+
2126
"""OpenAI API credentials and other environment variables"""
2227

2328
# OpenAI

.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,6 @@ cython_debug/
173173
.DS_Store
174174

175175
# Ignore vector store
176-
twiga_vector_store/
176+
twiga_vector_store/
177+
private.pem
178+
public.pem

app/config.py

+8
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,17 @@ class Settings(BaseSettings):
2121
meta_api_version: str
2222
meta_app_id: str
2323
meta_app_secret: SecretStr
24+
# WhatsApp settings
2425
whatsapp_cloud_number_id: str
2526
whatsapp_verify_token: SecretStr
2627
whatsapp_api_token: SecretStr
28+
whatsapp_business_public_key: SecretStr
29+
whatsapp_business_private_key: SecretStr
30+
whatsapp_business_private_key_password: SecretStr
31+
# Flows settings
32+
personal_and_school_info_flow_id: str
33+
subject_class_info_flow_id: str
34+
flow_token_encryption_key: SecretStr
2735
# Rate limit settings
2836
daily_message_limit: int
2937
# Database settings

app/database/db.py

+77-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from sqlmodel import SQLModel, Field, select
1+
from sqlmodel import SQLModel, Field, select, update
22
from sqlalchemy.ext.asyncio import AsyncSession
33
import logging
4+
from datetime import datetime
45

56
from app.database.models import *
67
from app.database.engine import db_engine
@@ -31,6 +32,11 @@ class UserQueryError(UserDatabaseError):
3132
"""Raised when user query fails"""
3233

3334
pass
35+
36+
class UserUpdateError(UserDatabaseError):
37+
"""Raised when user update fails"""
38+
39+
pass
3440

3541

3642
async def get_or_create_user(wa_id: str, name: Optional[str] = None) -> User:
@@ -99,7 +105,76 @@ async def get_user_by_waid(wa_id: str) -> Optional[User]:
99105
logger.error(f"Failed to query user {wa_id}: {str(e)}")
100106
raise UserQueryError(f"Failed to query user: {str(e)}")
101107

108+
# TODO: rename this function
109+
async def update_user_by_waid(user: User) -> User:
110+
"""
111+
Update any information about an existing user and return the updated user.
112+
"""
113+
if user is None:
114+
logger.error("Cannot update user: user object is None")
115+
raise UserUpdateError("Cannot update user: user object is None")
116+
117+
# Convert the User object to a dictionary
118+
user_data = user.__dict__.copy()
119+
120+
logger.debug(f"Updating user {user_data}")
121+
122+
# Remove the _sa_instance_state attribute
123+
user_data.pop("_sa_instance_state", None)
124+
125+
# Extract the wa_id
126+
wa_id = user_data.pop("wa_id")
127+
128+
# Remove the id attribute
129+
user_data.pop("wa_id", None)
130+
user_data.pop("id", None)
131+
132+
# Handle the birthday field if necessary
133+
if "birthday" in user_data and isinstance(user_data["birthday"], str):
134+
user_data["birthday"] = datetime.strptime(
135+
user_data["birthday"], "%Y-%m-%d"
136+
).date()
137+
138+
async with AsyncSession(db_engine) as session:
139+
try:
140+
statement = update(User).where(User.wa_id == wa_id).values(**user_data)
141+
await session.execute(statement)
142+
await session.commit()
143+
144+
# Fetch the updated user
145+
result = await session.execute(select(User).filter_by(wa_id=wa_id))
146+
updated_user = result.scalar_one_or_none()
147+
148+
logger.info(f"Updated user {wa_id} with {user_data}")
149+
return updated_user
150+
except Exception as e:
151+
await session.rollback()
152+
logger.error(f"Failed to update user {wa_id}: {str(e)}")
153+
raise UserUpdateError(f"Failed to update user: {str(e)}")
154+
155+
156+
# TODO: This should be replaced with get_user_by_waid or the get_or_create_user function
157+
async def get_user_data(wa_id: str) -> dict:
158+
"""
159+
Retrieve user data based on wa_id.
160+
"""
161+
async with AsyncSession(db_engine) as session:
162+
try:
163+
statement = select(User).where(User.wa_id == wa_id)
164+
result = await session.execute(statement)
165+
user = result.scalar_one_or_none()
166+
if user:
167+
user_data = user.model_dump()
168+
logger.info(f"Retrieved user data for {wa_id}: {user_data}")
169+
return user_data
170+
else:
171+
logger.warning(f"No user found with wa_id {wa_id}")
172+
return None
173+
except Exception as e:
174+
logger.error(f"Failed to query user {wa_id}: {str(e)}")
175+
raise UserQueryError(f"Failed to query user: {str(e)}")
102176

177+
103178
async def get_user_message_history(
104179
user_id: int, limit: int = 10
105180
) -> Optional[List[Message]]:
@@ -202,4 +277,4 @@ async def vector_search(query: str, n_results: int, where: dict) -> List[Chunk]:
202277
return result.scalars().all()
203278
except Exception as e:
204279
logger.error(f"Failed to search for knowledge: {str(e)}")
205-
raise Exception(f"Failed to search for knowledge: {str(e)}")
280+
raise Exception(f"Failed to search for knowledge: {str(e)}")

app/database/models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ class GradeLevel(str, Enum):
5757
class OnboardingState(str, Enum):
5858
new = "new"
5959
personal_info_submitted = "personal_info_submitted"
60-
class_subject_info_submitted = "class_subject_info_submitted"
6160
completed = "completed"
6261

6362

@@ -106,8 +105,10 @@ class User(SQLModel, table=True):
106105
role: str = Field(default=Role.teacher, max_length=20)
107106
class_info: Optional[dict] = Field(default=None, sa_type=JSON)
108107
school_name: Optional[str] = Field(default=None, max_length=100)
108+
school_location: Optional[str] = Field(default=None, max_length=100)
109109
birthday: Optional[date] = Field(default=None, sa_type=Date)
110110
region: Optional[str] = Field(default=None, max_length=50)
111+
location: Optional[str] = Field(default=None, max_length=100)
111112
last_message_at: Optional[datetime] = Field(
112113
sa_type=DateTime(timezone=True)
113114
) # user.last_message_at = datetime.now(timezone.utc) (this is how to set it when updating later)

app/main.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from fastapi import FastAPI, Request, Depends
1+
from fastapi import FastAPI, HTTPException, Request, Depends
22
from fastapi.responses import JSONResponse
33
import logging
44
from contextlib import asynccontextmanager
55

66

77
from app.security import signature_required
8+
from app.security import flows_signature_required
89
from app.services.whatsapp_service import whatsapp_client
910
from app.services.messaging_service import handle_request
11+
from app.services.flow_service import flow_client
1012
from app.database.engine import db_engine, init_db
1113

1214
logger = logging.getLogger(__name__)
@@ -40,3 +42,45 @@ async def webhook_get(request: Request) -> JSONResponse:
4042
async def webhook_post(request: Request) -> JSONResponse:
4143
logger.debug("webhook_post is being called")
4244
return await handle_request(request)
45+
46+
47+
@app.post("/flows", dependencies=[Depends(flows_signature_required)])
48+
async def handle_flows_webhook(request: Request) -> JSONResponse:
49+
try:
50+
body = await request.json()
51+
logger.debug(f"Received webhook: {body}")
52+
return await flow_client.handle_flow_webhook(body)
53+
except Exception as e:
54+
logger.error(f"Error handling webhook: {e}")
55+
raise HTTPException(status_code=500, detail="Internal Server Error")
56+
57+
58+
# use this when testing flows locally, the returned token will be the flow_token
59+
@app.post("/encrypt_flow_token")
60+
async def handle_encrypt_flow_token(request: Request) -> JSONResponse:
61+
try:
62+
body = await request.json()
63+
logger.debug(f"Received request to encrypt flow token: {body}")
64+
wa_id = body.get("wa_id")
65+
flow_id = body.get("flow_id")
66+
67+
logger.info(f"Encrypting flow token for wa_id {wa_id} and flow_id {flow_id}")
68+
69+
return await flow_client.encrypt_flow_token(wa_id, flow_id)
70+
except Exception as e:
71+
logger.error(f"Error encrypting flow token: {e}")
72+
raise HTTPException(status_code=500, detail="Internal Server Error")
73+
74+
75+
# decrypt_flow_token
76+
@app.post("/decrypt_flow_token")
77+
async def handle_decrypt_flow_token(request: Request) -> JSONResponse:
78+
try:
79+
body = await request.json()
80+
logger.debug(f"Received request to decrypt flow token: {body}")
81+
encrypted_flow_token = body.get("encrypted-flow-token")
82+
83+
return await flow_client.decrypt_flow_token(encrypted_flow_token)
84+
except Exception as e:
85+
logger.error(f"Error decrypting flow token: {e}")
86+
raise HTTPException(status_code=500, detail="Internal Server Error")

app/models/message_models.py

+66-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from pydantic import BaseModel, constr, Field
2-
from typing import List, Dict, Literal, Optional
1+
from pydantic import BaseModel, constr, Field, root_validator
2+
from typing import List, Dict, Literal, Optional, Union
33
import json
4+
from pydantic import BaseModel, constr, Field, model_validator
5+
from typing import List, Dict, Literal, Optional, Union
46

57

68
class TextObject(BaseModel):
@@ -68,25 +70,80 @@ class TextMessage(BaseModel):
6870

6971

7072
"""
71-
Main model for interactive messages
73+
Main model for template messages
7274
"""
7375

7476

75-
class InteractiveMessage(BaseModel):
77+
class TemplateMessage(BaseModel):
78+
messaging_product: Literal["whatsapp"] = "whatsapp"
79+
to: str
80+
type: Literal["template"] = "template"
81+
template: Dict[Literal["name", "language"], str]
82+
83+
84+
"""
85+
Models for flow interactive messages
86+
"""
87+
88+
89+
class FlowActionPayload(BaseModel):
90+
screen: str
91+
data: Dict[str, Union[str, int]]
92+
93+
94+
class FlowActionPayload(BaseModel):
95+
screen: str
96+
data: Dict[str, Union[str, int]]
97+
98+
99+
class FlowParameters(BaseModel):
100+
flow_message_version: str
101+
flow_token: str
102+
flow_name: Optional[str] = None
103+
flow_id: Optional[str] = None
104+
flow_cta: str
105+
flow_action: str
106+
flow_action_payload: FlowActionPayload
107+
108+
@model_validator(mode="before")
109+
def check_flow_name_or_id(cls, values):
110+
flow_name, flow_id = values.get("flow_name"), values.get("flow_id")
111+
if not flow_name and not flow_id:
112+
raise ValueError("Either flow_name or flow_id must be provided")
113+
if flow_name and flow_id:
114+
raise ValueError("Only one of flow_name or flow_id should be provided")
115+
return values
116+
117+
118+
class FlowAction(BaseModel):
119+
name: Literal["flow"]
120+
parameters: FlowParameters
121+
122+
123+
class FlowInteractive(BaseModel):
124+
type: Literal["flow"] = "flow"
125+
header: TextObject
126+
body: TextObject
127+
footer: TextObject
128+
action: FlowAction
129+
130+
131+
class FlowInteractiveMessage(BaseModel):
76132
messaging_product: Literal["whatsapp"] = "whatsapp"
77133
recipient_type: Literal["individual"] = "individual"
78134
to: str
79135
type: Literal["interactive"] = "interactive"
80-
interactive: InteractiveButton | InteractiveList
136+
interactive: FlowInteractive
81137

82138

83139
"""
84-
Main model for template messaged
140+
Main model for interactive messages
85141
"""
86142

87143

88-
class TemplateMessage(BaseModel):
144+
class InteractiveMessage(BaseModel):
89145
messaging_product: Literal["whatsapp"] = "whatsapp"
146+
recipient_type: Literal["individual"] = "individual"
90147
to: str
91-
type: Literal["template"] = "template"
92-
template: Dict[Literal["name", "language"], str]
148+
type: Literal["interactive"] = "interactive"
149+
interactive: Union[InteractiveButton, InteractiveList, FlowInteractive]

app/security.py

+11
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,14 @@ async def signature_required(request: Request) -> None:
3232
if not validate_signature(payload.decode("utf-8"), signature):
3333
logger.error("Signature verification failed!")
3434
raise HTTPException(status_code=403, detail="Invalid signature")
35+
36+
37+
# Dependency to ensure that incoming requests to our flows webhook are signed with the correct signature.
38+
async def flows_signature_required(request: Request) -> None:
39+
signature = request.headers.get("X-Hub-Signature-256", "")[7:] # Removing 'sha256='
40+
payload = await request.body()
41+
42+
if not validate_signature(payload.decode("utf-8"), signature):
43+
logger.error("Business signature verification failed!")
44+
# NOTE : We are using a custom status code here, 432. And user will see A generic error on the client.
45+
raise HTTPException(status_code=432, detail="Invalid business signature")

0 commit comments

Comments
 (0)