Skip to content

Commit 5aaa447

Browse files
authored
Migrate ChatGPT function to openai v1.0 (#1368)
Migrate ChatGPT function to openai v1.0. The test is skipped in circleCI because we must supply the `OPENAI_API_KEY`. The test passes on local machine. - [x] Upgrade ChatGPT function. - [x] Upgrade Dall-e function. - [x] Update unit test cases. - [x] Verify that notebooks work correctly.
1 parent 0c25a44 commit 5aaa447

File tree

5 files changed

+59
-30
lines changed

5 files changed

+59
-30
lines changed

evadb/functions/chatgpt.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,21 @@ def setup(
115115
)
116116
def forward(self, text_df):
117117
try_to_import_openai()
118-
import openai
118+
from openai import OpenAI
119119

120-
@retry(tries=6, delay=20)
121-
def completion_with_backoff(**kwargs):
122-
return openai.ChatCompletion.create(**kwargs)
123-
124-
openai.api_key = self.openai_api_key
125-
if len(openai.api_key) == 0:
126-
openai.api_key = os.environ.get("OPENAI_API_KEY", "")
120+
api_key = self.openai_api_key
121+
if len(self.openai_api_key) == 0:
122+
api_key = os.environ.get("OPENAI_API_KEY", "")
127123
assert (
128-
len(openai.api_key) != 0
124+
len(api_key) != 0
129125
), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)"
130126

127+
client = OpenAI(api_key=api_key)
128+
129+
@retry(tries=6, delay=20)
130+
def completion_with_backoff(**kwargs):
131+
return client.chat.completions.create(**kwargs)
132+
131133
queries = text_df[text_df.columns[0]]
132134
content = text_df[text_df.columns[0]]
133135
if len(text_df.columns) > 1:

evadb/functions/dalle.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,24 +56,25 @@ def setup(self, openai_api_key="") -> None:
5656
)
5757
def forward(self, text_df):
5858
try_to_import_openai()
59-
import openai
59+
from openai import OpenAI
6060

61-
openai.api_key = self.openai_api_key
62-
# If not found, try OS Environment Variable
63-
if len(openai.api_key) == 0:
64-
openai.api_key = os.environ.get("OPENAI_API_KEY", "")
61+
api_key = self.openai_api_key
62+
if len(self.openai_api_key) == 0:
63+
api_key = os.environ.get("OPENAI_API_KEY", "")
6564
assert (
66-
len(openai.api_key) != 0
67-
), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)"
65+
len(api_key) != 0
66+
), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)"
67+
68+
client = OpenAI(api_key=api_key)
6869

6970
def generate_image(text_df: PandasDataframe):
7071
results = []
7172
queries = text_df[text_df.columns[0]]
7273
for query in queries:
73-
response = openai.Image.create(prompt=query, n=1, size="1024x1024")
74+
response = client.images.generate(prompt=query, n=1, size="1024x1024")
7475

7576
# Download the image from the link
76-
image_response = requests.get(response["data"][0]["url"])
77+
image_response = requests.get(response.data[0].url)
7778
image = Image.open(BytesIO(image_response.content))
7879

7980
# Convert the image to an array format suitable for the DataFrame

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def read(path, encoding="utf-8"):
8080
"sentence-transformers",
8181
"protobuf",
8282
"bs4",
83-
"openai==0.28", # CHATGPT
83+
"openai>=1.0", # CHATGPT
8484
"gpt4all", # PRIVATE GPT
8585
"sentencepiece", # TRANSFORMERS
8686
]

test/integration_tests/long/functions/test_chatgpt.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import os
1617
import unittest
1718
from test.markers import chatgpt_skip_marker
1819
from test.util import get_evadb_for_testing
@@ -22,9 +23,8 @@
2223
from evadb.server.command_handler import execute_query_fetch_all
2324

2425

