Skip to content

Commit 171dd2f

Browse files
committed
Add function execution
1 parent 4eec697 commit 171dd2f

File tree

6 files changed

+358
-72
lines changed

6 files changed

+358
-72
lines changed

Makefile

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
.PHONY: test upload
22

3-
test:
4-
rm -rf testenv/ ; python3 -m venv testenv ; source testenv/bin/activate.csh ; pip install -e . ; testenv/bin/gptline
3+
run:
4+
source testenv/bin/activate ; pip install -e . ; testenv/bin/gptline
5+
6+
rebuild:
7+
rm -rf testenv/ ; python3 -m venv testenv ; source testenv/bin/activate ; pip install -e . ; testenv/bin/gptline
58

69
upload:
710
rm -f dist/*

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
openai
22
prompt-toolkit
3+
simplejson

src/chat.py

+100-25
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,100 @@
1-
import sys
2-
import time
1+
import simplejson
2+
import enum
3+
import inspect
34
import openai
5+
import sys
46
import threading
7+
import time
8+
import typing
9+
from src.spin import spin
510

6-
def create_chat_with_spinner(messages, temperature):
7-
chats = []
11+
def get_json_type_name(value):
12+
if isinstance(value, str):
13+
return "string"
14+
elif isinstance(value, int):
15+
return "integer"
16+
elif isinstance(value, float):
17+
return "number"
18+
elif isinstance(value, bool):
19+
return "boolean"
20+
else:
21+
return "null"
822

9-
def show_spinner():
10-
spinner = ["|", "/", "-", "\\"]
11-
i = 0
12-
while len(chats) == 0:
13-
sys.stdout.write("\r" + spinner[i % 4])
14-
sys.stdout.flush()
15-
time.sleep(0.1)
16-
i += 1
23+
def _json_schema(func):
24+
api_info = {}
1725

26+
# Get function name
27+
api_info["name"] = func.__name__
1828

19-
def create_chat_model():
20-
chat = openai.ChatCompletion.create(
21-
model="gpt-3.5-turbo", messages=messages, stream=True, temperature=temperature
22-
)
23-
chats.append(chat)
29+
# Get function description from docstring
30+
docstring = inspect.getdoc(func)
31+
if docstring:
32+
api_info["description"] = docstring.split("\n\n")[0]
33+
34+
# Get function parameters
35+
parameters = {}
36+
parameters["type"] = "object"
37+
38+
properties = {}
39+
40+
signature = inspect.signature(func)
41+
required = []
42+
for param_name, param in signature.parameters.items():
43+
param_info = {}
2444

25-
thread = threading.Thread(target=create_chat_model)
26-
thread.start()
45+
# Get parameter type from type hints
46+
if typing.get_origin(param.annotation) is typing.Union and type(None) in typing.get_args(param.annotation):
47+
# Handle Optional case
48+
inner_type = typing.get_args(param.annotation)[0]
49+
if issubclass(inner_type, enum.Enum):
50+
param_info["type"] = get_json_type_name(inner_type.__members__[list(inner_type.__members__)[0]].value)
51+
param_info["enum"] = [member.value for member in inner_type]
52+
else:
53+
param_info["type"] = get_json_type_name(inner_type.__name__)
54+
elif issubclass(param.annotation, enum.Enum):
55+
param_info["type"] = get_json_type_name(param.annotation.__members__[list(param.annotation.__members__)[0]].value)
56+
param_info["enum"] = [member.value for member in param.annotation]
57+
else:
58+
param_info["type"] = get_json_type_name(param.annotation.__name__)
2759

28-
spinner_thread = threading.Thread(target=show_spinner)
29-
spinner_thread.start()
60+
# Get parameter description from docstring
61+
if docstring and param_name in docstring:
62+
param_info["description"] = docstring.split(param_name + ":")[1].split("\n")[0].strip()
3063

31-
thread.join()
32-
spinner_thread.join()
64+
# Check if parameter is required
65+
if param.default == inspect.Parameter.empty:
66+
required.append(param_name)
3367

34-
sys.stdout.write("\r \r")
35-
return chats[0]
68+
# Add parameter info to parameters dict
69+
properties[param_name] = param_info
70+
71+
# Add parameters to api_info
72+
parameters["properties"] = properties
73+
parameters["required"] = required
74+
api_info["parameters"] = parameters
75+
76+
return api_info
77+
78+
def create_chat_with_spinner(messages, temperature, functions):
79+
return create_chat(messages, temperature, functions, True)
80+
81+
def create_chat(messages, temperature, functions, spinner=False):
82+
def create_chat_model():
83+
args = {
84+
"model": "gpt-4",
85+
"messages": messages,
86+
"stream": True,
87+
"temperature": temperature,
88+
}
89+
if functions:
90+
args['functions'] = list(map(lambda f: _json_schema(f), functions))
91+
return openai.ChatCompletion.create(**args)
92+
93+
if not spinner:
94+
chats = []
95+
return create_chat_model()
96+
else:
97+
return spin(create_chat_model)
3698

3799

38100
def suggest_name(chat_id, message):
@@ -48,3 +110,16 @@ def suggest_name(chat_id, message):
48110
name = chat_completion_resp.choices[0].message.content
49111
return (chat_id, name)
50112

113+
def invoke(functions, name, args_str):
114+
try:
115+
args = simplejson.loads(args_str, strict=False)
116+
except Exception as e:
117+
print("")
118+
print(f"ChatGPT called {name} with bad input: {args_str}")
119+
raise ValueError("The function arguments did not form a valid JSON document")
120+
121+
for func in functions:
122+
if func.__name__ == name:
123+
return func(**args)
124+
else:
125+
raise ValueError(f"Function '{name}' not found.")

src/db.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sqlite3
2+
from typing import Optional
23
import fcntl
34
import os
45
import sys
@@ -45,9 +46,11 @@ def create_schema(self):
4546
id INTEGER PRIMARY KEY AUTOINCREMENT,
4647
chat_id INTEGER,
4748
role TEXT,
48-
content TEXT,
49+
content TEXT NULL,
4950
time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
5051
deleted INTEGER DEFAULT 0,
52+
function_call_name TEXT,
53+
function_call_arguments TEXT,
5154
FOREIGN KEY (chat_id) REFERENCES chats (id)
5255
)
5356
"""
@@ -70,16 +73,17 @@ def create_chat(self, name=None):
7073
self.conn.commit()
7174
return chat_id
7275

73-
def add_message(self, chat_id: int, role: str, content: str):
76+
def add_message(self, chat_id: int, role: str, content: str, function_call_name: Optional[str], function_call_arguments: Optional[str]):
7477
cursor = self.conn.cursor() # Create a cursor
75-
query = "INSERT INTO messages (chat_id, role, content) VALUES (?, ?, ?)"
76-
cursor.execute(query, (chat_id, role, content))
78+
query = "INSERT INTO messages (chat_id, role, content, function_call_name, function_call_arguments) VALUES (?, ?, ?, ?, ?)"
79+
cursor.execute(query, (chat_id, role, content, function_call_name, function_call_arguments))
7780
self.conn.commit()
7881
last_message_id = cursor.lastrowid
7982

80-
fts_query = "INSERT INTO messages_fts (message_id, content) VALUES (?, ?)"
81-
self.conn.execute(fts_query, (last_message_id, content.lower()))
82-
self.conn.commit()
83+
if content is not None and role != "function":
84+
fts_query = "INSERT INTO messages_fts (message_id, content) VALUES (?, ?)"
85+
self.conn.execute(fts_query, (last_message_id, content.lower()))
86+
self.conn.commit()
8387

8488
query = f"UPDATE chats SET last_update = CURRENT_TIMESTAMP WHERE id = ?"
8589
cursor.execute(query, (chat_id, ))
@@ -100,7 +104,7 @@ def get_message_by_id(self, message_id: int):
100104
raise IndexError("Index out of range")
101105

102106
def get_message_by_index(self, chat_id: int, index: int):
103-
query = "SELECT role, content, time, id, deleted FROM messages WHERE chat_id = ? ORDER BY id LIMIT 1 OFFSET ?"
107+
query = "SELECT role, content, time, id, deleted, function_call_name, function_call_arguments FROM messages WHERE chat_id = ? ORDER BY id LIMIT 1 OFFSET ?"
104108
result = self.conn.execute(query, (chat_id, index)).fetchone()
105109
if result:
106110
return result

src/input_reader.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class UserInput:
2121
query: Optional[str] = None
2222
regenerate = False
2323
edit = False
24+
allow_execution = False
2425

2526
@dataclass
2627
class Chat:
@@ -30,8 +31,9 @@ class Chat:
3031

3132

3233
# Returns UserInput
33-
def read_input(chats, current_chat_name, can_regen, placeholder):
34+
def read_input(chats, current_chat_name, can_regen, placeholder, allow_execution):
3435
result = UserInput()
36+
result.allow_execution = allow_execution
3537

3638
session = PromptSession()
3739
kb = KeyBindings()
@@ -55,6 +57,7 @@ def _(event):
5557
SEARCH = "$$$SEARCH"
5658
REGENERATE = "$$$REGENERATE"
5759
EDIT = "$$$EDIT"
60+
TOGGLE_SETTING = "$$$TOGGLE_SETTING"
5861

5962
@kb.add(Keys.F2)
6063
def _(event):
@@ -81,6 +84,15 @@ def _(event):
8184
app = get_app()
8285
app.exit(result=EDIT)
8386

87+
@kb.add(Keys.F7)
88+
def _(event):
89+
result.allow_execution = not result.allow_execution
90+
if result.allow_execution:
91+
print("Command execution enabled.")
92+
else:
93+
print("Command execution disabled.")
94+
app = get_app()
95+
app.exit(result=TOGGLE_SETTING)
8496
def read_search_query():
8597
search_session = PromptSession()
8698
return search_session.prompt(HTML("<b>Search: </b>"))
@@ -148,6 +160,10 @@ def bottom_toolbar():
148160
if can_regen:
149161
text += " <b>F5</b>: Regenerate"
150162
text += " <b>F6</b>: Edit Last"
163+
if result.allow_execution:
164+
text += " <b>F7</b>: Disable Execution"
165+
else:
166+
text += " <b>F7</b>: Enable Execution"
151167
return HTML(text)
152168
value = session.prompt(
153169
"",
@@ -166,6 +182,8 @@ def bottom_toolbar():
166182
elif value == REGENERATE:
167183
result.regenerate = True
168184
return result
185+
elif value == TOGGLE_SETTING:
186+
continue
169187
elif value == EDIT:
170188
result.edit = True
171189
return result

0 commit comments

Comments
 (0)