|
| 1 | +import streamlit as st |
| 2 | +import os |
| 3 | +import uuid |
| 4 | + |
| 5 | +from langchain_core.messages import HumanMessage |
| 6 | +from langchain_openai import ChatOpenAI |
| 7 | +from langgraph.checkpoint.postgres import PostgresSaver |
| 8 | +from langgraph.graph import START, MessagesState, StateGraph |
| 9 | +from psycopg import Connection |
| 10 | + |
| 11 | +AI_PREFIX = "Assistant" |
| 12 | +HUMAN_PREFIX = "User" |
| 13 | + |
| 14 | +@st.cache_resource |
| 15 | +def get_checkpointer(): |
| 16 | + db_uri = os.environ.get("DB_URI") |
| 17 | + connection_kwargs = { "autocommit": True, "prepare_threshold": 0 } |
| 18 | + conn = Connection.connect(conninfo=db_uri, **connection_kwargs) |
| 19 | + checkpointer = PostgresSaver(conn) |
| 20 | + checkpointer.setup() |
| 21 | + return checkpointer |
| 22 | + |
| 23 | +@st.cache_resource |
| 24 | +def get_model(): |
| 25 | + model_base_url = os.environ.get("MODEL_BASE_URL") |
| 26 | + model_name = os.environ.get("MODEL_NAME") |
| 27 | + return ChatOpenAI(base_url=model_base_url, openai_api_key="-", model=model_name) |
| 28 | + |
| 29 | +# Initialize Streamlit |
| 30 | +st.set_page_config(page_title="Streamlit chatbot", page_icon="🤖") |
| 31 | +st.title("Streamlit chatbot") |
| 32 | +st.caption("Powered by Google Cloud, Langchain and PostgreSQL") |
| 33 | + |
| 34 | +# Initialize the chat_id and messages |
| 35 | +headers = st.context.headers |
| 36 | +user_id = headers.get("X-Goog-Authenticated-User-Id") |
| 37 | + |
| 38 | +if "chat_id" not in st.session_state: |
| 39 | + st.session_state.chat_id = user_id or str(uuid.uuid4()) |
| 40 | +if "messages" not in st.session_state: |
| 41 | + st.session_state.messages = [] |
| 42 | + |
| 43 | +# Initialize the model |
| 44 | +model = get_model() |
| 45 | +def call_model(state: MessagesState): |
| 46 | + response = model.invoke(state["messages"]) |
| 47 | + return {"messages": response} |
| 48 | + |
| 49 | +# Initialize the workflow and LangChain Graph state |
| 50 | +workflow = StateGraph(state_schema=MessagesState) |
| 51 | +workflow.add_edge(START, "model") |
| 52 | +workflow.add_node("model", call_model) |
| 53 | +app = workflow.compile(checkpointer=get_checkpointer()) |
| 54 | +config = {"configurable": {"thread_id": st.session_state.chat_id}} |
| 55 | + |
| 56 | +# Load messages from LangChain Graph state and display them |
| 57 | +app_state = app.get_state(config) |
| 58 | +if "messages" in app_state.values: |
| 59 | + for message in app_state.values["messages"]: |
| 60 | + st.chat_message(message.type).markdown(message.content) |
| 61 | + |
| 62 | +# Get the user input and generate a response |
| 63 | +if prompt := st.chat_input("Enter your message"): |
| 64 | + prompt = prompt.strip() |
| 65 | + st.chat_message("human").markdown(prompt) |
| 66 | + st.session_state.messages.append({"role": "human", "content": prompt}) |
| 67 | + |
| 68 | + with st.spinner(text="Processing..."): |
| 69 | + output = app.invoke({"messages": [HumanMessage(prompt)]}, config) |
| 70 | + |
| 71 | + response_content = output["messages"][-1].content.strip() |
| 72 | + st.chat_message("ai").markdown(response_content) |
| 73 | + st.session_state.messages.append({"role": "ai", "content": response_content}) |
| 74 | + |
| 75 | +# Add button to reset chat history |
| 76 | +if len(st.session_state.messages) and st.button("Restart chat"): |
| 77 | + st.session_state.messages = [] |
| 78 | + cursor = get_checkpointer().conn.cursor() |
| 79 | + cursor.execute("DELETE FROM checkpoints WHERE thread_id = %s", (st.session_state.chat_id,)) |
| 80 | + cursor.execute("DELETE FROM checkpoint_writes WHERE thread_id = %s", (st.session_state.chat_id,)) |
| 81 | + cursor.execute("DELETE FROM checkpoint_blobs WHERE thread_id = %s", (st.session_state.chat_id,)) |
| 82 | + st.rerun() |
0 commit comments