25-
def create_dummy_csv_file(config) -> str:
26-
tmp_dir_from_config = config.get_value("storage", "tmp_dir")
27-
26+
def create_dummy_csv_file(catalog) -> str:
27+
tmp_dir_from_config = catalog.get_configuration_catalog_value("tmp_dir")
2828
df_dict = [
2929
{
3030
"prompt": "summarize",
@@ -49,17 +49,18 @@ def setUp(self) -> None:
4949
);"""
5050
execute_query_fetch_all(self.evadb, create_table_query)
5151

52-
self.csv_file_path = create_dummy_csv_file(self.evadb.config)
52+
self.csv_file_path = create_dummy_csv_file(self.evadb.catalog())
5353

5454
csv_query = f"""LOAD CSV '{self.csv_file_path}' INTO MyTextCSV;"""
5555
execute_query_fetch_all(self.evadb, csv_query)
56+
os.environ["OPENAI_API_KEY"] = "sk-..."
5657

5758
def tearDown(self) -> None:
5859
execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS MyTextCSV;")
5960

6061
@chatgpt_skip_marker
6162
def test_openai_chat_completion_function(self):
62-
function_name = "OpenAIChatCompletion"
63+
function_name = "ChatGPT"
6364
execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};")
6465

6566
create_function_query = f"""CREATE FUNCTION IF NOT EXISTS{function_name}
@@ -69,4 +70,4 @@ def test_openai_chat_completion_function(self):
6970

7071
gpt_query = f"SELECT {function_name}('summarize', content) FROM MyTextCSV;"
7172
output_batch = execute_query_fetch_all(self.evadb, gpt_query)
72-
self.assertEqual(output_batch.columns, ["openaichatcompletion.response"])
73+
self.assertEqual(output_batch.columns, ["chatgpt.response"])

test/unit_tests/test_dalle.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,26 @@
1616
import unittest
1717
from io import BytesIO
1818
from test.util import get_evadb_for_testing
19+
from typing import List, Optional
1920
from unittest.mock import MagicMock, patch
2021

21-
from PIL import Image
22+
from PIL import Image as PILImage
23+
from pydantic import AnyUrl, BaseModel
2224

2325
from evadb.server.command_handler import execute_query_fetch_all
2426

2527

28+
class Image(BaseModel):
29+
b64_json: Optional[str] # Replace with the actual type if different
30+
revised_prompt: Optional[str] # Replace with the actual type if different
31+
url: AnyUrl
32+
33+
34+
class ImagesResponse(BaseModel):
35+
created: Optional[int] # Replace with the actual type if different
36+
data: List[Image]
37+
38+
2639
class DallEFunctionTest(unittest.TestCase):
2740
def setUp(self) -> None:
2841
self.evadb = get_evadb_for_testing()
@@ -43,10 +56,10 @@ def tearDown(self) -> None:
4356

4457
@patch.dict("os.environ", {"OPENAI_API_KEY": "mocked_openai_key"})
4558
@patch("requests.get")
46-
@patch("openai.Image.create", return_value={"data": [{"url": "mocked_url"}]})
47-
def test_dalle_image_generation(self, mock_openai_create, mock_requests_get):
59+
@patch("openai.OpenAI")
60+
def test_dalle_image_generation(self, mock_openai, mock_requests_get):
4861
# Generate a 1x1 white pixel PNG image in memory
49-
img = Image.new("RGB", (1, 1), color="white")
62+
img = PILImage.new("RGB", (1, 1), color="white")
5063
img_byte_array = BytesIO()
5164
img.save(img_byte_array, format="PNG")
5265
mock_image_content = img_byte_array.getvalue()
@@ -55,6 +68,18 @@ def test_dalle_image_generation(self, mock_openai_create, mock_requests_get):
5568
mock_response.content = mock_image_content
5669
mock_requests_get.return_value = mock_response
5770

71+
# Set up the mock for OpenAI instance
72+
mock_openai_instance = mock_openai.return_value
73+
mock_openai_instance.images.generate.return_value = ImagesResponse(
74+
data=[
75+
Image(
76+
b64_json=None,
77+
revised_prompt=None,
78+
url="https://images.openai.com/1234.png",
79+
)
80+
]
81+
)
82+
5883
function_name = "DallE"
5984

6085
execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};")
@@ -67,6 +92,6 @@ def test_dalle_image_generation(self, mock_openai_create, mock_requests_get):
6792
gpt_query = f"SELECT {function_name}(prompt) FROM ImageGen;"
6893
execute_query_fetch_all(self.evadb, gpt_query)
6994

70-
mock_openai_create.assert_called_once_with(
95+
mock_openai_instance.images.generate.assert_called_once_with(
7196
prompt="a surreal painting of a cat", n=1, size="1024x1024"
7297
)

0 commit comments

Comments
 (0)