Skip to content

Commit 773d5a6

Browse files
committed
added streaming
1 parent d93b315 commit 773d5a6

File tree

3 files changed

+132
-197
lines changed

3 files changed

+132
-197
lines changed

app/app.py

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,11 @@
1-
from flask import Flask, request, jsonify, render_template
2-
import subprocess
3-
import os
1+
from flask import Flask, request, jsonify, render_template, Response, stream_with_context
42
from rag_system import rag_system
53

64
app = Flask(__name__)
75

86
@app.route('/', methods=['GET', 'POST'])
97
def index():
10-
if request.method == 'POST':
11-
query = request.form.get('query')
12-
if not query:
13-
return render_template('index.html', query=None, response="No query provided")
14-
15-
try:
16-
response = rag_system.answer_query(query)
17-
return render_template('index.html', query=query, response=response)
18-
except Exception as e:
19-
print(f"Error in /ask endpoint: {e}")
20-
return render_template('index.html', query=query, response="Internal Server Error")
21-
return render_template('index.html', query=None, response=None)
8+
return render_template('index.html')
229

2310
@app.route('/ask', methods=['POST'])
2411
def ask():
@@ -27,24 +14,15 @@ def ask():
2714
if not query:
2815
return jsonify({"error": "No query provided"}), 400
2916

30-
try:
31-
response = rag_system.answer_query(query)
32-
return jsonify({"response": response})
33-
except Exception as e:
34-
print(f"Error in /ask endpoint: {e}")
35-
return jsonify({"error": "Internal Server Error"}), 500
36-
37-
38-
# # New endpoint for triggering the rebuild
39-
# def run_get_knowledge_base_script():
40-
# """ Function to run the get_knowledge_base.py script from the parent directory """
41-
# try:
42-
# subprocess.run(['python', 'get_knowledge_base.py'], check=True)
17+
def generate():
18+
try:
19+
for token in rag_system.answer_query_stream(query):
20+
yield token
21+
except Exception as e:
22+
print(f"Error in /ask endpoint: {e}")
23+
yield "Internal Server Error"
4324

44-
# except subprocess.CalledProcessError as e:
45-
# print(f"Error running get_knowledge_base.py: {e}")
46-
# except Exception as e:
47-
# print(f"An error occurred: {e}")
25+
return Response(stream_with_context(generate()), content_type='text/plain')
4826

4927
@app.route('/trigger-rebuild', methods=['POST'])
5028
def trigger_rebuild():
@@ -73,6 +51,6 @@ def trigger_rebuild():
7351
except Exception as e:
7452
print(f"Error in /trigger-rebuild endpoint: {e}")
7553
return jsonify({"error": "Internal Server Error"}), 500
76-
54+
7755
if __name__ == '__main__':
78-
app.run(host='0.0.0.0', port=5000)
56+
app.run(host='0.0.0.0', port=5000)

app/rag_system.py

Lines changed: 36 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import numpy as np
66
from sklearn.metrics.pairwise import cosine_similarity
77

8-
# Ensure you have set the OPENAI_API_KEY in your environment variables
98
openai.api_key = os.getenv("OPENAI_API_KEY")
109

1110
class RAGSystem:
@@ -14,43 +13,23 @@ def __init__(self, knowledge_base_path='knowledge_base.json'):
1413
self.knowledge_base = self.load_knowledge_base()
1514
self.model = SentenceTransformer('all-MiniLM-L6-v2')
1615
self.doc_embeddings = self.embed_knowledge_base()
17-
self.conversation_history = [] # To store the conversation history
16+
self.conversation_history = []
1817

1918
def load_knowledge_base(self):
20-
"""
21-
Load the knowledge base from a JSON file.
22-
"""
2319
with open(self.knowledge_base_path, 'r') as kb_file:
2420
return json.load(kb_file)
2521

2622
def embed_knowledge_base(self):
27-
"""
28-
Embed the knowledge base using the SentenceTransformer model.
29-
Combines 'about' and 'text' fields for each document for embedding.
30-
"""
3123
docs = [f'{doc["about"]}. {doc["text"]}' for doc in self.knowledge_base]
3224
return self.model.encode(docs, convert_to_tensor=True)
3325

3426
def normalize_query(self, query):
35-
"""
36-
Normalize the query by converting it to lowercase and stripping whitespace.
37-
"""
3827
return query.lower().strip()
3928

4029
def retrieve(self, query, similarity_threshold=0.7, high_match_threshold=0.8, max_docs=5):
41-
"""
42-
Retrieve relevant documents from the knowledge base using cosine similarity.
43-
"""
4430
normalized_query = self.normalize_query(query)
45-
print(f"Retrieving context for query: '{normalized_query}'")
46-
47-
# Query embedding
4831
query_embedding = self.model.encode([normalized_query], convert_to_tensor=True)
49-
50-
# Calculate similarities
5132
similarities = cosine_similarity(query_embedding, self.doc_embeddings)[0]
52-
53-
# Initialize relevance scores
5433
relevance_scores = []
5534

