1
- import sys
2
- import time
1
+ import simplejson
2
+ import enum
3
+ import inspect
3
4
import openai
5
+ import sys
4
6
import threading
7
+ import time
8
+ import typing
9
+ from src .spin import spin
5
10
6
- def create_chat_with_spinner (messages , temperature ):
7
- chats = []
11
+ def get_json_type_name (value ):
12
+ if isinstance (value , str ):
13
+ return "string"
14
+ elif isinstance (value , int ):
15
+ return "integer"
16
+ elif isinstance (value , float ):
17
+ return "number"
18
+ elif isinstance (value , bool ):
19
+ return "boolean"
20
+ else :
21
+ return "null"
8
22
9
- def show_spinner ():
10
- spinner = ["|" , "/" , "-" , "\\ " ]
11
- i = 0
12
- while len (chats ) == 0 :
13
- sys .stdout .write ("\r " + spinner [i % 4 ])
14
- sys .stdout .flush ()
15
- time .sleep (0.1 )
16
- i += 1
23
+ def _json_schema (func ):
24
+ api_info = {}
17
25
26
+ # Get function name
27
+ api_info ["name" ] = func .__name__
18
28
19
- def create_chat_model ():
20
- chat = openai .ChatCompletion .create (
21
- model = "gpt-3.5-turbo" , messages = messages , stream = True , temperature = temperature
22
- )
23
- chats .append (chat )
29
+ # Get function description from docstring
30
+ docstring = inspect .getdoc (func )
31
+ if docstring :
32
+ api_info ["description" ] = docstring .split ("\n \n " )[0 ]
33
+
34
+ # Get function parameters
35
+ parameters = {}
36
+ parameters ["type" ] = "object"
37
+
38
+ properties = {}
39
+
40
+ signature = inspect .signature (func )
41
+ required = []
42
+ for param_name , param in signature .parameters .items ():
43
+ param_info = {}
24
44
25
- thread = threading .Thread (target = create_chat_model )
26
- thread .start ()
45
+ # Get parameter type from type hints
46
+ if typing .get_origin (param .annotation ) is typing .Union and type (None ) in typing .get_args (param .annotation ):
47
+ # Handle Optional case
48
+ inner_type = typing .get_args (param .annotation )[0 ]
49
+ if issubclass (inner_type , enum .Enum ):
50
+ param_info ["type" ] = get_json_type_name (inner_type .__members__ [list (inner_type .__members__ )[0 ]].value )
51
+ param_info ["enum" ] = [member .value for member in inner_type ]
52
+ else :
53
+ param_info ["type" ] = get_json_type_name (inner_type .__name__ )
54
+ elif issubclass (param .annotation , enum .Enum ):
55
+ param_info ["type" ] = get_json_type_name (param .annotation .__members__ [list (param .annotation .__members__ )[0 ]].value )
56
+ param_info ["enum" ] = [member .value for member in param .annotation ]
57
+ else :
58
+ param_info ["type" ] = get_json_type_name (param .annotation .__name__ )
27
59
28
- spinner_thread = threading .Thread (target = show_spinner )
29
- spinner_thread .start ()
60
+ # Get parameter description from docstring
61
+ if docstring and param_name in docstring :
62
+ param_info ["description" ] = docstring .split (param_name + ":" )[1 ].split ("\n " )[0 ].strip ()
30
63
31
- thread .join ()
32
- spinner_thread .join ()
64
+ # Check if parameter is required
65
+ if param .default == inspect .Parameter .empty :
66
+ required .append (param_name )
33
67
34
- sys .stdout .write ("\r \r " )
35
- return chats [0 ]
68
+ # Add parameter info to parameters dict
69
+ properties [param_name ] = param_info
70
+
71
+ # Add parameters to api_info
72
+ parameters ["properties" ] = properties
73
+ parameters ["required" ] = required
74
+ api_info ["parameters" ] = parameters
75
+
76
+ return api_info
77
+
78
+ def create_chat_with_spinner (messages , temperature , functions ):
79
+ return create_chat (messages , temperature , functions , True )
80
+
81
+ def create_chat (messages , temperature , functions , spinner = False ):
82
+ def create_chat_model ():
83
+ args = {
84
+ "model" : "gpt-4" ,
85
+ "messages" : messages ,
86
+ "stream" : True ,
87
+ "temperature" : temperature ,
88
+ }
89
+ if functions :
90
+ args ['functions' ] = list (map (lambda f : _json_schema (f ), functions ))
91
+ return openai .ChatCompletion .create (** args )
92
+
93
+ if not spinner :
94
+ chats = []
95
+ return create_chat_model ()
96
+ else :
97
+ return spin (create_chat_model )
36
98
37
99
38
100
def suggest_name (chat_id , message ):
@@ -48,3 +110,16 @@ def suggest_name(chat_id, message):
48
110
name = chat_completion_resp .choices [0 ].message .content
49
111
return (chat_id , name )
50
112
113
+ def invoke (functions , name , args_str ):
114
+ try :
115
+ args = simplejson .loads (args_str , strict = False )
116
+ except Exception as e :
117
+ print ("" )
118
+ print (f"ChatGPT called { name } with bad input: { args_str } " )
119
+ raise ValueError ("The function arguments did not form a valid JSON document" )
120
+
121
+ for func in functions :
122
+ if func .__name__ == name :
123
+ return func (** args )
124
+ else :
125
+ raise ValueError (f"Function '{ name } ' not found." )
0 commit comments