diff --git a/src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java index 55a01d35..91f27bea 100644 --- a/src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java +++ b/src/demo/java/io/github/sashirestela/openai/demo/AbstractDemo.java @@ -1,20 +1,19 @@ package io.github.sashirestela.openai.demo; -import io.github.sashirestela.cleverclient.http.HttpRequestData; +import io.github.sashirestela.openai.BaseSimpleOpenAI; import io.github.sashirestela.openai.SimpleOpenAI; import java.util.ArrayList; import java.util.List; -import java.util.function.UnaryOperator; import lombok.NonNull; public abstract class AbstractDemo { private String apiKey; private String organizationId; - protected SimpleOpenAI openAI; + protected BaseSimpleOpenAI openAI; private static List titleActions = new ArrayList<>(); - private int times = 80; + private final int times = 80; protected AbstractDemo() { apiKey = System.getenv("OPENAI_API_KEY"); @@ -25,14 +24,8 @@ protected AbstractDemo() { .build(); } - protected AbstractDemo(@NonNull String baseUrl, - @NonNull String apiKey, - @NonNull UnaryOperator requestInterceptor) { - openAI = SimpleOpenAI.builder() - .apiKey(apiKey) - .baseUrl(baseUrl) - .requestInterceptor(requestInterceptor) - .build(); + protected AbstractDemo(@NonNull BaseSimpleOpenAI openAI) { + this.openAI = openAI; } public void addTitleAction(String title, Action action) { diff --git a/src/demo/java/io/github/sashirestela/openai/demo/AzureOpenAIChatServiceDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/AzureOpenAIChatServiceDemo.java deleted file mode 100644 index 23c384c7..00000000 --- a/src/demo/java/io/github/sashirestela/openai/demo/AzureOpenAIChatServiceDemo.java +++ /dev/null @@ -1,81 +0,0 @@ -package io.github.sashirestela.openai.demo; - - -import io.github.sashirestela.cleverclient.support.ContentType; -import io.github.sashirestela.openai.domain.chat.ChatRequest; -import io.github.sashirestela.openai.domain.chat.message.ChatMsgSystem; -import io.github.sashirestela.openai.domain.chat.message.ChatMsgUser; -import java.util.Map; -import java.util.Optional; - -public class AzureOpenAIChatServiceDemo extends AbstractDemo { - private static final String AZURE_OPENAI_API_KEY_HEADER = "api-key"; - private final ChatRequest chatRequest; - - @SuppressWarnings("unchecked") - public AzureOpenAIChatServiceDemo(String baseUrl, String apiKey, String model) { - super(baseUrl, apiKey, request -> { - var url = request.getUrl(); - var contentType = request.getContentType(); - var body = request.getBody(); - - // add a header to the request - var headers = request.getHeaders(); - headers.put(AZURE_OPENAI_API_KEY_HEADER, apiKey); - request.setHeaders(headers); - - // add a query parameter to url - url += (url.contains("?") ? "&" : "?") + "api-version=2023-05-15"; - // remove '/vN' or '/vN.M' from url - url = url.replaceFirst("(\\/v\\d+\\.*\\d*)", ""); - request.setUrl(url); - - if (contentType != null) { - if (contentType.equals(ContentType.APPLICATION_JSON)) { - var bodyJson = (String) request.getBody(); - // remove a field from body (as Json) - bodyJson = bodyJson.replaceFirst(",?\"model\":\"[^\"]*\",?", ""); - bodyJson = bodyJson.replaceFirst("\"\"", "\",\""); - body = bodyJson; - } - if (contentType.equals(ContentType.MULTIPART_FORMDATA)) { - Map bodyMap = (Map) request.getBody(); - // remove a field from body (as Map) - bodyMap.remove("model"); - body = bodyMap; - } - request.setBody(body); - } - - return request; - }); - - chatRequest = ChatRequest.builder() - .model(model) - .message(new ChatMsgSystem("You are an expert in AI.")) - .message( - new ChatMsgUser("Write a technical article about ChatGPT, no more than 100 words.")) - .temperature(0.0) - .maxTokens(300) - .build(); - } - - public void demoCallChatBlocking() { - var futureChat = openAI.chatCompletions().create(chatRequest); - var chatResponse = futureChat.join(); - System.out.println(chatResponse.firstContent()); - } - - public static void main(String[] args) { - var baseUrl = System.getenv("CUSTOM_OPENAI_BASE_URL"); - var apiKey = System.getenv("CUSTOM_OPENAI_API_KEY"); - // Services like Azure OpenAI don't require a model (endpoints have built-in model) - var model = Optional.ofNullable(System.getenv("CUSTOM_OPENAI_MODEL")) - .orElse("N/A"); - var demo = new AzureOpenAIChatServiceDemo(baseUrl, apiKey, model); - - demo.addTitleAction("Call Completion (Blocking Approach)", demo::demoCallChatBlocking); - - demo.run(); - } -} diff --git a/src/demo/java/io/github/sashirestela/openai/demo/ChatAnyscaleServiceDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/ChatAnyscaleServiceDemo.java new file mode 100644 index 00000000..3f57bece --- /dev/null +++ b/src/demo/java/io/github/sashirestela/openai/demo/ChatAnyscaleServiceDemo.java @@ -0,0 +1,108 @@ +package io.github.sashirestela.openai.demo; + + +import io.github.sashirestela.openai.SimpleOpenAIAnyscale; +import io.github.sashirestela.openai.demo.ChatServiceDemo.Product; +import io.github.sashirestela.openai.demo.ChatServiceDemo.RunAlarm; +import io.github.sashirestela.openai.demo.ChatServiceDemo.Weather; +import io.github.sashirestela.openai.domain.chat.ChatRequest; +import io.github.sashirestela.openai.domain.chat.ChatResponse; +import io.github.sashirestela.openai.domain.chat.message.ChatMsg; +import io.github.sashirestela.openai.domain.chat.message.ChatMsgSystem; +import io.github.sashirestela.openai.domain.chat.message.ChatMsgTool; +import io.github.sashirestela.openai.domain.chat.message.ChatMsgUser; +import io.github.sashirestela.openai.domain.chat.tool.ChatFunction; +import io.github.sashirestela.openai.function.FunctionExecutor; +import java.util.ArrayList; + +public class ChatAnyscaleServiceDemo extends AbstractDemo { + + public static final String MODEL = "mistralai/Mixtral-8x7B-Instruct-v0.1"; + + + private ChatRequest chatRequest; + + + public ChatAnyscaleServiceDemo(String apiKey, String model) { + super(SimpleOpenAIAnyscale.builder().apiKey(apiKey).build()); + chatRequest = ChatRequest.builder() + .model(model) + .message(new ChatMsgSystem("You are an expert in AI.")) + .message( + new ChatMsgUser("Write a technical article about ChatGPT, no more than 100 words.")) + .temperature(0.0) + .maxTokens(300) + .build(); + } + + public void demoCallChatStreaming() { + var futureChat = openAI.chatCompletions().createStream(chatRequest); + var chatResponse = futureChat.join(); + chatResponse.filter(chatResp -> chatResp.firstContent() != null) + .map(ChatResponse::firstContent) + .forEach(System.out::print); + System.out.println(); + } + + public void demoCallChatBlocking() { + var futureChat = openAI.chatCompletions().create(chatRequest); + var chatResponse = futureChat.join(); + System.out.println(chatResponse.firstContent()); + } + + public void demoCallChatWithFunctions() { + var functionExecutor = new FunctionExecutor(); + functionExecutor.enrollFunction( + ChatFunction.builder() + .name("get_weather") + .description("Get the current weather of a location") + .functionalClass(Weather.class) + .build()); + functionExecutor.enrollFunction( + ChatFunction.builder() + .name("product") + .description("Get the product of two numbers") + .functionalClass(Product.class) + .build()); + functionExecutor.enrollFunction( + ChatFunction.builder() + .name("run_alarm") + .description("Run an alarm") + .functionalClass(RunAlarm.class) + .build()); + var messages = new ArrayList(); + messages.add(new ChatMsgUser("What is the product of 123 and 456?")); + var chatRequest = ChatRequest.builder() + .model(MODEL) + .messages(messages) + .tools(functionExecutor.getToolFunctions()) + .build(); + var futureChat = openAI.chatCompletions().create(chatRequest); + var chatResponse = futureChat.join(); + var chatMessage = chatResponse.firstMessage(); + var chatToolCall = chatMessage.getToolCalls().get(0); + var result = functionExecutor.execute(chatToolCall.getFunction()); + messages.add(chatMessage); + messages.add(new ChatMsgTool(result.toString(), chatToolCall.getId())); + chatRequest = ChatRequest.builder() + .model(MODEL) + .messages(messages) + .tools(functionExecutor.getToolFunctions()) + .build(); + futureChat = openAI.chatCompletions().create(chatRequest); + chatResponse = futureChat.join(); + System.out.println(chatResponse.firstContent()); + } + + public static void main(String[] args) { + var apiKey = System.getenv("ANYSCALE_API_KEY"); + // Services like Azure OpenAI don't require a model (endpoints have built-in model) + var demo = new ChatAnyscaleServiceDemo(apiKey, MODEL); + + demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming); + demo.addTitleAction("Call Chat (Blocking Approach)", demo::demoCallChatBlocking); + demo.addTitleAction("Call Chat with Functions", demo::demoCallChatWithFunctions); + + demo.run(); + } +} diff --git a/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureServiceDemo.java b/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureServiceDemo.java new file mode 100644 index 00000000..ee6cc746 --- /dev/null +++ b/src/demo/java/io/github/sashirestela/openai/demo/ChatAzureServiceDemo.java @@ -0,0 +1,174 @@ +package io.github.sashirestela.openai.demo; + + +import io.github.sashirestela.openai.SimpleOpenAIAzure; +import io.github.sashirestela.openai.demo.ChatServiceDemo.Product; +import io.github.sashirestela.openai.demo.ChatServiceDemo.RunAlarm; +import io.github.sashirestela.openai.demo.ChatServiceDemo.Weather; +import io.github.sashirestela.openai.domain.chat.ChatRequest; +import io.github.sashirestela.openai.domain.chat.ChatResponse; +import io.github.sashirestela.openai.domain.chat.content.ContentPartImage; +import io.github.sashirestela.openai.domain.chat.content.ContentPartText; +import io.github.sashirestela.openai.domain.chat.content.ImageUrl; +import io.github.sashirestela.openai.domain.chat.message.ChatMsg; +import io.github.sashirestela.openai.domain.chat.message.ChatMsgSystem; +import io.github.sashirestela.openai.domain.chat.message.ChatMsgTool; +import io.github.sashirestela.openai.domain.chat.message.ChatMsgUser; +import io.github.sashirestela.openai.domain.chat.tool.ChatFunction; +import io.github.sashirestela.openai.function.FunctionExecutor; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; + +public class ChatAzureServiceDemo extends AbstractDemo { + private ChatRequest chatRequest; + + public ChatAzureServiceDemo(String baseUrl, String apiKey, String apiVersion) { + super(SimpleOpenAIAzure.builder() + .apiKey(apiKey) + .baseUrl(baseUrl) + .apiVersion(apiVersion) + .build()); + chatRequest = ChatRequest.builder() + .model("N/A") + .message(new ChatMsgSystem("You are an expert in AI.")) + .message( + new ChatMsgUser("Write a technical article about ChatGPT, no more than 100 words.")) + .temperature(0.0) + .maxTokens(300) + .build(); + } + + public void demoCallChatStreaming() { + var futureChat = openAI.chatCompletions().createStream(chatRequest); + var chatResponse = futureChat.join(); + chatResponse.filter(chatResp -> chatResp.firstContent() != null) + .map(ChatResponse::firstContent) + .forEach(System.out::print); + System.out.println(); + } + + public void demoCallChatBlocking() { + var futureChat = openAI.chatCompletions().create(chatRequest); + var chatResponse = futureChat.join(); + System.out.println(chatResponse.firstContent()); + } + + public void demoCallChatWithFunctions() { + var functionExecutor = new FunctionExecutor(); + functionExecutor.enrollFunction( + ChatFunction.builder() + .name("get_weather") + .description("Get the current weather of a location") + .functionalClass(Weather.class) + .build()); + functionExecutor.enrollFunction( + ChatFunction.builder() + .name("product") + .description("Get the product of two numbers") + .functionalClass(Product.class) + .build()); + functionExecutor.enrollFunction( + ChatFunction.builder() + .name("run_alarm") + .description("Run an alarm") + .functionalClass(RunAlarm.class) + .build()); + var messages = new ArrayList(); + messages.add(new ChatMsgUser("What is the product of 123 and 456?")); + chatRequest = ChatRequest.builder() + .model("N/A") + .messages(messages) + .tools(functionExecutor.getToolFunctions()) + .build(); + var futureChat = openAI.chatCompletions().create(chatRequest); + var chatResponse = futureChat.join(); + var chatMessage = chatResponse.firstMessage(); + var chatToolCall = chatMessage.getToolCalls().get(0); + var result = functionExecutor.execute(chatToolCall.getFunction()); + messages.add(chatMessage); + messages.add(new ChatMsgTool(result.toString(), chatToolCall.getId())); + chatRequest = ChatRequest.builder() + .model("N/A") + .messages(messages) + .tools(functionExecutor.getToolFunctions()) + .build(); + futureChat = openAI.chatCompletions().create(chatRequest); + chatResponse = futureChat.join(); + System.out.println(chatResponse.firstContent()); + } + + public void demoCallChatWithVisionExternalImage() { + var chatRequest = ChatRequest.builder() + .model("N/A") + .messages(List.of( + new ChatMsgUser(List.of( + new ContentPartText( + "What do you see in the image? Give in details in no more than 100 words."), + new ContentPartImage(new ImageUrl( + "https://upload.wikimedia.org/wikipedia/commons/e/eb/Machu_Picchu%2C_Peru.jpg")))))) + .temperature(0.0) + .maxTokens(500) + .build(); + var chatResponse = openAI.chatCompletions().createStream(chatRequest).join(); + chatResponse.filter(chatResp -> chatResp.firstContent() != null) + .map(chatResp -> chatResp.firstContent()) + .forEach(System.out::print); + System.out.println(); + } + + public void demoCallChatWithVisionLocalImage() { + var chatRequest = ChatRequest.builder() + .model("N/A") + .messages(List.of( + new ChatMsgUser(List.of( + new ContentPartText( + "What do you see in the image? Give in details in no more than 100 words."), + new ContentPartImage(loadImageAsBase64("src/demo/resources/machupicchu.jpg")))))) + .temperature(0.0) + .maxTokens(500) + .build(); + var chatResponse = openAI.chatCompletions().createStream(chatRequest).join(); + chatResponse.filter(chatResp -> chatResp.firstContent() != null) + .map(chatResp -> chatResp.firstContent()) + .forEach(System.out::print); + System.out.println(); + } + + private static ImageUrl loadImageAsBase64(String imagePath) { + try { + Path path = Paths.get(imagePath); + byte[] imageBytes = Files.readAllBytes(path); + String base64String = Base64.getEncoder().encodeToString(imageBytes); + var extension = imagePath.substring(imagePath.lastIndexOf(".") + 1); + var prefix = "data:image/" + extension + ";base64,"; + return new ImageUrl(prefix + base64String); + } catch (Exception e) { + e.printStackTrace(); + return null; + } + } + + public static void main(String[] args) { + var baseUrl = System.getenv("AZURE_OPENAI_BASE_URL"); + var apiKey = System.getenv("AZURE_OPENAI_API_KEY"); + var apiVersion = System.getenv("AZURE_OPENAI_API_VERSION"); + // Services like Azure OpenAI don't require a model (endpoints have built-in model) + var demo = new ChatAzureServiceDemo(baseUrl, apiKey, apiVersion); + + + demo.addTitleAction("Call Chat (Blocking Approach)", demo::demoCallChatBlocking); + if (baseUrl.contains("gpt-35-turbo")) { + demo.addTitleAction("Call Chat with Functions", demo::demoCallChatWithFunctions); + } else if (baseUrl.contains("gpt-4")){ + demo.addTitleAction("Call Chat (Streaming Approach)", demo::demoCallChatStreaming); + demo.addTitleAction("Call Chat with Vision (External image)", demo::demoCallChatWithVisionExternalImage); + demo.addTitleAction("Call Chat with Vision (Local image)", demo::demoCallChatWithVisionLocalImage); + } + + demo.run(); + } +} diff --git a/src/main/java/io/github/sashirestela/openai/BaseSimpleOpenAI.java b/src/main/java/io/github/sashirestela/openai/BaseSimpleOpenAI.java new file mode 100644 index 00000000..965004f0 --- /dev/null +++ b/src/main/java/io/github/sashirestela/openai/BaseSimpleOpenAI.java @@ -0,0 +1,121 @@ +package io.github.sashirestela.openai; + +import io.github.sashirestela.cleverclient.CleverClient; +import java.net.http.HttpClient; +import java.util.Optional; +import lombok.NonNull; +import lombok.Setter; + + +/** + * The base abstract class that all providers extend. It generates + * an implementation to the chatCompletions() interface of {@link OpenAI OpenAI} interfaces. + * It throws a "Not implemented" exception for all other interfaces + */ + + +public abstract class BaseSimpleOpenAI { + + private static final String END_OF_STREAM = "[DONE]"; + + @Setter + protected CleverClient cleverClient; + + protected OpenAI.ChatCompletions chatCompletionService; + + BaseSimpleOpenAI(@NonNull BaseSimpleOpenAIArgs args) { + var httpClient = + Optional.ofNullable(args.getHttpClient()).orElse(HttpClient.newHttpClient()); + this.cleverClient = CleverClient.builder() + .httpClient(httpClient) + .baseUrl(args.getBaseUrl()) + .headers(args.getHeaders()) + .endOfStream(END_OF_STREAM) + .requestInterceptor(args.getRequestInterceptor()) + .build(); + } + + /** + * Throw not implemented + */ + public OpenAI.Audios audios() { + throw new SimpleUncheckedException("Not implemented"); + } + + /** + * Generates an implementation of the ChatCompletions interface to handle + * requests. + * + * @return An instance of the interface. It is created only once. + */ + public OpenAI.ChatCompletions chatCompletions() { + if (this.chatCompletionService == null) { + this.chatCompletionService = this.cleverClient.create(OpenAI.ChatCompletions.class); + } + return this.chatCompletionService; + + } + + /** + * Throw not implemented + */ + public OpenAI.Completions completions() { + throw new SimpleUncheckedException("Not implemented"); + } + + /** + * Throw not implemented + */ + public OpenAI.Embeddings embeddings() { + throw new SimpleUncheckedException("Not implemented"); + } + + /** + * Throw not implemented + */ + public OpenAI.Files files() { + throw new SimpleUncheckedException("Not implemented"); + } + + /** + * Throw not implemented + */ + public OpenAI.FineTunings fineTunings() { + throw new SimpleUncheckedException("Not implemented"); + } + + /** + * Throw not implemented + */ + public OpenAI.Images images() { + throw new SimpleUncheckedException("Not implemented"); + } + + /** + * Throw not implemented + */ + public OpenAI.Models models() { + throw new SimpleUncheckedException("Not implemented"); + } + + /** + * Throw not implemented + */ + public OpenAI.Moderations moderations() { + throw new SimpleUncheckedException("Not implemented"); + } + + /** + * Throw not implemented + */ + public OpenAI.Assistants assistants() { + throw new SimpleUncheckedException("Not implemented"); + } + + /** + * Throw not implemented + */ + public OpenAI.Threads threads() { + throw new SimpleUncheckedException("Not implemented"); + } +} diff --git a/src/main/java/io/github/sashirestela/openai/BaseSimpleOpenAIArgs.java b/src/main/java/io/github/sashirestela/openai/BaseSimpleOpenAIArgs.java new file mode 100644 index 00000000..6e60bd89 --- /dev/null +++ b/src/main/java/io/github/sashirestela/openai/BaseSimpleOpenAIArgs.java @@ -0,0 +1,19 @@ +package io.github.sashirestela.openai; + +import io.github.sashirestela.cleverclient.http.HttpRequestData; +import java.net.http.HttpClient; +import java.util.Map; +import java.util.function.UnaryOperator; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; + +@Getter +@Builder +public class BaseSimpleOpenAIArgs { + @NonNull + private final String baseUrl; + private final Map headers; + private final HttpClient httpClient; + private final UnaryOperator requestInterceptor; +} diff --git a/src/main/java/io/github/sashirestela/openai/OpenAI.java b/src/main/java/io/github/sashirestela/openai/OpenAI.java index 7c27b3d6..1636b5d5 100644 --- a/src/main/java/io/github/sashirestela/openai/OpenAI.java +++ b/src/main/java/io/github/sashirestela/openai/OpenAI.java @@ -1,5 +1,9 @@ package io.github.sashirestela.openai; +import static io.github.sashirestela.cleverclient.util.CommonUtil.isNullOrEmpty; + +import io.github.sashirestela.openai.domain.chat.tool.ChatToolChoice; +import io.github.sashirestela.openai.domain.chat.tool.ChatToolChoiceType; import java.io.InputStream; import java.util.EnumSet; import java.util.List; @@ -69,6 +73,25 @@ */ public interface OpenAI { + static ChatRequest updateRequest(ChatRequest chatRequest, Boolean useStream) { + var toolChoice = chatRequest.getToolChoice(); + + if (!isNullOrEmpty(chatRequest.getTools())) { + if (toolChoice == null) { + toolChoice = ChatToolChoiceType.AUTO; + } else if (!(toolChoice instanceof ChatToolChoice) && + !(toolChoice instanceof ChatToolChoiceType)) { + throw new SimpleUncheckedException( + "The field toolChoice must be ChatToolChoiceType or ChatToolChoice classes.", + null, null); + } + } + return chatRequest + .withStream(useStream) + .withToolChoice(toolChoice); + } + + /** * Turn audio into text (speech to text). * @@ -183,7 +206,6 @@ private AudioRespFmt getResponseFormat(AudioRespFmt currValue, AudioRespFmt orDe */ @Resource("/v1/chat/completions") interface ChatCompletions { - /** * Creates a model response for the given chat conversation. Blocking mode. * @@ -192,7 +214,7 @@ interface ChatCompletions { * @return Response is delivered as a full text when is ready. */ default CompletableFuture create(@Body ChatRequest chatRequest) { - var request = chatRequest.withStream(Boolean.FALSE); + var request = updateRequest(chatRequest, Boolean.FALSE); return __create(request); } @@ -207,7 +229,7 @@ default CompletableFuture create(@Body ChatRequest chatRequest) { * @return Response is delivered as a continues flow of tokens. */ default CompletableFuture> createStream(@Body ChatRequest chatRequest) { - var request = chatRequest.withStream(Boolean.TRUE); + var request = updateRequest(chatRequest, Boolean.TRUE); return __createStream(request); } diff --git a/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java b/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java index db42971b..c5468c84 100644 --- a/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java +++ b/src/main/java/io/github/sashirestela/openai/SimpleOpenAI.java @@ -1,71 +1,58 @@ package io.github.sashirestela.openai; -import io.github.sashirestela.cleverclient.CleverClient; -import io.github.sashirestela.cleverclient.http.HttpRequestData; import java.net.http.HttpClient; import java.util.HashMap; import java.util.Optional; -import java.util.function.UnaryOperator; -import lombok.AccessLevel; import lombok.Builder; -import lombok.Getter; import lombok.NonNull; -import lombok.Setter; /** - * The factory that generates implementations of the {@link OpenAI OpenAI} - * interfaces. + * This class provides the implements additional {@link OpenAI OpenAI} interfaces + * targeting the OpenAI service. */ -@Getter -public class SimpleOpenAI { +public class SimpleOpenAI extends BaseSimpleOpenAI { public static final String OPENAI_BASE_URL = "https://api.openai.com"; - private static final String AUTHORIZATION_HEADER = "Authorization"; - private static final String ORGANIZATION_HEADER = "OpenAI-Organization"; - private static final String BEARER_AUTHORIZATION = "Bearer "; - private static final String END_OF_STREAM = "[DONE]"; - - @NonNull - private final String apiKey; - private final String organizationId; - private final String baseUrl; - private final HttpClient httpClient; - @Setter - private CleverClient cleverClient; - - @Getter(AccessLevel.NONE) - private OpenAI.Audios audioService; - - @Getter(AccessLevel.NONE) - private OpenAI.ChatCompletions chatCompletionService; + public static final String AUTHORIZATION_HEADER = "Authorization"; + public static final String ORGANIZATION_HEADER = "OpenAI-Organization"; + public static final String BEARER_AUTHORIZATION = "Bearer "; - @Getter(AccessLevel.NONE) + private OpenAI.Audios audioService; private OpenAI.Completions completionService; - @Getter(AccessLevel.NONE) private OpenAI.Embeddings embeddingService; - @Getter(AccessLevel.NONE) private OpenAI.Files fileService; - @Getter(AccessLevel.NONE) private OpenAI.FineTunings fineTuningService; - @Getter(AccessLevel.NONE) private OpenAI.Images imageService; - @Getter(AccessLevel.NONE) private OpenAI.Models modelService; - @Getter(AccessLevel.NONE) private OpenAI.Moderations moderationService; - @Getter(AccessLevel.NONE) private OpenAI.Assistants assistantService; - @Getter(AccessLevel.NONE) private OpenAI.Threads threadService; + + public static BaseSimpleOpenAIArgs prepareBaseSimpleOpenAIArgs( + String apiKey, String organizationId, String baseUrl, HttpClient httpClient) { + + var headers = new HashMap(); + headers.put(AUTHORIZATION_HEADER, BEARER_AUTHORIZATION + apiKey); + if (organizationId != null) { + headers.put(ORGANIZATION_HEADER, organizationId); + } + + return BaseSimpleOpenAIArgs.builder() + .baseUrl(Optional.ofNullable(baseUrl).orElse(OPENAI_BASE_URL)) + .headers(headers) + .httpClient(httpClient) + .build(); + } + /** * Constructor used to generate a builder. * @@ -77,25 +64,8 @@ public class SimpleOpenAI { * One is created by default if not provided. Optional. */ @Builder - public SimpleOpenAI(@NonNull String apiKey, String organizationId, String baseUrl, HttpClient httpClient, - UnaryOperator requestInterceptor) { - this.apiKey = apiKey; - this.organizationId = organizationId; - this.baseUrl = Optional.ofNullable(baseUrl).orElse(OPENAI_BASE_URL); - this.httpClient = Optional.ofNullable(httpClient).orElse(HttpClient.newHttpClient()); - - var headers = new HashMap(); - headers.put(AUTHORIZATION_HEADER, BEARER_AUTHORIZATION + apiKey); - if (organizationId != null) { - headers.put(ORGANIZATION_HEADER, organizationId); - } - this.cleverClient = CleverClient.builder() - .httpClient(this.httpClient) - .baseUrl(this.baseUrl) - .headers(headers) - .endOfStream(END_OF_STREAM) - .requestInterceptor(requestInterceptor) - .build(); + public SimpleOpenAI(@NonNull String apiKey, String organizationId, String baseUrl, HttpClient httpClient) { + super(prepareBaseSimpleOpenAIArgs(apiKey, organizationId, baseUrl, httpClient)); } /** @@ -110,19 +80,6 @@ public OpenAI.Audios audios() { return audioService; } - /** - * Generates an implementation of the ChatCompletions interface to handle - * requests. - * - * @return An instance of the interface. It is created only once. - */ - public OpenAI.ChatCompletions chatCompletions() { - if (chatCompletionService == null) { - chatCompletionService = cleverClient.create(OpenAI.ChatCompletions.class); - } - return chatCompletionService; - } - /** * Generates an implementation of the Completions interface to handle requests. * @@ -140,6 +97,7 @@ public OpenAI.Completions completions() { * * @return An instance of the interface. It is created only once. */ + public OpenAI.Embeddings embeddings() { if (embeddingService == null) { embeddingService = cleverClient.create(OpenAI.Embeddings.class); diff --git a/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAnyscale.java b/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAnyscale.java new file mode 100644 index 00000000..70c8b580 --- /dev/null +++ b/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAnyscale.java @@ -0,0 +1,42 @@ +package io.github.sashirestela.openai; + +import java.net.http.HttpClient; +import java.util.HashMap; +import java.util.Optional; +import lombok.Builder; +import lombok.NonNull; + +/** + * This class provides the chatCompletion() service for the Anyscale provider + */ +public class SimpleOpenAIAnyscale extends BaseSimpleOpenAI { + public static final String DEFAULT_BASE_URL = "https://api.endpoints.anyscale.com"; + + public static final String AUTHORIZATION_HEADER = "Authorization"; + public static final String BEARER_AUTHORIZATION = "Bearer "; + + /** + * Constructor used to generate a builder. + * + * @param apiKey Identifier to be used for authentication. Mandatory. + * @param baseUrl Host's url + * @param httpClient A {@link java.net.http.HttpClient HttpClient} object. + * One is created by default if not provided. Optional. + */ + public static BaseSimpleOpenAIArgs prepareBaseSimpleOpenAIArgs(String apiKey, String baseUrl, HttpClient httpClient) { + baseUrl = Optional.ofNullable(baseUrl).orElse(DEFAULT_BASE_URL); + var headers = new HashMap(); + headers.put(AUTHORIZATION_HEADER, BEARER_AUTHORIZATION + apiKey); + + return BaseSimpleOpenAIArgs.builder() + .baseUrl(baseUrl) + .headers(headers) + .httpClient(httpClient) + .build(); + } + + @Builder + public SimpleOpenAIAnyscale(@NonNull String apiKey, String baseUrl, HttpClient httpClient) { + super(prepareBaseSimpleOpenAIArgs(apiKey, baseUrl, httpClient)); + } +} diff --git a/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAzure.java b/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAzure.java new file mode 100644 index 00000000..e9c9c223 --- /dev/null +++ b/src/main/java/io/github/sashirestela/openai/SimpleOpenAIAzure.java @@ -0,0 +1,83 @@ +package io.github.sashirestela.openai; + +import io.github.sashirestela.cleverclient.http.HttpRequestData; +import io.github.sashirestela.cleverclient.support.ContentType; +import java.net.http.HttpClient; +import java.util.Map; +import java.util.function.UnaryOperator; +import lombok.Builder; +import lombok.NonNull; + +/** + * This class provides the chatCompletion() service for the Azure OpenAI provider + * Note that each instance of SimpleOpenAIAzure is linked to a single specific model. + * The capabilities of the model determine which chatCompletion() methods are available. + */ +public class SimpleOpenAIAzure extends BaseSimpleOpenAI { + + public static final String API_KEY_HEADER = "api-key"; + public static final String API_VERSION = "api-version"; + + private static final String ENDPOINT_VERSION_REGEX = "(\\/v\\d+\\.*\\d*)"; + private static final String MODEL_REGEX = ",?\"model\":\"[^\"]*\",?"; + + private static final String EMPTY_REGEX = "\"\""; + private static final String QUOTED_COMMA = "\",\""; + + private static final String MODEL_LITERAL = "model"; + + public static BaseSimpleOpenAIArgs prepareBaseSimpleOpenAIArgs(String apiKey, String baseUrl, String apiVersion, HttpClient httpClient) { + + var headers = Map.of(API_KEY_HEADER, apiKey); + + var requestInterceptor = (UnaryOperator) request -> { + var url = request.getUrl(); + var contentType = request.getContentType(); + var body = request.getBody(); + + url += (url.contains("?") ? "&" : "?") + API_VERSION + "=" + apiVersion; + url = url.replaceFirst(ENDPOINT_VERSION_REGEX, ""); + request.setUrl(url); + + if (contentType != null) { + if (contentType.equals(ContentType.APPLICATION_JSON)) { + var bodyJson = (String) request.getBody(); + bodyJson = bodyJson.replaceFirst(MODEL_REGEX, ""); + bodyJson = bodyJson.replaceFirst(EMPTY_REGEX, QUOTED_COMMA); + body = bodyJson; + } + if (contentType.equals(ContentType.MULTIPART_FORMDATA)) { + @SuppressWarnings("unchecked") + var bodyMap = (Map) request.getBody(); + bodyMap.remove(MODEL_LITERAL); + body = bodyMap; + } + request.setBody(body); + } + + return request; + }; + + return BaseSimpleOpenAIArgs.builder() + .baseUrl(baseUrl) + .headers(headers) + .httpClient(httpClient) + .requestInterceptor(requestInterceptor) + .build(); + } + + /** + * Constructor used to generate a builder. + * + * @param apiKey Identifier to be used for authentication. Mandatory. + * @param baseUrl The URL of the Azure OpenAI deployment. Mandatory. + * @param apiVersion Azure OpenAI API version. See: + * Azure OpenAI API versioning + * @param httpClient A {@link HttpClient HttpClient} object. + * One is created by default if not provided. Optional. + */ + @Builder + public SimpleOpenAIAzure(@NonNull String apiKey, @NonNull String baseUrl, @NonNull String apiVersion, HttpClient httpClient) { + super(prepareBaseSimpleOpenAIArgs(apiKey, baseUrl, apiVersion, httpClient)); + } +} diff --git a/src/main/java/io/github/sashirestela/openai/domain/chat/ChatRequest.java b/src/main/java/io/github/sashirestela/openai/domain/chat/ChatRequest.java index f08421ac..feef27bc 100644 --- a/src/main/java/io/github/sashirestela/openai/domain/chat/ChatRequest.java +++ b/src/main/java/io/github/sashirestela/openai/domain/chat/ChatRequest.java @@ -11,8 +11,6 @@ import io.github.sashirestela.openai.SimpleUncheckedException; import io.github.sashirestela.openai.domain.chat.message.ChatMsg; import io.github.sashirestela.openai.domain.chat.tool.ChatTool; -import io.github.sashirestela.openai.domain.chat.tool.ChatToolChoice; -import io.github.sashirestela.openai.domain.chat.tool.ChatToolChoiceType; import lombok.Builder; import lombok.Getter; import lombok.NonNull; @@ -29,7 +27,7 @@ public class ChatRequest { private ChatRespFmt responseFormat; private Integer seed; private List tools; - private Object toolChoice; + @With private Object toolChoice; private Double temperature; private Double topP; private Integer n; @@ -48,12 +46,6 @@ public ChatRequest(@NonNull String model, @NonNull @Singular List messa Integer seed, @Singular List tools, Object toolChoice, Double temperature, Double topP, Integer n, Boolean stream, Object stop, Integer maxTokens, Double presencePenalty, Double frequencyPenalty, Map logitBias, String user, Boolean logprobs, Integer topLogprobs) { - if (toolChoice != null && - !(toolChoice instanceof ChatToolChoiceType) && !(toolChoice instanceof ChatToolChoice)) { - throw new SimpleUncheckedException( - "The field toolChoice must be ChatToolChoiceType or ChatToolChoice classes.", - null, null); - } if (stop != null && !(stop instanceof String) && !(stop instanceof List && ((List) stop).get(0) instanceof String && ((List) stop).size() <= 4)) { throw new SimpleUncheckedException( diff --git a/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAnyscaleTest.java b/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAnyscaleTest.java new file mode 100644 index 00000000..24cc077a --- /dev/null +++ b/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAnyscaleTest.java @@ -0,0 +1,46 @@ +package io.github.sashirestela.openai; + +import static io.github.sashirestela.openai.SimpleOpenAIAnyscale.AUTHORIZATION_HEADER; +import static io.github.sashirestela.openai.SimpleOpenAIAnyscale.BEARER_AUTHORIZATION; +import static io.github.sashirestela.openai.SimpleOpenAIAnyscale.DEFAULT_BASE_URL; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +import java.net.http.HttpClient; +import org.junit.jupiter.api.Test; + +class SimpleOpenAIAnyscaleTest { + @Test + void shouldPrepareBaseOpenSimpleAIArgsCorrectlyWithCustomBaseURL() { + var args = SimpleOpenAIAnyscale.prepareBaseSimpleOpenAIArgs( + "the-api-key", + "https://example.org", + HttpClient.newHttpClient()); + + assertEquals("https://example.org", args.getBaseUrl()); + assertEquals(1, args.getHeaders().size()); + assertEquals(BEARER_AUTHORIZATION + "the-api-key", args.getHeaders().get(AUTHORIZATION_HEADER)); + assertNotNull(args.getHttpClient()); + + // No request interceptor for SimpleOpenAIAnyscale + assertNull(args.getRequestInterceptor()); + } + + @Test + void shouldPrepareBaseOpenSimpleAIArgsCorrectlyWithDefaultBaseURL() { + var args = SimpleOpenAIAnyscale.prepareBaseSimpleOpenAIArgs( + "the-api-key", + null, + HttpClient.newHttpClient()); + + assertEquals(SimpleOpenAIAnyscale.DEFAULT_BASE_URL, args.getBaseUrl()); + assertEquals(1, args.getHeaders().size()); + assertEquals(BEARER_AUTHORIZATION + "the-api-key", args.getHeaders().get(AUTHORIZATION_HEADER)); + assertNotNull(args.getHttpClient()); + + // No request interceptor for SimpleOpenAIAnyscale + assertNull(args.getRequestInterceptor()); + } + +} diff --git a/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAzureTest.java b/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAzureTest.java new file mode 100644 index 00000000..a25cb1ba --- /dev/null +++ b/src/test/java/io/github/sashirestela/openai/SimpleOpenAIAzureTest.java @@ -0,0 +1,57 @@ +package io.github.sashirestela.openai; + +import static io.github.sashirestela.openai.SimpleOpenAIAzure.API_KEY_HEADER; +import static io.github.sashirestela.openai.SimpleOpenAIAzure.API_VERSION; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import io.github.sashirestela.cleverclient.http.HttpRequestData; +import io.github.sashirestela.cleverclient.support.ContentType; +import java.net.http.HttpClient; +import java.util.Map; +import org.junit.jupiter.api.Test; + +class SimpleOpenAIAzureTest { + + + @Test + void shouldPrepareBaseOpenSimpleAIArgsCorrectly() { + var args = SimpleOpenAIAzure.prepareBaseSimpleOpenAIArgs( + "the-api-key", + "https://example.org", + "12-34-5678", + HttpClient.newHttpClient()); + + assertEquals("https://example.org", args.getBaseUrl()); + assertEquals(1, args.getHeaders().size()); + assertEquals("the-api-key", args.getHeaders().get(API_KEY_HEADER)); + assertNotNull(args.getHttpClient()); + assertNotNull(args.getRequestInterceptor()); + } + + @Test + void shouldInterceptUrlCorrectly() { + var request = HttpRequestData.builder() + .url("https://example.org/v1/endpoint") + .contentType(ContentType.APPLICATION_JSON) + .headers(Map.of(API_KEY_HEADER, "the-api-key")) + .body("{\"model\":\"model1\"}") + .build(); + var expectedRequest = HttpRequestData.builder() + .url("https://example.org/endpoint?" + API_VERSION + "=12-34-5678") + .contentType(ContentType.APPLICATION_JSON) + .headers(Map.of(API_KEY_HEADER, "the-api-key")) + .body("{}") + .build(); + var args = SimpleOpenAIAzure.prepareBaseSimpleOpenAIArgs( + "the-api-key", + "https://example.org", + "12-34-5678", + null); + var actualRequest = args.getRequestInterceptor().apply(request); + assertEquals(expectedRequest.getUrl() , actualRequest.getUrl()); + assertEquals(expectedRequest.getContentType(), actualRequest.getContentType()); + assertEquals(expectedRequest.getHeaders(), actualRequest.getHeaders()); + assertEquals(expectedRequest.getBody(), actualRequest.getBody()); + } +} diff --git a/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java b/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java index caeae9f4..e7fd73bf 100644 --- a/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java +++ b/src/test/java/io/github/sashirestela/openai/SimpleOpenAITest.java @@ -1,11 +1,13 @@ package io.github.sashirestela.openai; +import static io.github.sashirestela.openai.SimpleOpenAI.AUTHORIZATION_HEADER; +import static io.github.sashirestela.openai.SimpleOpenAI.BEARER_AUTHORIZATION; import static io.github.sashirestela.openai.SimpleOpenAI.OPENAI_BASE_URL; +import static io.github.sashirestela.openai.SimpleOpenAI.ORGANIZATION_HEADER; import static java.util.concurrent.CompletableFuture.completedFuture; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -32,45 +34,37 @@ class SimpleOpenAITest { CleverClient cleverClient = mock(CleverClient.class); @Test - void shouldSetPropertiesToDefaultValuesWhenBuilderIsCalledWithoutThoseProperties() { - var openAI = SimpleOpenAI.builder() - .apiKey("apiKey") - .build(); - assertEquals(HttpClient.Version.HTTP_2, openAI.getHttpClient().version()); - assertEquals(OPENAI_BASE_URL, openAI.getBaseUrl()); - assertNotNull(openAI.getCleverClient()); + void shouldPrepareBaseOpenSimpleAIArgsCorrectly() { + + var args = SimpleOpenAI.prepareBaseSimpleOpenAIArgs( + "the-api-key", + "orgId", + "https://example.org", + HttpClient.newHttpClient()); + + assertEquals("https://example.org", args.getBaseUrl()); + assertEquals(2, args.getHeaders().size()); + assertEquals(BEARER_AUTHORIZATION + "the-api-key", args.getHeaders().get(AUTHORIZATION_HEADER)); + assertEquals("orgId", args.getHeaders().get(ORGANIZATION_HEADER)); + assertNotNull(args.getHttpClient()); + + // No request interceptor for SimpleOpenAI + assertNull(args.getRequestInterceptor()); } - @Test - void shouldSetPropertiesWhenBuilderIsCalledWithThoseProperties() { - var otherUrl = "https://openai.com/api"; - var openAI = SimpleOpenAI.builder() - .apiKey("apiKey") - .baseUrl(otherUrl) - .httpClient(httpClient) - .build(); - assertEquals("apiKey", openAI.getApiKey()); - assertEquals(otherUrl, openAI.getBaseUrl()); - assertEquals(httpClient, openAI.getHttpClient()); - } + void shouldPrepareBaseOpenSimpleAIArgsCorrectlyWithOnlyApiKey() { + var args = SimpleOpenAI.prepareBaseSimpleOpenAIArgs("the-api-key", null, null, null); - @Test - void shouldNotAddOrganizationToHeadersWhenBuilderIsCalledWithoutOrganizationId() { - var openAI = SimpleOpenAI.builder() - .apiKey("apiKey") - .build(); - assertFalse(openAI.getCleverClient().getHeaders().containsValue(openAI.getOrganizationId())); - } + assertEquals(OPENAI_BASE_URL, args.getBaseUrl()); + assertEquals(1, args.getHeaders().size()); + assertEquals(BEARER_AUTHORIZATION + "the-api-key", args.getHeaders().get(AUTHORIZATION_HEADER)); + assertNotNull(args.getHttpClient()); - @Test - void shouldAddOrganizationToHeadersWhenBuilderIsCalledWithOrganizationId() { - var openAI = SimpleOpenAI.builder() - .apiKey("apiKey") - .organizationId("orgId") - .build(); - assertTrue(openAI.getCleverClient().getHeaders().containsValue(openAI.getOrganizationId())); + // No request interceptor for SimpleOpenAI + assertNull(args.getRequestInterceptor()); } + @Test @SuppressWarnings("unchecked") void shouldNotDuplicateContentTypeHeaderWhenCallingSimpleOpenAI() { @@ -104,7 +98,7 @@ void shouldNotDuplicateContentTypeHeaderWhenCallingSimpleOpenAI() { } @Test - void shouldInstanceServiceOnlyOnceWhenItIsCalledSeverlaTimes() { + void shouldInstanceServiceOnlyOnceWhenItIsCalledSeveralTimes() { final int NUMBER_CALLINGS = 3; final int NUMBER_INVOCATIONS = 1; diff --git a/src/test/java/io/github/sashirestela/openai/domain/chat/ChatDomainTest.java b/src/test/java/io/github/sashirestela/openai/domain/chat/ChatDomainTest.java index d6fd0a56..a5c7ec57 100644 --- a/src/test/java/io/github/sashirestela/openai/domain/chat/ChatDomainTest.java +++ b/src/test/java/io/github/sashirestela/openai/domain/chat/ChatDomainTest.java @@ -3,9 +3,13 @@ import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.mock; +import io.github.sashirestela.openai.OpenAI; +import io.github.sashirestela.openai.OpenAI.ChatCompletions; +import io.github.sashirestela.openai.domain.chat.tool.ChatTool; import java.io.IOException; import java.net.http.HttpClient; import java.util.List; @@ -201,16 +205,32 @@ void shouldCreateChatRequestWhenToolChoiceIsRightClass() { } } + @Test + void shouldUpdateChatRequestWithAutoToolChoiceWhenToolsAreProvidedWithoutToolChoice() { + var charRequest = ChatRequest.builder() + .model("model") + .message(new ChatMsgUser("content")) + .tools(functionExecutor.getToolFunctions()) + .build(); + + assertNull(charRequest.getToolChoice()); + var updatedChatRequest = OpenAI.updateRequest(charRequest, Boolean.TRUE); + assertEquals(ChatToolChoiceType.AUTO, updatedChatRequest.getToolChoice()); + } + @Test void shouldThrownExceptionWhenCreatingChatRequestWithToolChoiceWrongClass() { - var chatRequestBuilder = ChatRequest.builder() - .model("model") - .message(new ChatMsgUser("My Content")) - .toolChoice("wrong value"); - var exception = assertThrows(SimpleUncheckedException.class, () -> chatRequestBuilder.build()); + var charRequest = ChatRequest.builder() + .model("model") + .message(new ChatMsgUser("content")) + .tools(functionExecutor.getToolFunctions()) + .toolChoice("wrong value") + .build(); + + var exception = assertThrows(SimpleUncheckedException.class, () -> OpenAI.updateRequest(charRequest, Boolean.TRUE)); var actualErrorMessage = exception.getMessage(); - var expectedErrorMessge = "The field toolChoice must be ChatToolChoiceType or ChatToolChoice classes."; - assertEquals(expectedErrorMessge, actualErrorMessage); + var expectedErrorMessage = "The field toolChoice must be ChatToolChoiceType or ChatToolChoice classes."; + assertEquals(expectedErrorMessage, actualErrorMessage); } @Test