5635
for i, doc in enumerate(self.knowledge_base):
@@ -72,77 +51,58 @@ def retrieve(self, query, similarity_threshold=0.7, high_match_threshold=0.8, ma
7251
max_index = np.argmax(similarities)
7352
retrieved_docs.append(f'{self.knowledge_base[max_index]["about"]}. {self.knowledge_base[max_index]["text"]}')
7453

75-
context = "\n\n".join(retrieved_docs)
76-
print("Retrieved Context:\n", context)
54+
return "\n\n".join(retrieved_docs)
55+
56+
def answer_query_stream(self, query):
57+
try:
58+
normalized_query = self.normalize_query(query)
59+
context = self.retrieve(normalized_query)
60+
61+
self.conversation_history.append({"role": "user", "content": query})
7762

78-
return context
63+
messages = [{"role": "system", "content": "You are a helpful assistant."}]
64+
messages.extend(self.conversation_history)
7965

80-
def generate_response(self, query, context):
81-
"""
82-
Generate a response from OpenAI using conversation history and knowledge base context.
83-
"""
84-
# Add user query to conversation history
85-
self.conversation_history.append({"role": "user", "content": query})
86-
87-
# Include conversation history in the prompt
88-
messages = [{"role": "system", "content": "You are a helpful assistant."}]
89-
messages.extend(self.conversation_history) # Include past conversation history
90-
91-
# Prepare the context from the knowledge base
92-
system_message = {
93-
"role": "system",
94-
"content": (
95-
"You are a dedicated assistant designed to provide positive answers about Defang. "
96-
"When the user says 'you', 'your', or any pronoun, interpret it as referring to Defang with context of Defang. "
97-
"If the user's question involves comparisons with or references to other services, you may use external knowledge. "
98-
"However, if the question is strictly about Defang, you must ignore all external knowledge and only utilize the given context. "
99-
"Context: " + context
100-
)
101-
}
66+
system_message = {
67+
"role": "system",
68+
"content": (
69+
"You are a dedicated assistant designed to provide positive answers about Defang. "
70+
"When the user says 'you', 'your', or any pronoun, interpret it as referring to Defang with context of Defang. "
71+
"If the user's question involves comparisons with or references to other services, you may use external knowledge. "
72+
"However, if the question is strictly about Defang, you must ignore all external knowledge and only utilize the given context. "
73+
"Context: " + context
74+
)
75+
}
10276

103-
messages.append(system_message)
77+
messages.append(system_message)
10478

105-
try:
106-
response = openai.ChatCompletion.create(
79+
stream = openai.ChatCompletion.create(
10780
model="gpt-4-turbo",
10881
messages=messages,
10982
temperature=0.5,
11083
max_tokens=2048,
11184
top_p=1,
11285
frequency_penalty=0,
113-
presence_penalty=0
86+
presence_penalty=0,
87+
stream=True
11488
)
11589

116-
generated_response = response['choices'][0]['message']['content'].strip()
117-
118-
# Add the bot's response to the conversation history
119-
self.conversation_history.append({"role": "assistant", "content": generated_response})
120-
121-
print("Generated Response:\n", generated_response)
122-
return generated_response
90+
collected_messages = []
91+
for chunk in stream:
92+
if chunk['choices'][0]['finish_reason'] is not None:
93+
break
94+
content = chunk['choices'][0]['delta'].get('content', '')
95+
collected_messages.append(content)
96+
yield content
12397

124-
except openai.error.OpenAIError as e:
125-
print(f"Error generating response from OpenAI: {e}")
126-
return "An error occurred while generating the response."
98+
full_response = ''.join(collected_messages).strip()
99+
self.conversation_history.append({"role": "assistant", "content": full_response})
127100

128-
def answer_query(self, query):
129-
"""
130-
Answer the user query, leveraging knowledge base context and conversation history.
131-
"""
132-
try:
133-
normalized_query = self.normalize_query(query)
134-
context = self.retrieve(normalized_query)
135-
response = self.generate_response(normalized_query, context)
136-
return response
137101
except Exception as e:
138-
print(f"Error in answer_query: {e}")
139-
return "An error occurred while generating the response."
102+
print(f"Error in answer_query_stream: {e}")
103+
yield "An error occurred while generating the response."
140104

141105
def clear_conversation_history(self):
142-
"""
143-
Clear the stored conversation history.
144-
This can be called to reset the conversation for a new session.
145-
"""
146106
self.conversation_history = []
147107
print("Conversation history cleared.")
148108

0 commit comments

Comments
 (0)