5
5
import numpy as np
6
6
from sklearn .metrics .pairwise import cosine_similarity
7
7
8
- # Ensure you have set the OPENAI_API_KEY in your environment variables
9
8
openai .api_key = os .getenv ("OPENAI_API_KEY" )
10
9
11
10
class RAGSystem :
@@ -14,43 +13,23 @@ def __init__(self, knowledge_base_path='knowledge_base.json'):
14
13
self .knowledge_base = self .load_knowledge_base ()
15
14
self .model = SentenceTransformer ('all-MiniLM-L6-v2' )
16
15
self .doc_embeddings = self .embed_knowledge_base ()
17
- self .conversation_history = [] # To store the conversation history
16
+ self .conversation_history = []
18
17
19
18
def load_knowledge_base (self ):
20
- """
21
- Load the knowledge base from a JSON file.
22
- """
23
19
with open (self .knowledge_base_path , 'r' ) as kb_file :
24
20
return json .load (kb_file )
25
21
26
22
def embed_knowledge_base (self ):
27
- """
28
- Embed the knowledge base using the SentenceTransformer model.
29
- Combines 'about' and 'text' fields for each document for embedding.
30
- """
31
23
docs = [f'{ doc ["about" ]} . { doc ["text" ]} ' for doc in self .knowledge_base ]
32
24
return self .model .encode (docs , convert_to_tensor = True )
33
25
34
26
def normalize_query (self , query ):
35
- """
36
- Normalize the query by converting it to lowercase and stripping whitespace.
37
- """
38
27
return query .lower ().strip ()
39
28
40
29
def retrieve (self , query , similarity_threshold = 0.7 , high_match_threshold = 0.8 , max_docs = 5 ):
41
- """
42
- Retrieve relevant documents from the knowledge base using cosine similarity.
43
- """
44
30
normalized_query = self .normalize_query (query )
45
- print (f"Retrieving context for query: '{ normalized_query } '" )
46
-
47
- # Query embedding
48
31
query_embedding = self .model .encode ([normalized_query ], convert_to_tensor = True )
49
-
50
- # Calculate similarities
51
32
similarities = cosine_similarity (query_embedding , self .doc_embeddings )[0 ]
52
-
53
- # Initialize relevance scores
54
33
relevance_scores = []
55
34
56
35
for i , doc in enumerate (self .knowledge_base ):
@@ -72,77 +51,58 @@ def retrieve(self, query, similarity_threshold=0.7, high_match_threshold=0.8, ma
72
51
max_index = np .argmax (similarities )
73
52
retrieved_docs .append (f'{ self .knowledge_base [max_index ]["about" ]} . { self .knowledge_base [max_index ]["text" ]} ' )
74
53
75
- context = "\n \n " .join (retrieved_docs )
76
- print ("Retrieved Context:\n " , context )
54
+ return "\n \n " .join (retrieved_docs )
55
+
56
+ def answer_query_stream (self , query ):
57
+ try :
58
+ normalized_query = self .normalize_query (query )
59
+ context = self .retrieve (normalized_query )
60
+
61
+ self .conversation_history .append ({"role" : "user" , "content" : query })
77
62
78
- return context
63
+ messages = [{"role" : "system" , "content" : "You are a helpful assistant." }]
64
+ messages .extend (self .conversation_history )
79
65
80
- def generate_response (self , query , context ):
81
- """
82
- Generate a response from OpenAI using conversation history and knowledge base context.
83
- """
84
- # Add user query to conversation history
85
- self .conversation_history .append ({"role" : "user" , "content" : query })
86
-
87
- # Include conversation history in the prompt
88
- messages = [{"role" : "system" , "content" : "You are a helpful assistant." }]
89
- messages .extend (self .conversation_history ) # Include past conversation history
90
-
91
- # Prepare the context from the knowledge base
92
- system_message = {
93
- "role" : "system" ,
94
- "content" : (
95
- "You are a dedicated assistant designed to provide positive answers about Defang. "
96
- "When the user says 'you', 'your', or any pronoun, interpret it as referring to Defang with context of Defang. "
97
- "If the user's question involves comparisons with or references to other services, you may use external knowledge. "
98
- "However, if the question is strictly about Defang, you must ignore all external knowledge and only utilize the given context. "
99
- "Context: " + context
100
- )
101
- }
66
+ system_message = {
67
+ "role" : "system" ,
68
+ "content" : (
69
+ "You are a dedicated assistant designed to provide positive answers about Defang. "
70
+ "When the user says 'you', 'your', or any pronoun, interpret it as referring to Defang with context of Defang. "
71
+ "If the user's question involves comparisons with or references to other services, you may use external knowledge. "
72
+ "However, if the question is strictly about Defang, you must ignore all external knowledge and only utilize the given context. "
73
+ "Context: " + context
74
+ )
75
+ }
102
76
103
- messages .append (system_message )
77
+ messages .append (system_message )
104
78
105
- try :
106
- response = openai .ChatCompletion .create (
79
+ stream = openai .ChatCompletion .create (
107
80
model = "gpt-4-turbo" ,
108
81
messages = messages ,
109
82
temperature = 0.5 ,
110
83
max_tokens = 2048 ,
111
84
top_p = 1 ,
112
85
frequency_penalty = 0 ,
113
- presence_penalty = 0
86
+ presence_penalty = 0 ,
87
+ stream = True
114
88
)
115
89
116
- generated_response = response [ 'choices' ][ 0 ][ 'message' ][ 'content' ]. strip ()
117
-
118
- # Add the bot's response to the conversation history
119
- self . conversation_history . append ({ "role" : "assistant" , "content" : generated_response })
120
-
121
- print ( "Generated Response: \n " , generated_response )
122
- return generated_response
90
+ collected_messages = []
91
+ for chunk in stream :
92
+ if chunk [ 'choices' ][ 0 ][ 'finish_reason' ] is not None :
93
+ break
94
+ content = chunk [ 'choices' ][ 0 ][ 'delta' ]. get ( 'content' , '' )
95
+ collected_messages . append ( content )
96
+ yield content
123
97
124
- except openai .error .OpenAIError as e :
125
- print (f"Error generating response from OpenAI: { e } " )
126
- return "An error occurred while generating the response."
98
+ full_response = '' .join (collected_messages ).strip ()
99
+ self .conversation_history .append ({"role" : "assistant" , "content" : full_response })
127
100
128
- def answer_query (self , query ):
129
- """
130
- Answer the user query, leveraging knowledge base context and conversation history.
131
- """
132
- try :
133
- normalized_query = self .normalize_query (query )
134
- context = self .retrieve (normalized_query )
135
- response = self .generate_response (normalized_query , context )
136
- return response
137
101
except Exception as e :
138
- print (f"Error in answer_query : { e } " )
139
- return "An error occurred while generating the response."
102
+ print (f"Error in answer_query_stream : { e } " )
103
+ yield "An error occurred while generating the response."
140
104
141
105
def clear_conversation_history (self ):
142
- """
143
- Clear the stored conversation history.
144
- This can be called to reset the conversation for a new session.
145
- """
146
106
self .conversation_history = []
147
107
print ("Conversation history cleared." )
148
108
0 commit comments