1
1
import json
2
2
import logging
3
+ import re
3
4
from collections .abc import Generator
4
5
from typing import Any , Optional , Union , cast
5
6
@@ -621,11 +622,19 @@ def _chat_generate(
621
622
prompt_messages = self ._clear_illegal_prompt_messages (model , prompt_messages )
622
623
623
624
# o1 compatibility
625
+ block_as_stream = False
624
626
if model .startswith ("o1" ):
625
627
if "max_tokens" in model_parameters :
626
628
model_parameters ["max_completion_tokens" ] = model_parameters ["max_tokens" ]
627
629
del model_parameters ["max_tokens" ]
628
630
631
+ if re .match (r"^o1(-\d{4}-\d{2}-\d{2})?$" , model ):
632
+ if stream :
633
+ block_as_stream = True
634
+ stream = False
635
+ if "stream_options" in extra_model_kwargs :
636
+ del extra_model_kwargs ["stream_options" ]
637
+
629
638
if "stop" in extra_model_kwargs :
630
639
del extra_model_kwargs ["stop" ]
631
640
@@ -642,7 +651,45 @@ def _chat_generate(
642
651
if stream :
643
652
return self ._handle_chat_generate_stream_response (model , credentials , response , prompt_messages , tools )
644
653
645
- return self ._handle_chat_generate_response (model , credentials , response , prompt_messages , tools )
654
+ block_result = self ._handle_chat_generate_response (model , credentials , response , prompt_messages , tools )
655
+
656
+ if block_as_stream :
657
+ return self ._handle_chat_block_as_stream_response (block_result , prompt_messages , stop )
658
+
659
+ return block_result
660
+
661
+ def _handle_chat_block_as_stream_response (
662
+ self ,
663
+ block_result : LLMResult ,
664
+ prompt_messages : list [PromptMessage ],
665
+ stop : Optional [list [str ]] = None ,
666
+ ) -> Generator [LLMResultChunk , None , None ]:
667
+ """
668
+ Handle llm chat response
669
+ :param model: model name
670
+ :param credentials: credentials
671
+ :param response: response
672
+ :param prompt_messages: prompt messages
673
+ :param tools: tools for tool calling
674
+ :return: llm response chunk generator
675
+ """
676
+ text = block_result .message .content
677
+ text = cast (str , text )
678
+
679
+ if stop :
680
+ text = self .enforce_stop_tokens (text , stop )
681
+
682
+ yield LLMResultChunk (
683
+ model = block_result .model ,
684
+ prompt_messages = prompt_messages ,
685
+ system_fingerprint = block_result .system_fingerprint ,
686
+ delta = LLMResultChunkDelta (
687
+ index = 0 ,
688
+ message = block_result .message ,
689
+ finish_reason = "stop" ,
690
+ usage = block_result .usage ,
691
+ ),
692
+ )
646
693
647
694
def _handle_chat_generate_response (
648
695
self ,
0 commit comments