14
14
15
15
import logging
16
16
import threading
17
- from typing import Dict , List , Optional
17
+ from typing import Any , Dict , List , Optional
18
18
from uuid import UUID
19
19
20
20
from langchain_core .callbacks import BaseCallbackHandler
@@ -32,16 +32,20 @@ def __init__(
32
32
splitting_chars : Optional [List [str ]] = None ,
33
33
max_buffer_size : int = 200 ,
34
34
logger : Optional [logging .Logger ] = None ,
35
+ stream_response : bool = True ,
35
36
):
36
37
self .connectors = connectors
37
38
self .aggregate_chunks = aggregate_chunks
39
+ self .stream_response = stream_response
38
40
self .splitting_chars = splitting_chars or ["\n " , "." , "!" , "?" ]
39
41
self .chunks_buffer = ""
40
42
self .max_buffer_size = max_buffer_size
41
43
self ._buffer_lock = threading .Lock ()
42
44
self .logger = logger or logging .getLogger (__name__ )
43
45
self .current_conversation_id = None
44
46
self .current_chunk_id = 0
47
+ self .working = False
48
+ self .hit_on_llm_new_token = False
45
49
46
50
def _should_split (self , token : str ) -> bool :
47
51
return token in self .splitting_chars
@@ -63,8 +67,22 @@ def _send_all_targets(self, tokens: str, done: bool = False):
63
67
f"Failed to send { len (tokens )} tokens to hri_connector: { e } "
64
68
)
65
69
70
+ def on_llm_start (
71
+ self ,
72
+ serialized : dict [str , Any ],
73
+ prompts : list [str ],
74
+ * ,
75
+ run_id : UUID ,
76
+ parent_run_id : Optional [UUID ] = None ,
77
+ tags : Optional [list [str ]] = None ,
78
+ metadata : Optional [dict [str , Any ]] = None ,
79
+ ** kwargs : Any ,
80
+ ) -> Any :
81
+ self .working = True
82
+
66
83
def on_llm_new_token (self , token : str , * , run_id : UUID , ** kwargs ):
67
- if token == "" :
84
+ self .hit_on_llm_new_token = True
85
+ if token == "" or not self .stream_response :
68
86
return
69
87
if self .current_conversation_id != str (run_id ):
70
88
self .current_conversation_id = str (run_id )
@@ -93,7 +111,22 @@ def on_llm_end(
93
111
** kwargs ,
94
112
):
95
113
self .current_conversation_id = str (run_id )
96
- if self .aggregate_chunks and self .chunks_buffer :
114
+ if self .stream_response and not self .hit_on_llm_new_token :
115
+ self .logger .error (
116
+ (
117
+ "No tokens were sent to the callback handler. "
118
+ "LLM did not stream response. "
119
+ "Is your BaseChatModel configured to stream? "
120
+ "Sending generated text as a single message."
121
+ )
122
+ )
123
+ msg = response .generations [0 ][0 ].message
124
+ self ._send_all_targets (msg .content , done = True )
125
+ elif not self .stream_response :
126
+ msg = response .generations [0 ][0 ].message
127
+ self ._send_all_targets (msg .content , done = True )
128
+ elif self .aggregate_chunks and self .chunks_buffer :
97
129
with self ._buffer_lock :
98
130
self ._send_all_targets (self .chunks_buffer , done = True )
99
131
self .chunks_buffer = ""
132
+ self .working = False
0 commit comments