From 1405a317d9dc576fed896423eb6b48def32cbc68 Mon Sep 17 00:00:00 2001 From: Collin Date: Tue, 28 Mar 2023 14:21:14 -0400 Subject: [PATCH 01/12] Add request/response classes for completions --- .../com/cjcrafter/openai/chat/ChatChoice.kt | 2 +- .../com/cjcrafter/openai/chat/ChatRequest.kt | 31 ++- .../com/cjcrafter/openai/chat/ChatResponse.kt | 2 +- .../openai/completions/CompletionChoice.kt | 28 ++ .../openai/completions/CompletionRequest.kt | 255 ++++++++++++++++++ .../openai/completions/CompletionResponse.kt | 66 +++++ .../openai/completions/CompletionUsage.kt | 21 ++ 7 files changed, 387 insertions(+), 18 deletions(-) create mode 100644 src/main/kotlin/com/cjcrafter/openai/completions/CompletionChoice.kt create mode 100644 src/main/kotlin/com/cjcrafter/openai/completions/CompletionRequest.kt create mode 100644 src/main/kotlin/com/cjcrafter/openai/completions/CompletionResponse.kt create mode 100644 src/main/kotlin/com/cjcrafter/openai/completions/CompletionUsage.kt diff --git a/src/main/kotlin/com/cjcrafter/openai/chat/ChatChoice.kt b/src/main/kotlin/com/cjcrafter/openai/chat/ChatChoice.kt index ac5733e..e281985 100644 --- a/src/main/kotlin/com/cjcrafter/openai/chat/ChatChoice.kt +++ b/src/main/kotlin/com/cjcrafter/openai/chat/ChatChoice.kt @@ -5,7 +5,7 @@ import com.google.gson.JsonObject import com.google.gson.annotations.SerializedName /** - * The OpenAI API returns a list of [ChatChoice]. Each chat choice has a + * The OpenAI API returns a list of `ChatChoice`. Each choice has a * generated message ([ChatChoice.message]) and a finish reason * ([ChatChoice.finishReason]). For most use cases, you only need the generated * message. diff --git a/src/main/kotlin/com/cjcrafter/openai/chat/ChatRequest.kt b/src/main/kotlin/com/cjcrafter/openai/chat/ChatRequest.kt index 5ac9a8e..fc228e2 100644 --- a/src/main/kotlin/com/cjcrafter/openai/chat/ChatRequest.kt +++ b/src/main/kotlin/com/cjcrafter/openai/chat/ChatRequest.kt @@ -6,7 +6,7 @@ import com.google.gson.annotations.SerializedName * [ChatRequest] holds the configurable options that can be sent to the OpenAI * Chat API. For most use cases, you only need to set [model] and [messages]. * For more detailed descriptions for each option, refer to the - * [Chat Wiki](https://platform.openai.com/docs/api-reference/chat) + * [Chat Wiki](https://platform.openai.com/docs/api-reference/chat). * * [messages] stores **ALL** previous messages from the conversation. It is * **YOUR RESPONSIBILITY** to store and update this list for your conversations @@ -49,7 +49,7 @@ data class ChatRequest @JvmOverloads constructor( var temperature: Float? = null, @field:SerializedName("top_p") var topP: Float? = null, var n: Int? = null, - @Deprecated("Use ChatBot#streamResponse") var stream: Boolean? = null, + @Deprecated("Use OpenAI#streamChatCompletion") var stream: Boolean? = null, var stop: String? = null, @field:SerializedName("max_tokens") var maxTokens: Int? = null, @field:SerializedName("presence_penalty") var presencePenalty: Float? = null, @@ -58,20 +58,8 @@ data class ChatRequest @JvmOverloads constructor( var user: String? = null ) { - companion object { - - /** - * A static method that provides a new [Builder] instance for the - * [ChatRequest] class. - * - * @return a new [Builder] instance for creating a [ChatRequest] object. - */ - @JvmStatic - fun builder(): Builder = Builder() - } - /** - * [Builder] is a helper class to build a [ChatRequest] instance with a fluent API. + * [Builder] is a helper class to build a [ChatRequest] instance with a stable API. * It provides methods for setting the properties of the [ChatRequest] object. * The [build] method returns a new [ChatRequest] instance with the specified properties. * @@ -80,7 +68,6 @@ data class ChatRequest @JvmOverloads constructor( * val chatRequest = ChatRequest.builder() * .model("gpt-3.5-turbo") * .messages(mutableListOf("Be as helpful as possible".toSystemMessage())) - * .temperature(0.7f) * .build() * ``` * @@ -222,4 +209,16 @@ data class ChatRequest @JvmOverloads constructor( ) } } + + companion object { + + /** + * A static method that provides a new [Builder] instance for the + * [ChatRequest] class. + * + * @return a new [Builder] instance for creating a [ChatRequest] object. + */ + @JvmStatic + fun builder(): Builder = Builder() + } } \ No newline at end of file diff --git a/src/main/kotlin/com/cjcrafter/openai/chat/ChatResponse.kt b/src/main/kotlin/com/cjcrafter/openai/chat/ChatResponse.kt index 6a73b2e..fbaee00 100644 --- a/src/main/kotlin/com/cjcrafter/openai/chat/ChatResponse.kt +++ b/src/main/kotlin/com/cjcrafter/openai/chat/ChatResponse.kt @@ -7,7 +7,7 @@ import java.time.ZonedDateTime import java.util.* /** - * The [ChatResponse] contains all the data returned by the OpenAI Chat API. + * The `ChatResponse` contains all the data returned by the OpenAI Chat API. * For most use cases, [ChatResponse.get] (passing 0 to the index argument) is * all you need. * diff --git a/src/main/kotlin/com/cjcrafter/openai/completions/CompletionChoice.kt b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionChoice.kt new file mode 100644 index 0000000..4257b50 --- /dev/null +++ b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionChoice.kt @@ -0,0 +1,28 @@ +package com.cjcrafter.openai.completions + +import com.cjcrafter.openai.FinishReason +import com.google.gson.annotations.SerializedName + +/** + * The OpenAI API returns a list of `CompletionChoice`. Each choice has a + * generated message ([CompletionChoice.text]) and a finish reason + * ([CompletionChoice.finishReason]). For most use cases, you only need the + * generated text. + * + * By default, only 1 choice is generated (since [CompletionRequest.n] == 1). + * When you increase `n` or provide a list of prompts (called batching), + * there will be multiple choices. + * + * @property text The generated text. + * @property index The index in the list... This is 0 for most use cases. + * @property logprobs List of logarithmic probabilities for each token in the generated text. + * @property finishReason The reason the bot stopped generating tokens. + * @constructor Create empty Completion choice, for internal usage. + * @see FinishReason + */ +data class CompletionChoice( + val text: String, + val index: Int, + val logprobs: List?, + @field:SerializedName("finish_reason") val finishReason: FinishReason +) diff --git a/src/main/kotlin/com/cjcrafter/openai/completions/CompletionRequest.kt b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionRequest.kt new file mode 100644 index 0000000..f349dcf --- /dev/null +++ b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionRequest.kt @@ -0,0 +1,255 @@ +package com.cjcrafter.openai.completions + +import com.google.gson.annotations.SerializedName + +/** + * `CompletionRequest` holds the configurable options that can be sent to the OpenAI + * Completions API. For most use cases, you only need to set [model] and [prompt]. + * For more detailed descriptions for each option, refer to the + * [Completions Wiki](https://platform.openai.com/docs/api-reference/completions/create). + * + * [prompt] can be either a singular `String`, a `List`, or a `String[]`. + * Providing multiple prompts is called [batching](https://platform.openai.com/docs/guides/rate-limits/batching-requests), + * and it can be used to reduce rate limit errors. This will cause the [CompletionResponse] + * to have multiple choices. + * + * You should not set [stream]. The stream option is handled using [] + * + * @property model ID of the model to use. + * @property prompt The prompt(s) to generate completions for (either a `String`, `List`, or `String[]`) + * @property suffix The suffix that comes after a completion of inserted text. + * @property maxTokens The maximum number of tokens to generate in the completion. + * @property temperature What sampling temperature to use, between 0 and 2. + * @property topP An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. + * @property n How many completions to generate for each prompt. + * @property stream Whether to stream back partial progress. + * @property logprobs Include the log probabilities on the logprobs most likely tokens, as well the chosen tokens. + * @property echo Echo back the prompt in addition to the completion. + * @property stop Up to 4 sequences where the API will stop generating further tokens. + * @property presencePenalty Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + * @property frequencyPenalty Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + * @property bestOf Generates best_of completions server-side and returns the "best" (the one with the highest log probability per token). + * @property logitBias Modify the likelihood of specified tokens appearing in the completion. + * @property user A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + * @constructor Create a CompletionRequest instance. Recommend using [builder] instead. + */ +data class CompletionRequest @JvmOverloads constructor( + val model: String, + val prompt: Any, + val suffix: String? = null, + @field:SerializedName("max_tokens") val maxTokens: Int? = null, + val temperature: Number? = null, + @field:SerializedName("top_p") val topP: Number? = null, + val n: Int? = null, + @Deprecated("Use OpenAI#streamCompletion") val stream: Boolean? = null, + val logprobs: Int? = null, + val echo: Boolean? = null, + val stop: Any? = null, + @field:SerializedName("presence_penalty") val presencePenalty: Number? = null, + @field:SerializedName("frequency_penalty") val frequencyPenalty: Number? = null, + @field:SerializedName("best_of") val bestOf: Int? = null, + @field:SerializedName("logit_bias") val logitBias: Map? = null, + val user: String? = null +) { + + /** + * `Builder` is a helper class to build a [CompletionRequest] instance with a + * stable API. It provides methods for setting the properties fo the [CompletionRequest] + * object. The [build] method returns a new [CompletionRequest] instance with the + * specified properties. + * + * Usage: + * ``` + * val completionRequest = CompletionRequest.builder() + * .model("davinci") + * .prompt("The wheels on the bus go") + * .build() + * ``` + */ + class Builder { + private var model: String? = null + private var prompt: Any? = null + private var suffix: String? = null + private var maxTokens: Int? = null + private var temperature: Number? = null + private var topP: Number? = null + private var n: Int? = null + private var stream: Boolean? = null + private var logprobs: Int? = null + private var echo: Boolean? = null + private var stop: Any? = null + private var presencePenalty: Number? = null + private var frequencyPenalty: Number? = null + private var bestOf: Int? = null + private var logitBias: Map? = null + private var user: String? = null + + /** + * Sets the model for the CompletionRequest. + * @param model The ID of the model to use. + * @return The updated Builder instance. + */ + fun model(model: String) = apply { this.model = model } + + /** + * Sets the prompt for the CompletionRequest. + * @param prompt The prompt to generate completions for, encoded as a string. + * @return The updated Builder instance. + */ + fun prompt(prompt: String?) = apply { this.prompt = prompt } + + /** + * Sets the list of prompts for the CompletionRequest. + * @param prompts The prompts to generate completions for, encoded as a list of strings. + * @return The updated Builder instance. + */ + fun prompts(prompts: List?) = apply { this.prompt = prompts } + + /** + * Sets the array of prompts for the CompletionRequest. + * @param prompts The prompts to generate completions for, encoded as an array of strings. + * @return The updated Builder instance. + */ + fun prompts(prompts: Array?) = apply { this.prompt = prompts } + + /** + * Sets the suffix for the CompletionRequest. + * @param suffix The suffix that comes after a completion of inserted text. + * @return The updated Builder instance. + */ + fun suffix(suffix: String?) = apply { this.suffix = suffix } + + /** + * Sets the maximum number of tokens for the CompletionRequest. + * @param maxTokens The maximum number of tokens to generate in the completion. + * @return The updated Builder instance. + */ + fun maxTokens(maxTokens: Int?) = apply { this.maxTokens = maxTokens } + + /** + * Sets the temperature for the CompletionRequest. + * @param temperature What sampling temperature to use, between 0 and 2. + * @return The updated Builder instance. + */ + fun temperature(temperature: Number?) = apply { this.temperature = temperature } + + /** + * Sets the top_p for the CompletionRequest. + * @param topP An alternative to sampling with temperature, called nucleus sampling. + * @return The updated Builder instance. + */ + fun topP(topP: Number?) = apply { this.topP = topP } + + /** + * Sets the number of completions for the CompletionRequest. + * @param n How many completions to generate for each prompt. + * @return The updated Builder instance. + */ + fun n(n: Int?) = apply { this.n = n } + + /** + * Sets the stream option for the CompletionRequest. + * @param stream Whether to stream back partial progress. + * @return The updated Builder instance. + */ + fun stream(stream: Boolean?) = apply { this.stream = stream } + + /** + * Sets the logprobs for the CompletionRequest. + * @param logprobs Include the log probabilities on the logprobs most likely tokens, as well as the chosen tokens. + * @return The updated Builder instance. + */ + fun logprobs(logprobs: Int?) = apply { this.logprobs = logprobs } + + /** + * Sets the echo option for the CompletionRequest. + * @param echo Echo back the prompt in addition to the completion. + * @return The updated Builder instance. + */ + fun echo(echo: Boolean?) = apply { this.echo = echo } + + /** + * Sets the stop sequences for the CompletionRequest. + * @param stop Up to 4 sequences where the API will stop generating further tokens. + * @return The updated Builder instance. + */ + fun stop(stop: Any?) = apply { this.stop = stop } + + /** + * Sets the presence penalty for the CompletionRequest. + * @param presencePenalty Number between -2.0 and 2.0 for penalizing new tokens based on whether they appear in the text so far. + * @return The updated Builder instance. + */ + fun presencePenalty(presencePenalty: Number?) = apply { this.presencePenalty = presencePenalty } + + /** + * Sets the frequency penalty for the CompletionRequest. + * @param frequencyPenalty Number between -2.0 and 2.0 for penalizing new tokens based on their existing frequency in the text so far. + * @return The updated Builder instance. + */ + fun frequencyPenalty(frequencyPenalty: Number?) = apply { this.frequencyPenalty = frequencyPenalty } + + /** + * Sets the best_of option for the CompletionRequest. + * @param bestOf Generates best_of completions server-side and returns the "best" (the one with the highest log probability per token). + * @return The updated Builder instance. + */ + fun bestOf(bestOf: Int?) = apply { this.bestOf = bestOf } + + /** + * Sets the logit bias for the CompletionRequest. + * @param logitBias Modify the likelihood of specified tokens appearing in the completion. + * @return The updated Builder instance. + */ + fun logitBias(logitBias: Map?) = apply { this.logitBias = logitBias } + + /** + * Sets the user identifier for the CompletionRequest. + * @param user A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + * @return The updated Builder instance. + */ + fun user(user: String?) = apply { this.user = user } + + /** + * Builds the CompletionRequest instance with the provided parameters. + * + * @return The constructed CompletionRequest instance. + */ + fun build(): CompletionRequest { + require(model != null) { "Set CompletionRequest.Builder#model(String) before building" } + require(prompt != null) { "Set CompletionRequest.Builder#prompt(String) before building" } + + return CompletionRequest( + model = model!!, + prompt = prompt!!, + suffix = suffix, + maxTokens = maxTokens, + temperature = temperature, + topP = topP, + n = n, + stream = stream, + logprobs = logprobs, + echo = echo, + stop = stop, + presencePenalty = presencePenalty, + frequencyPenalty = frequencyPenalty, + bestOf = bestOf, + logitBias = logitBias, + user = user + ) + } + } + + companion object { + + /** + * A static method that provides a new [Builder] instance for the + * [CompletionRequest] class. + * + * @return a new [Builder] instance for creating a [CompletionRequest] object. + */ + @JvmStatic + fun builder(): Builder = Builder() + } +} + diff --git a/src/main/kotlin/com/cjcrafter/openai/completions/CompletionResponse.kt b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionResponse.kt new file mode 100644 index 0000000..0476eba --- /dev/null +++ b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionResponse.kt @@ -0,0 +1,66 @@ +package com.cjcrafter.openai.completions + +import java.time.Instant +import java.time.ZoneId +import java.time.ZonedDateTime +import java.util.* + +/** + * The `CompletionResponse` contains all the data returned by the OpenAI Completions + * API. For most use cases, [CompletionResponse.get] (passing 0 to the index argument) + * is all you need. + * + * @property id The unique id for your request. + * @property created The Unix timestamp (measured in seconds since 00:00:00 UTC on Junuary 1, 1970) when the API response was created. + * @property model The model used to generate the completion. + * @property choices The generated completion(s). + * @property usage The number of tokens used in this request/response. + * @constructor Create Completion response (for internal usage) + */ +data class CompletionResponse( + val id: String, + val created: Long, + val model: String, + val choices: List, + val usage: CompletionUsage +) { + + /** + * Returns the [Instant] time that the OpenAI Completion API sent this response. + * The time is measured as a unix timestamp (measured in seconds since + * 00:00:00 UTC on January 1, 1970). + * + * Note that users expect time to be measured in their timezone, so + * [getZonedTime] is preferred. + * + * @return The instant the api created this response. + * @see getZonedTime + */ + fun getTime(): Instant { + return Instant.ofEpochSecond(created) + } + + /** + * Returns the time-zoned instant that the OpenAI Completion API sent this + * response. By default, this method uses the system's timezone. + * + * @param timezone The user's timezone. + * @return The timezone adjusted date time. + * @see TimeZone.getDefault + */ + @JvmOverloads + fun getZonedTime(timezone: ZoneId = TimeZone.getDefault().toZoneId()): ZonedDateTime { + return ZonedDateTime.ofInstant(getTime(), timezone) + } + + /** + * Shorthand for accessing the generated messages (shorthand for + * [CompletionResponse.choices]). + * + * @param index The index of the message. + * @return The generated [CompletionChoice] at the index. + */ + operator fun get(index: Int): CompletionChoice { + return choices[index] + } +} diff --git a/src/main/kotlin/com/cjcrafter/openai/completions/CompletionUsage.kt b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionUsage.kt new file mode 100644 index 0000000..4648593 --- /dev/null +++ b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionUsage.kt @@ -0,0 +1,21 @@ +package com.cjcrafter.openai.completions + +import com.google.gson.annotations.SerializedName + +/** + * Holds how many tokens that were used by your API request. Use these + * tokens to calculate how much money you have spent on each request. + * + * By monitoring your token usage, you can limit the amount of money charged + * for your requests. You can check the cost of the model you are using in + * OpenAI's billing page. + * + * @param promptTokens How many tokens the input used. + * @param completionTokens How many tokens the output used. + * @param totalTokens How many tokens in total. + */ +data class CompletionUsage( + @field:SerializedName("prompt_tokens") val promptTokens: Int, + @field:SerializedName("completion_tokens") val completionTokens: Int, + @field:SerializedName("total_tokens") val totalTokens: Int +) \ No newline at end of file From 4bcd10c237000ea25c4f4cd0073ab8634252326b Mon Sep 17 00:00:00 2001 From: Collin Date: Wed, 29 Mar 2023 13:03:27 -0400 Subject: [PATCH 02/12] make CompletionRequest mutable --- .../kotlin/com/cjcrafter/openai/OpenAI.kt | 35 ++++++++++++++++--- .../openai/completions/CompletionRequest.kt | 32 ++++++++--------- 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt b/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt index 8eb1986..4d034f3 100644 --- a/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt +++ b/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt @@ -2,6 +2,8 @@ package com.cjcrafter.openai import com.cjcrafter.openai.gson.ChatChoiceChunkAdapter import com.cjcrafter.openai.chat.* +import com.cjcrafter.openai.completions.CompletionRequest +import com.cjcrafter.openai.completions.CompletionResponse import com.cjcrafter.openai.exception.OpenAIError import com.cjcrafter.openai.exception.WrappedIOError import com.cjcrafter.openai.gson.ChatUserAdapter @@ -40,17 +42,42 @@ class OpenAI @JvmOverloads constructor( private val mediaType = "application/json; charset=utf-8".toMediaType() private val gson = createGson() - private fun buildRequest(request: Any): Request { + private fun buildRequest(request: Any, endpoint: String): Request { val json = gson.toJson(request) val body: RequestBody = json.toRequestBody(mediaType) return Request.Builder() - .url("https://api.openai.com/v1/chat/completions") + .url("https://api.openai.com/v1/$endpoint") .addHeader("Content-Type", "application/json") .addHeader("Authorization", "Bearer $apiKey") .apply { if (organization != null) addHeader("OpenAI-Organization", organization) } .post(body).build() } + @Throws(OpenAIError::class) + fun createCompletion(request: CompletionRequest): CompletionResponse { + @Suppress("DEPRECATION") + request.stream = false // use streamResponse for stream=true + val httpRequest = buildRequest(request, "completions") + + // Save the JsonObject to check for errors + var rootObject: JsonObject? + try { + client.newCall(httpRequest).execute().use { response -> + + // Servers respond to API calls with json blocks. Since raw JSON isn't + // very developer friendly, we wrap for easy data access. + rootObject = JsonParser.parseString(response.body!!.string()).asJsonObject + if (rootObject!!.has("error")) + throw OpenAIError.fromJson(rootObject!!.get("error").asJsonObject) + + return gson.fromJson(rootObject, CompletionResponse::class.java) + //return ChatResponse(rootObject!!) + } + } catch (ex: IOException) { + throw WrappedIOError(ex) + } + } + /** * Blocks the current thread until OpenAI responds to https request. The * returned value includes information including tokens, generated text, @@ -65,7 +92,7 @@ class OpenAI @JvmOverloads constructor( fun createChatCompletion(request: ChatRequest): ChatResponse { @Suppress("DEPRECATION") request.stream = false // use streamResponse for stream=true - val httpRequest = buildRequest(request) + val httpRequest = buildRequest(request, "chat/completions") // Save the JsonObject to check for errors var rootObject: JsonObject? @@ -150,7 +177,7 @@ class OpenAI @JvmOverloads constructor( ) { @Suppress("DEPRECATION") request.stream = true // use requestResponse for stream=false - val httpRequest = buildRequest(request) + val httpRequest = buildRequest(request, "chat/completions") client.newCall(httpRequest).enqueue(object : Callback { var cache: ChatResponseChunk? = null diff --git a/src/main/kotlin/com/cjcrafter/openai/completions/CompletionRequest.kt b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionRequest.kt index f349dcf..9cfa3b9 100644 --- a/src/main/kotlin/com/cjcrafter/openai/completions/CompletionRequest.kt +++ b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionRequest.kt @@ -34,22 +34,22 @@ import com.google.gson.annotations.SerializedName * @constructor Create a CompletionRequest instance. Recommend using [builder] instead. */ data class CompletionRequest @JvmOverloads constructor( - val model: String, - val prompt: Any, - val suffix: String? = null, - @field:SerializedName("max_tokens") val maxTokens: Int? = null, - val temperature: Number? = null, - @field:SerializedName("top_p") val topP: Number? = null, - val n: Int? = null, - @Deprecated("Use OpenAI#streamCompletion") val stream: Boolean? = null, - val logprobs: Int? = null, - val echo: Boolean? = null, - val stop: Any? = null, - @field:SerializedName("presence_penalty") val presencePenalty: Number? = null, - @field:SerializedName("frequency_penalty") val frequencyPenalty: Number? = null, - @field:SerializedName("best_of") val bestOf: Int? = null, - @field:SerializedName("logit_bias") val logitBias: Map? = null, - val user: String? = null + var model: String, + var prompt: Any, + var suffix: String? = null, + @field:SerializedName("max_tokens") var maxTokens: Int? = null, + var temperature: Number? = null, + @field:SerializedName("top_p") var topP: Number? = null, + var n: Int? = null, + @Deprecated("Use OpenAI#streamCompletion") var stream: Boolean? = null, + var logprobs: Int? = null, + var echo: Boolean? = null, + var stop: Any? = null, + @field:SerializedName("presence_penalty") var presencePenalty: Number? = null, + @field:SerializedName("frequency_penalty") var frequencyPenalty: Number? = null, + @field:SerializedName("best_of") var bestOf: Int? = null, + @field:SerializedName("logit_bias") var logitBias: Map? = null, + var user: String? = null ) { /** From 0017d9342f98eb7702d79e500c3ab244cc64a01a Mon Sep 17 00:00:00 2001 From: Collin Date: Thu, 30 Mar 2023 13:52:34 -0400 Subject: [PATCH 03/12] improve javadocs in OpenAI --- .../kotlin/com/cjcrafter/openai/OpenAI.kt | 46 +++++++++++++++---- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt b/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt index 4d034f3..2f0dc02 100644 --- a/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt +++ b/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt @@ -19,6 +19,10 @@ import java.io.IOException import java.util.function.Consumer /** + * The `OpenAI` class contains all the API calls to OpenAI's endpoint. Whether + * you are working with images, chat, or completions, you need to have an + * `OpenAI` instance to make the API requests. + * * To get your API key: * 1. Log in to your account: Go to [https://www.openai.com/](openai.com) and * log in. @@ -38,7 +42,6 @@ class OpenAI @JvmOverloads constructor( private val organization: String? = null, private val client: OkHttpClient = OkHttpClient() ) { - private val mediaType = "application/json; charset=utf-8".toMediaType() private val gson = createGson() @@ -56,7 +59,7 @@ class OpenAI @JvmOverloads constructor( @Throws(OpenAIError::class) fun createCompletion(request: CompletionRequest): CompletionResponse { @Suppress("DEPRECATION") - request.stream = false // use streamResponse for stream=true + request.stream = false // use streamCompletion for stream=true val httpRequest = buildRequest(request, "completions") // Save the JsonObject to check for errors @@ -71,9 +74,9 @@ class OpenAI @JvmOverloads constructor( throw OpenAIError.fromJson(rootObject!!.get("error").asJsonObject) return gson.fromJson(rootObject, CompletionResponse::class.java) - //return ChatResponse(rootObject!!) } } catch (ex: IOException) { + // Wrap the IOException, so we don't need to catch multiple exceptions throw WrappedIOError(ex) } } @@ -106,9 +109,9 @@ class OpenAI @JvmOverloads constructor( throw OpenAIError.fromJson(rootObject!!.get("error").asJsonObject) return gson.fromJson(rootObject, ChatResponse::class.java) - //return ChatResponse(rootObject!!) } } catch (ex: IOException) { + // Wrap the IOException, so we don't need to catch multiple exceptions throw WrappedIOError(ex) } } @@ -214,13 +217,38 @@ class OpenAI @JvmOverloads constructor( companion object { + /** + * Returns a `Gson` object that can be used to read/write .json files. + * This can be used to save requests/responses to a file, so you can + * keep a history of all API calls you've made. + * + * This is especially important for [ChatRequest], since users will + * expect you to save their conversations to be continued at later + * times. + * + * If you want to add your own type adapters, use [createGsonBuilder] + * instead. + * + * @return Google gson serializer for json files. + */ @JvmStatic fun createGson(): Gson { - return GsonBuilder() - .registerTypeAdapter(ChatUser::class.java, ChatUserAdapter()) - .registerTypeAdapter(FinishReason::class.java, FinishReasonAdapter()) - .registerTypeAdapter(ChatChoiceChunk::class.java, ChatChoiceChunkAdapter()) - .create() + return createGsonBuilder().create() + } + + /** + * Returns a `GsonBuilder` with all [com.google.gson.TypeAdapter] used + * by `com.cjcrafter.openai`. Unless you want to register your own + * adapters, I recommend using [createGson] instead of this method. + * + * @return Google gson builder for serializing json files. + */ + @JvmStatic + fun createGsonBuilder(): GsonBuilder { + return GsonBuilder() + .registerTypeAdapter(ChatUser::class.java, ChatUserAdapter()) + .registerTypeAdapter(FinishReason::class.java, FinishReasonAdapter()) + .registerTypeAdapter(ChatChoiceChunk::class.java, ChatChoiceChunkAdapter()) } } } \ No newline at end of file From 15fbf5883a99131ef259566a90613d5b76ac2265 Mon Sep 17 00:00:00 2001 From: Collin Date: Fri, 31 Mar 2023 01:35:00 -0400 Subject: [PATCH 04/12] add completion streaming support --- .../kotlin/com/cjcrafter/openai/OpenAI.kt | 104 ++++++++++++++++-- .../completions/CompletionChoiceChunk.kt | 28 +++++ .../completions/CompletionResponseChunk.kt | 64 +++++++++++ src/test/kotlin/KotlinCompletionStreamTest.kt | 19 ++++ 4 files changed, 203 insertions(+), 12 deletions(-) create mode 100644 src/main/kotlin/com/cjcrafter/openai/completions/CompletionChoiceChunk.kt create mode 100644 src/main/kotlin/com/cjcrafter/openai/completions/CompletionResponseChunk.kt create mode 100644 src/test/kotlin/KotlinCompletionStreamTest.kt diff --git a/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt b/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt index 2f0dc02..e260189 100644 --- a/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt +++ b/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt @@ -4,6 +4,7 @@ import com.cjcrafter.openai.gson.ChatChoiceChunkAdapter import com.cjcrafter.openai.chat.* import com.cjcrafter.openai.completions.CompletionRequest import com.cjcrafter.openai.completions.CompletionResponse +import com.cjcrafter.openai.completions.CompletionResponseChunk import com.cjcrafter.openai.exception.OpenAIError import com.cjcrafter.openai.exception.WrappedIOError import com.cjcrafter.openai.gson.ChatUserAdapter @@ -16,6 +17,7 @@ import okhttp3.* import okhttp3.MediaType.Companion.toMediaType import okhttp3.RequestBody.Companion.toRequestBody import java.io.IOException +import java.lang.IllegalStateException import java.util.function.Consumer /** @@ -56,22 +58,27 @@ class OpenAI @JvmOverloads constructor( .post(body).build() } + /** + * Create completion + * + * @param request + * @return + * @since 1.3.0 + */ @Throws(OpenAIError::class) fun createCompletion(request: CompletionRequest): CompletionResponse { @Suppress("DEPRECATION") request.stream = false // use streamCompletion for stream=true val httpRequest = buildRequest(request, "completions") - // Save the JsonObject to check for errors - var rootObject: JsonObject? try { client.newCall(httpRequest).execute().use { response -> // Servers respond to API calls with json blocks. Since raw JSON isn't // very developer friendly, we wrap for easy data access. - rootObject = JsonParser.parseString(response.body!!.string()).asJsonObject - if (rootObject!!.has("error")) - throw OpenAIError.fromJson(rootObject!!.get("error").asJsonObject) + val rootObject = JsonParser.parseString(response.body!!.string()).asJsonObject + if (rootObject.has("error")) + throw OpenAIError.fromJson(rootObject.get("error").asJsonObject) return gson.fromJson(rootObject, CompletionResponse::class.java) } @@ -81,6 +88,78 @@ class OpenAI @JvmOverloads constructor( } } + /** + * Helper method to call [streamCompletion]. + * + * @param request The input information for ChatGPT. + * @param onResponse The method to call for each chunk. + * @since 1.3.0 + */ + fun streamCompletionKotlin(request: CompletionRequest, onResponse: CompletionResponseChunk.() -> Unit) { + streamCompletion(request, { it.onResponse() }) + } + + /** + * This method does not block the thread. Method calls to [onResponse] are + * not handled by the main thread. It is crucial to consider thread safety + * within the context of your program. + * + * @param request The input information for ChatGPT. + * @param onResponse The method to call for each chunk. + * @param onFailure The method to call if the HTTP fails. This method will + * not be called if OpenAI returns an error. + * @see createCompletion + * @see streamCompletionKotlin + * @since 1.3.0 + */ + @JvmOverloads + fun streamCompletion( + request: CompletionRequest, + onResponse: Consumer, // use Consumer instead of Kotlin for better Java syntax + onFailure: Consumer = Consumer { it.printStackTrace() } + ) { + @Suppress("DEPRECATION") + request.stream = true // use requestResponse for stream=false + val httpRequest = buildRequest(request, "completions") + + client.newCall(httpRequest).enqueue(object : Callback { + + override fun onFailure(call: Call, e: IOException) { + onFailure.accept(WrappedIOError(e)) + } + + override fun onResponse(call: Call, response: Response) { + response.body?.source()?.use { source -> + while (!source.exhausted()) { + + // Parse the JSON string as a map. Every string starts + // with "data: ", so we need to remove that. + var jsonResponse = source.readUtf8Line() ?: continue + if (jsonResponse.isEmpty()) + continue + + // TODO comment + if (!jsonResponse.startsWith("data: ")) { + System.err.println(jsonResponse) + continue + } + + jsonResponse = jsonResponse.substring("data: ".length) + if (jsonResponse == "[DONE]") + continue + + val rootObject = JsonParser.parseString(jsonResponse).asJsonObject + if (rootObject.has("error")) + throw OpenAIError.fromJson(rootObject.get("error").asJsonObject) + + val cache = gson.fromJson(rootObject, CompletionResponseChunk::class.java) + onResponse.accept(cache) + } + } + } + }) + } + /** * Blocks the current thread until OpenAI responds to https request. The * returned value includes information including tokens, generated text, @@ -97,16 +176,14 @@ class OpenAI @JvmOverloads constructor( request.stream = false // use streamResponse for stream=true val httpRequest = buildRequest(request, "chat/completions") - // Save the JsonObject to check for errors - var rootObject: JsonObject? try { client.newCall(httpRequest).execute().use { response -> // Servers respond to API calls with json blocks. Since raw JSON isn't // very developer friendly, we wrap for easy data access. - rootObject = JsonParser.parseString(response.body!!.string()).asJsonObject - if (rootObject!!.has("error")) - throw OpenAIError.fromJson(rootObject!!.get("error").asJsonObject) + val rootObject = JsonParser.parseString(response.body!!.string()).asJsonObject + if (rootObject.has("error")) + throw OpenAIError.fromJson(rootObject.get("error").asJsonObject) return gson.fromJson(rootObject, ChatResponse::class.java) } @@ -176,7 +253,7 @@ class OpenAI @JvmOverloads constructor( fun streamChatCompletion( request: ChatRequest, onResponse: Consumer, // use Consumer instead of Kotlin for better Java syntax - onFailure: Consumer = Consumer { it.printStackTrace() } + onFailure: Consumer = Consumer { it.printStackTrace() } ) { @Suppress("DEPRECATION") request.stream = true // use requestResponse for stream=false @@ -186,7 +263,7 @@ class OpenAI @JvmOverloads constructor( var cache: ChatResponseChunk? = null override fun onFailure(call: Call, e: IOException) { - onFailure.accept(e) + onFailure.accept(WrappedIOError(e)) } override fun onResponse(call: Call, response: Response) { @@ -203,6 +280,9 @@ class OpenAI @JvmOverloads constructor( continue val rootObject = JsonParser.parseString(jsonResponse).asJsonObject + if (rootObject.has("error")) + throw OpenAIError.fromJson(rootObject.get("error").asJsonObject) + if (cache == null) cache = gson.fromJson(rootObject, ChatResponseChunk::class.java) else diff --git a/src/main/kotlin/com/cjcrafter/openai/completions/CompletionChoiceChunk.kt b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionChoiceChunk.kt new file mode 100644 index 0000000..2adec63 --- /dev/null +++ b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionChoiceChunk.kt @@ -0,0 +1,28 @@ +package com.cjcrafter.openai.completions + +import com.cjcrafter.openai.FinishReason +import com.google.gson.annotations.SerializedName + +/** + * The OpenAI API returns a list of `CompletionChoice`. Each choice has a + * generated message ([CompletionChoice.text]) and a finish reason + * ([CompletionChoice.finishReason]). For most use cases, you only need the + * generated text. + * + * By default, only 1 choice is generated (since [CompletionRequest.n] == 1). + * When you increase `n` or provide a list of prompts (called batching), + * there will be multiple choices. + * + * @property text The few generated tokens. + * @property index The index in the list... This is 0 for most use cases. + * @property logprobs List of logarithmic probabilities for each token in the generated text. + * @property finishReason The reason the bot stopped generating tokens. + * @constructor Create empty Completion choice, for internal usage. + * @see FinishReason + */ +data class CompletionChoiceChunk( + val text: String, + val index: Int, + val logprobs: List?, + @field:SerializedName("finish_reason") val finishReason: FinishReason? +) diff --git a/src/main/kotlin/com/cjcrafter/openai/completions/CompletionResponseChunk.kt b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionResponseChunk.kt new file mode 100644 index 0000000..1d791a0 --- /dev/null +++ b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionResponseChunk.kt @@ -0,0 +1,64 @@ +package com.cjcrafter.openai.completions + +import java.time.Instant +import java.time.ZoneId +import java.time.ZonedDateTime +import java.util.* + +/** + * The `CompletionResponse` contains all the data returned by the OpenAI Completions + * API. For most use cases, [CompletionResponse.get] (passing 0 to the index argument) + * is all you need. + * + * @property id The unique id for your request. + * @property created The Unix timestamp (measured in seconds since 00:00:00 UTC on Junuary 1, 1970) when the API response was created. + * @property model The model used to generate the completion. + * @property choices The generated completion(s). + * @constructor Create Completion response (for internal usage) + */ +data class CompletionResponseChunk( + val id: String, + val created: Long, + val model: String, + val choices: List, +) { + + /** + * Returns the [Instant] time that the OpenAI Completion API sent this response. + * The time is measured as a unix timestamp (measured in seconds since + * 00:00:00 UTC on January 1, 1970). + * + * Note that users expect time to be measured in their timezone, so + * [getZonedTime] is preferred. + * + * @return The instant the api created this response. + * @see getZonedTime + */ + fun getTime(): Instant { + return Instant.ofEpochSecond(created) + } + + /** + * Returns the time-zoned instant that the OpenAI Completion API sent this + * response. By default, this method uses the system's timezone. + * + * @param timezone The user's timezone. + * @return The timezone adjusted date time. + * @see TimeZone.getDefault + */ + @JvmOverloads + fun getZonedTime(timezone: ZoneId = TimeZone.getDefault().toZoneId()): ZonedDateTime { + return ZonedDateTime.ofInstant(getTime(), timezone) + } + + /** + * Shorthand for accessing the generated messages (shorthand for + * [CompletionResponseChunk.choices]). + * + * @param index The index of the message. + * @return The generated [CompletionChoiceChunk] at the index. + */ + operator fun get(index: Int): CompletionChoiceChunk { + return choices[index] + } +} diff --git a/src/test/kotlin/KotlinCompletionStreamTest.kt b/src/test/kotlin/KotlinCompletionStreamTest.kt new file mode 100644 index 0000000..5bbf449 --- /dev/null +++ b/src/test/kotlin/KotlinCompletionStreamTest.kt @@ -0,0 +1,19 @@ +import com.cjcrafter.openai.OpenAI +import com.cjcrafter.openai.completions.CompletionRequest +import io.github.cdimascio.dotenv.dotenv + +fun main(args: Array) { + + // Prepare the ChatRequest + val request = CompletionRequest(model="davinci", prompt="Hello darkness", maxTokens = 1024) + + // Loads the API key from the .env file in the root directory. + val key = dotenv()["OPENAI_TOKEN"] + val openai = OpenAI(key) + + // Generate a response, and print it to the user + //println(openai.createCompletion(request)) + openai.streamCompletionKotlin(request) { + print(choices[0].text) + } +} \ No newline at end of file From 4a41ef5ec119ff8b267df29e0e0cebf6e3e37d5d Mon Sep 17 00:00:00 2001 From: Collin Date: Fri, 31 Mar 2023 01:37:07 -0400 Subject: [PATCH 05/12] prompt is not a required argument for Completions --- .../com/cjcrafter/openai/completions/CompletionRequest.kt | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/main/kotlin/com/cjcrafter/openai/completions/CompletionRequest.kt b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionRequest.kt index 9cfa3b9..4127547 100644 --- a/src/main/kotlin/com/cjcrafter/openai/completions/CompletionRequest.kt +++ b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionRequest.kt @@ -35,7 +35,7 @@ import com.google.gson.annotations.SerializedName */ data class CompletionRequest @JvmOverloads constructor( var model: String, - var prompt: Any, + var prompt: Any? = null, var suffix: String? = null, @field:SerializedName("max_tokens") var maxTokens: Int? = null, var temperature: Number? = null, @@ -217,11 +217,10 @@ data class CompletionRequest @JvmOverloads constructor( */ fun build(): CompletionRequest { require(model != null) { "Set CompletionRequest.Builder#model(String) before building" } - require(prompt != null) { "Set CompletionRequest.Builder#prompt(String) before building" } return CompletionRequest( model = model!!, - prompt = prompt!!, + prompt = prompt, suffix = suffix, maxTokens = maxTokens, temperature = temperature, From f99e07459077edf41e0a04279aab7cb6d4694cce Mon Sep 17 00:00:00 2001 From: Collin Date: Sun, 2 Apr 2023 18:25:56 -0400 Subject: [PATCH 06/12] Update OpenAI with sync/async methods --- .../kotlin/com/cjcrafter/openai/MyCallback.kt | 76 ++++ .../kotlin/com/cjcrafter/openai/OpenAI.kt | 349 +++++++++--------- src/test/kotlin/KotlinChatStreamTest.kt | 19 +- src/test/kotlin/KotlinCompletionStreamTest.kt | 4 +- 4 files changed, 272 insertions(+), 176 deletions(-) create mode 100644 src/main/kotlin/com/cjcrafter/openai/MyCallback.kt diff --git a/src/main/kotlin/com/cjcrafter/openai/MyCallback.kt b/src/main/kotlin/com/cjcrafter/openai/MyCallback.kt new file mode 100644 index 0000000..6465536 --- /dev/null +++ b/src/main/kotlin/com/cjcrafter/openai/MyCallback.kt @@ -0,0 +1,76 @@ +package com.cjcrafter.openai + +import com.cjcrafter.openai.exception.OpenAIError +import com.cjcrafter.openai.exception.WrappedIOError +import com.google.gson.JsonObject +import com.google.gson.JsonParser +import okhttp3.Call +import okhttp3.Callback +import okhttp3.Response +import java.io.IOException +import java.util.function.Consumer + +internal class MyCallback( + private val isStream: Boolean, + private val onFailure: Consumer, + private val onResponse: Consumer +) : Callback { + + override fun onFailure(call: Call, e: IOException) { + onFailure.accept(WrappedIOError(e)) + } + + override fun onResponse(call: Call, response: Response) { + onResponse(response) + } + + fun onResponse(response: Response) { + if (isStream) { + handleStream(response) + return + } + + val rootObject = JsonParser.parseString(response.body!!.string()).asJsonObject + + // Sometimes OpenAI will respond with an error code for malformed + // requests, timeouts, rate limits, etc. We need to let the dev + // know that an error occurred. + if (rootObject.has("error")) { + onFailure.accept(OpenAIError.fromJson(rootObject.get("error").asJsonObject)) + return + } + + onResponse.accept(rootObject) + } + + private fun handleStream(response: Response) { + response.body?.source()?.use { source -> + while (!source.exhausted()) { + var jsonResponse = source.readUtf8() + + // OpenAI returns a json string, but they prepend the content with + // "data: " (which is not valid json). In order to parse this into + // a JsonObject, we have to strip away this extra string. + jsonResponse = jsonResponse.substring("data: ".length) + + // After OpenAI's final message (which already contains a non-null + // finish reason), they redundantly send "data: [DONE]". Ignore it. + if (jsonResponse == "[DONE]") + continue + + val rootObject = JsonParser.parseString(jsonResponse).asJsonObject + + // Sometimes OpenAI will respond with an error code for malformed + // requests, timeouts, rate limits, etc. We need to let the dev + // know that an error occurred. + if (rootObject.has("error")) { + onFailure.accept(OpenAIError.fromJson(rootObject.get("error").asJsonObject)) + continue + } + + // Developer defined code to run + onResponse.accept(rootObject) + } + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt b/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt index e260189..3f13220 100644 --- a/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt +++ b/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt @@ -18,6 +18,7 @@ import okhttp3.MediaType.Companion.toMediaType import okhttp3.RequestBody.Companion.toRequestBody import java.io.IOException import java.lang.IllegalStateException +import java.util.ArrayList import java.util.function.Consumer /** @@ -59,244 +60,256 @@ class OpenAI @JvmOverloads constructor( } /** - * Create completion * - * @param request - * @return + * @param request The input information for the Completions API. + * @return The value returned by the Completions API. * @since 1.3.0 */ @Throws(OpenAIError::class) fun createCompletion(request: CompletionRequest): CompletionResponse { @Suppress("DEPRECATION") request.stream = false // use streamCompletion for stream=true - val httpRequest = buildRequest(request, "completions") + val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT) try { - client.newCall(httpRequest).execute().use { response -> - - // Servers respond to API calls with json blocks. Since raw JSON isn't - // very developer friendly, we wrap for easy data access. - val rootObject = JsonParser.parseString(response.body!!.string()).asJsonObject - if (rootObject.has("error")) - throw OpenAIError.fromJson(rootObject.get("error").asJsonObject) + val httpResponse = client.newCall(httpRequest).execute(); + lateinit var response: CompletionResponse + MyCallback(true, { throw it }) { + response = gson.fromJson(it, CompletionResponse::class.java) + }.onResponse(httpResponse) - return gson.fromJson(rootObject, CompletionResponse::class.java) - } + return response } catch (ex: IOException) { - // Wrap the IOException, so we don't need to catch multiple exceptions throw WrappedIOError(ex) } } /** - * Helper method to call [streamCompletion]. + * Create completion async * - * @param request The input information for ChatGPT. - * @param onResponse The method to call for each chunk. + * @param request + * @param onResponse + * @param onFailure * @since 1.3.0 */ - fun streamCompletionKotlin(request: CompletionRequest, onResponse: CompletionResponseChunk.() -> Unit) { - streamCompletion(request, { it.onResponse() }) + @JvmOverloads + fun createCompletionAsync( + request: CompletionRequest, + onResponse: Consumer, + onFailure: Consumer = Consumer { it.printStackTrace() } + ) { + @Suppress("DEPRECATION") + request.stream = false // use streamCompletionAsync for stream=true + val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT) + + client.newCall(httpRequest).enqueue(MyCallback(false, onFailure) { + val response = gson.fromJson(it, CompletionResponse::class.java) + onResponse.accept(response) + }) } /** - * This method does not block the thread. Method calls to [onResponse] are - * not handled by the main thread. It is crucial to consider thread safety - * within the context of your program. + * Calls OpenAI's Completions API using a *stream* of data. Streams allow + * developers to access tokens in real time as they are generated. This is + * used to create the "scrolling text" or "living typing" effect. Using + * `streamCompletion` gives users information immediately, as opposed to + * `createCompletion` where you have to wait for the entire message to + * generate. + * + * This method blocks the current thread until the stream is complete. For + * non-blocking options, use [streamCompletionAsync]. It is important to + * consider which thread you are currently running on. Running this method + * on [javax.swing]'s thread, for example, will cause your UI to freeze + * temporarily. * - * @param request The input information for ChatGPT. - * @param onResponse The method to call for each chunk. - * @param onFailure The method to call if the HTTP fails. This method will - * not be called if OpenAI returns an error. - * @see createCompletion - * @see streamCompletionKotlin + * @param request The data to send to the API endpoint. + * @param onResponse The code to execute for every chunk of text. + * @param onFailure The code to execute when a failure occurs. * @since 1.3.0 */ @JvmOverloads fun streamCompletion( request: CompletionRequest, - onResponse: Consumer, // use Consumer instead of Kotlin for better Java syntax + onResponse: Consumer, onFailure: Consumer = Consumer { it.printStackTrace() } ) { @Suppress("DEPRECATION") - request.stream = true // use requestResponse for stream=false - val httpRequest = buildRequest(request, "completions") - - client.newCall(httpRequest).enqueue(object : Callback { - - override fun onFailure(call: Call, e: IOException) { - onFailure.accept(WrappedIOError(e)) - } + request.stream = true // use createCompletion for stream=false + val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT) - override fun onResponse(call: Call, response: Response) { - response.body?.source()?.use { source -> - while (!source.exhausted()) { - - // Parse the JSON string as a map. Every string starts - // with "data: ", so we need to remove that. - var jsonResponse = source.readUtf8Line() ?: continue - if (jsonResponse.isEmpty()) - continue - - // TODO comment - if (!jsonResponse.startsWith("data: ")) { - System.err.println(jsonResponse) - continue - } - - jsonResponse = jsonResponse.substring("data: ".length) - if (jsonResponse == "[DONE]") - continue + try { + val httpResponse = client.newCall(httpRequest).execute() + MyCallback(true, onFailure) { + val response = gson.fromJson(it, CompletionResponseChunk::class.java) + onResponse.accept(response) + }.onResponse(httpResponse) + } catch (ex: IOException) { + onFailure.accept(WrappedIOError(ex)) + } + } - val rootObject = JsonParser.parseString(jsonResponse).asJsonObject - if (rootObject.has("error")) - throw OpenAIError.fromJson(rootObject.get("error").asJsonObject) + /** + * Calls OpenAI's Completions API using a *stream* of data. Streams allow + * developers to access tokens in real time as they are generated. This is + * used to create the "scrolling text" or "living typing" effect. Using + * `streamCompletion` gives users information immediately, as opposed to + * `createCompletion` where you have to wait for the entire message to + * generate. + * + * This method will not block the current thread. The code block [onResponse] + * will be run later on a different thread. Due to the different thread, it + * is important to consider thread safety in the context of your program. To + * avoid thread safety issues, use [streamCompletion] to block the main thread. + * + * @param request The data to send to the API endpoint. + * @param onResponse The code to execute for every chunk of text. + * @param onFailure The code to execute when a failure occurs. + * @since 1.3.0 + */ + @JvmOverloads + fun streamCompletionAsync( + request: CompletionRequest, + onResponse: Consumer, + onFailure: Consumer = Consumer { it.printStackTrace() } + ) { + @Suppress("DEPRECATION") + request.stream = true // use createCompletionAsync for stream=false + val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT) - val cache = gson.fromJson(rootObject, CompletionResponseChunk::class.java) - onResponse.accept(cache) - } - } - } + client.newCall(httpRequest).enqueue(MyCallback(true, onFailure) { + val response = gson.fromJson(it, CompletionResponseChunk::class.java) + onResponse.accept(response) }) } /** - * Blocks the current thread until OpenAI responds to https request. The - * returned value includes information including tokens, generated text, - * and stop reason. You can access the generated message through - * [ChatResponse.choices]. * - * @param request The input information for ChatGPT. - * @return The returned response. - * @throws OpenAIError Invalid request/timeout/io/etc. + * @param request The input information for the Completions API. + * @return The value returned by the Completions API. + * @since 1.3.0 */ @Throws(OpenAIError::class) fun createChatCompletion(request: ChatRequest): ChatResponse { @Suppress("DEPRECATION") - request.stream = false // use streamResponse for stream=true - val httpRequest = buildRequest(request, "chat/completions") + request.stream = false // use streamChatCompletion for stream=true + val httpRequest = buildRequest(request, CHAT_ENDPOINT) try { - client.newCall(httpRequest).execute().use { response -> - - // Servers respond to API calls with json blocks. Since raw JSON isn't - // very developer friendly, we wrap for easy data access. - val rootObject = JsonParser.parseString(response.body!!.string()).asJsonObject - if (rootObject.has("error")) - throw OpenAIError.fromJson(rootObject.get("error").asJsonObject) + val httpResponse = client.newCall(httpRequest).execute(); + lateinit var response: ChatResponse + MyCallback(true, { throw it }) { + response = gson.fromJson(it, ChatResponse::class.java) + }.onResponse(httpResponse) - return gson.fromJson(rootObject, ChatResponse::class.java) - } + return response } catch (ex: IOException) { - // Wrap the IOException, so we don't need to catch multiple exceptions throw WrappedIOError(ex) } } /** - * This is a helper method that calls [streamChatCompletion], which lets you use - * the generated tokens in real time (As ChatGPT generates them). - * - * This method does not block the thread. Method calls to [onResponse] are - * not handled by the main thread. It is crucial to consider thread safety - * within the context of your program. - * - * Usage: - * ``` - * val messages = mutableListOf("Write a poem".toUserMessage()) - * val request = ChatRequest("gpt-3.5-turbo", messages) - * val bot = ChatBot(/* your key */) - - * bot.streamResponseKotlin(request) { - * print(choices[0].delta) - * - * // when finishReason != null, this is the last message (done generating new tokens) - * if (choices[0].finishReason != null) - * messages.add(choices[0].message) - * } - * ``` + * Create completion async * - * @param request The input information for ChatGPT. - * @param onResponse The method to call for each chunk. - * @since 1.2.0 + * @param request + * @param onResponse + * @param onFailure + * @since 1.3.0 */ - fun streamChatCompletionKotlin(request: ChatRequest, onResponse: ChatResponseChunk.() -> Unit) { - streamChatCompletion(request, { it.onResponse() }) + @JvmOverloads + fun createChatCompletionAsync( + request: ChatRequest, + onResponse: Consumer, + onFailure: Consumer = Consumer { it.printStackTrace() } + ) { + @Suppress("DEPRECATION") + request.stream = false // use streamChatCompletionAsync for stream=true + val httpRequest = buildRequest(request, CHAT_ENDPOINT) + + client.newCall(httpRequest).enqueue(MyCallback(false, onFailure) { + val response = gson.fromJson(it, ChatResponse::class.java) + onResponse.accept(response) + }) } /** - * Uses ChatGPT to generate tokens in real time. As ChatGPT generates - * content, those tokens are sent in a stream in real time. This allows you - * to update the user without long delays between their input and OpenAI's - * response. - * - * For *"simpler"* calls, you can use [createChatCompletion] which will block - * the thread until the entire response is generated. + * Calls OpenAI's Completions API using a *stream* of data. Streams allow + * developers to access tokens in real time as they are generated. This is + * used to create the "scrolling text" or "living typing" effect. Using + * `streamCompletion` gives users information immediately, as opposed to + * `createCompletion` where you have to wait for the entire message to + * generate. * - * Instead of using the [ChatResponse], this method uses [ChatResponseChunk]. - * This means that it is not possible to retrieve the number of tokens from - * this method, + * This method blocks the current thread until the stream is complete. For + * non-blocking options, use [streamCompletionAsync]. It is important to + * consider which thread you are currently running on. Running this method + * on [javax.swing]'s thread, for example, will cause your UI to freeze + * temporarily. * - * This method does not block the thread. Method calls to [onResponse] are - * not handled by the main thread. It is crucial to consider thread safety - * within the context of your program. - * - * @param request The input information for ChatGPT. - * @param onResponse The method to call for each chunk. - * @param onFailure The method to call if the HTTP fails. This method will - * not be called if OpenAI returns an error. - * @see createChatCompletion - * @see streamChatCompletionKotlin - * @since 1.2.0 + * @param request The data to send to the API endpoint. + * @param onResponse The code to execute for every chunk of text. + * @param onFailure The code to execute when a failure occurs. + * @since 1.3.0 */ @JvmOverloads fun streamChatCompletion( request: ChatRequest, - onResponse: Consumer, // use Consumer instead of Kotlin for better Java syntax - onFailure: Consumer = Consumer { it.printStackTrace() } + onResponse: Consumer, + onFailure: Consumer = Consumer { it.printStackTrace() } ) { @Suppress("DEPRECATION") request.stream = true // use requestResponse for stream=false - val httpRequest = buildRequest(request, "chat/completions") - - client.newCall(httpRequest).enqueue(object : Callback { - var cache: ChatResponseChunk? = null - - override fun onFailure(call: Call, e: IOException) { - onFailure.accept(WrappedIOError(e)) - } - - override fun onResponse(call: Call, response: Response) { - response.body?.source()?.use { source -> - while (!source.exhausted()) { + val httpRequest = buildRequest(request, CHAT_ENDPOINT) - // Parse the JSON string as a map. Every string starts - // with "data: ", so we need to remove that. - var jsonResponse = source.readUtf8Line() ?: continue - if (jsonResponse.isEmpty()) - continue - jsonResponse = jsonResponse.substring("data: ".length) - if (jsonResponse == "[DONE]") - continue - - val rootObject = JsonParser.parseString(jsonResponse).asJsonObject - if (rootObject.has("error")) - throw OpenAIError.fromJson(rootObject.get("error").asJsonObject) + try { + val httpResponse = client.newCall(httpRequest).execute() + MyCallback(true, onFailure) { + val response = gson.fromJson(it, ChatResponseChunk::class.java) + onResponse.accept(response) + }.onResponse(httpResponse) + } catch (ex: IOException) { + onFailure.accept(WrappedIOError(ex)) + } + } - if (cache == null) - cache = gson.fromJson(rootObject, ChatResponseChunk::class.java) - else - cache!!.update(rootObject) + /** + * Calls OpenAI's Completions API using a *stream* of data. Streams allow + * developers to access tokens in real time as they are generated. This is + * used to create the "scrolling text" or "living typing" effect. Using + * `streamCompletion` gives users information immediately, as opposed to + * `createCompletion` where you have to wait for the entire message to + * generate. + * + * This method will not block the current thread. The code block [onResponse] + * will be run later on a different thread. Due to the different thread, it + * is important to consider thread safety in the context of your program. To + * avoid thread safety issues, use [streamCompletion] to block the main thread. + * + * @param request The data to send to the API endpoint. + * @param onResponse The code to execute for every chunk of text. + * @param onFailure The code to execute when a failure occurs. + * @since 1.3.0 + */ + @JvmOverloads + fun streamChatCompletionAsync( + request: CompletionRequest, + onResponse: Consumer, + onFailure: Consumer = Consumer { it.printStackTrace() } + ) { + @Suppress("DEPRECATION") + request.stream = true // use requestResponse for stream=false + val httpRequest = buildRequest(request, CHAT_ENDPOINT) - onResponse.accept(cache!!) - } - } - } + client.newCall(httpRequest).enqueue(MyCallback(true, onFailure) { + val response = gson.fromJson(it, ChatResponseChunk::class.java) + onResponse.accept(response) }) } companion object { + const val COMPLETIONS_ENDPOINT = "completions" + const val CHAT_ENDPOINT = "chat/completions" + /** * Returns a `Gson` object that can be used to read/write .json files. * This can be used to save requests/responses to a file, so you can @@ -325,10 +338,10 @@ class OpenAI @JvmOverloads constructor( */ @JvmStatic fun createGsonBuilder(): GsonBuilder { - return GsonBuilder() - .registerTypeAdapter(ChatUser::class.java, ChatUserAdapter()) - .registerTypeAdapter(FinishReason::class.java, FinishReasonAdapter()) - .registerTypeAdapter(ChatChoiceChunk::class.java, ChatChoiceChunkAdapter()) + return GsonBuilder() + .registerTypeAdapter(ChatUser::class.java, ChatUserAdapter()) + .registerTypeAdapter(FinishReason::class.java, FinishReasonAdapter()) + .registerTypeAdapter(ChatChoiceChunk::class.java, ChatChoiceChunkAdapter()) } } } \ No newline at end of file diff --git a/src/test/kotlin/KotlinChatStreamTest.kt b/src/test/kotlin/KotlinChatStreamTest.kt index 6dffc11..ac296e4 100644 --- a/src/test/kotlin/KotlinChatStreamTest.kt +++ b/src/test/kotlin/KotlinChatStreamTest.kt @@ -2,8 +2,15 @@ import com.cjcrafter.openai.OpenAI import com.cjcrafter.openai.chat.ChatMessage.Companion.toSystemMessage import com.cjcrafter.openai.chat.ChatMessage.Companion.toUserMessage import com.cjcrafter.openai.chat.ChatRequest +import com.cjcrafter.openai.completions.CompletionResponseChunk +import com.cjcrafter.openai.exception.OpenAIError import io.github.cdimascio.dotenv.dotenv +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.RequestBody +import okhttp3.RequestBody.Companion.toRequestBody import java.util.* +import java.util.function.Consumer fun main(args: Array) { val scan = Scanner(System.`in`) @@ -23,12 +30,12 @@ fun main(args: Array) { // Generate a response, and print it to the user. messages.add(input.toUserMessage()) - openai.streamChatCompletionKotlin(request) { - print(choices[0].delta) + openai.streamChatCompletion(request, { + print(it[0].delta) // Once the message is complete, we should save the message to our // conversation (In case you want to generate more responses). - if (choices[0].finishReason != null) - messages.add(choices[0].message) - } -} \ No newline at end of file + if (it[0].finishReason != null) + messages.add(it[0].message) + }) +} diff --git a/src/test/kotlin/KotlinCompletionStreamTest.kt b/src/test/kotlin/KotlinCompletionStreamTest.kt index 5bbf449..83dba7d 100644 --- a/src/test/kotlin/KotlinCompletionStreamTest.kt +++ b/src/test/kotlin/KotlinCompletionStreamTest.kt @@ -5,7 +5,7 @@ import io.github.cdimascio.dotenv.dotenv fun main(args: Array) { // Prepare the ChatRequest - val request = CompletionRequest(model="davinci", prompt="Hello darkness", maxTokens = 1024) + val request = CompletionRequest(model="davinci", prompt="The wheels on the bus", maxTokens = 128) // Loads the API key from the .env file in the root directory. val key = dotenv()["OPENAI_TOKEN"] @@ -13,7 +13,7 @@ fun main(args: Array) { // Generate a response, and print it to the user //println(openai.createCompletion(request)) - openai.streamCompletionKotlin(request) { + val list = openai.streamCompletionKotlin(request) { print(choices[0].text) } } \ No newline at end of file From c75b97f20e2ee48049652e77aab344bd3bbaaa32 Mon Sep 17 00:00:00 2001 From: Collin Date: Mon, 3 Apr 2023 11:16:15 -0400 Subject: [PATCH 07/12] add isFinished methods for chunks --- .../kotlin/com/cjcrafter/openai/chat/ChatChoiceChunk.kt | 7 +++++++ .../openai/completions/CompletionChoiceChunk.kt | 9 ++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/main/kotlin/com/cjcrafter/openai/chat/ChatChoiceChunk.kt b/src/main/kotlin/com/cjcrafter/openai/chat/ChatChoiceChunk.kt index 13d4624..75ac8ac 100644 --- a/src/main/kotlin/com/cjcrafter/openai/chat/ChatChoiceChunk.kt +++ b/src/main/kotlin/com/cjcrafter/openai/chat/ChatChoiceChunk.kt @@ -37,6 +37,13 @@ data class ChatChoiceChunk( message.content += delta finishReason = if (json["finish_reason"].isJsonNull) null else FinishReason.valueOf(json["finish_reason"].asString.uppercase()) } + + /** + * Returns `true` if this message chunk is complete. Once complete, no more + * tokens will be generated, and [ChatChoiceChunk.message] will contain the + * complete message. + */ + fun isFinished() = finishReason != null } /* diff --git a/src/main/kotlin/com/cjcrafter/openai/completions/CompletionChoiceChunk.kt b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionChoiceChunk.kt index 2adec63..c29ae96 100644 --- a/src/main/kotlin/com/cjcrafter/openai/completions/CompletionChoiceChunk.kt +++ b/src/main/kotlin/com/cjcrafter/openai/completions/CompletionChoiceChunk.kt @@ -1,6 +1,7 @@ package com.cjcrafter.openai.completions import com.cjcrafter.openai.FinishReason +import com.cjcrafter.openai.chat.ChatChoiceChunk import com.google.gson.annotations.SerializedName /** @@ -25,4 +26,10 @@ data class CompletionChoiceChunk( val index: Int, val logprobs: List?, @field:SerializedName("finish_reason") val finishReason: FinishReason? -) +) { + /** + * Returns `true` if this message chunk is complete. Once complete, no more + * tokens will be generated. + */ + fun isFinished() = finishReason != null +} From 305f37190ca90c2227e0b13ef175191929aef686 Mon Sep 17 00:00:00 2001 From: Collin Date: Mon, 3 Apr 2023 11:32:01 -0400 Subject: [PATCH 08/12] add basic testing code for java and kotlin --- .../kotlin/com/cjcrafter/openai/MyCallback.kt | 30 ++-- .../kotlin/com/cjcrafter/openai/OpenAI.kt | 20 ++- src/test/java/JavaChatStreamTest.java | 38 ---- src/test/java/JavaChatTest.java | 57 ------ src/test/java/JavaTest.java | 170 ++++++++++++++++++ src/test/kotlin/KotlinChatStreamTest.kt | 41 ----- src/test/kotlin/KotlinChatTest.kt | 48 ----- src/test/kotlin/KotlinCompletionStreamTest.kt | 19 -- src/test/kotlin/KotlinTest.kt | 150 ++++++++++++++++ 9 files changed, 355 insertions(+), 218 deletions(-) delete mode 100644 src/test/java/JavaChatStreamTest.java delete mode 100644 src/test/java/JavaChatTest.java create mode 100644 src/test/java/JavaTest.java delete mode 100644 src/test/kotlin/KotlinChatStreamTest.kt delete mode 100644 src/test/kotlin/KotlinChatTest.kt delete mode 100644 src/test/kotlin/KotlinCompletionStreamTest.kt create mode 100644 src/test/kotlin/KotlinTest.kt diff --git a/src/main/kotlin/com/cjcrafter/openai/MyCallback.kt b/src/main/kotlin/com/cjcrafter/openai/MyCallback.kt index 6465536..81ac203 100644 --- a/src/main/kotlin/com/cjcrafter/openai/MyCallback.kt +++ b/src/main/kotlin/com/cjcrafter/openai/MyCallback.kt @@ -3,6 +3,7 @@ package com.cjcrafter.openai import com.cjcrafter.openai.exception.OpenAIError import com.cjcrafter.openai.exception.WrappedIOError import com.google.gson.JsonObject +import com.google.gson.JsonParseException import com.google.gson.JsonParser import okhttp3.Call import okhttp3.Callback @@ -45,20 +46,29 @@ internal class MyCallback( private fun handleStream(response: Response) { response.body?.source()?.use { source -> - while (!source.exhausted()) { - var jsonResponse = source.readUtf8() - // OpenAI returns a json string, but they prepend the content with - // "data: " (which is not valid json). In order to parse this into - // a JsonObject, we have to strip away this extra string. - jsonResponse = jsonResponse.substring("data: ".length) + while (!source.exhausted()) { + var jsonResponse = source.readUtf8Line() - // After OpenAI's final message (which already contains a non-null - // finish reason), they redundantly send "data: [DONE]". Ignore it. - if (jsonResponse == "[DONE]") + // Or data is separated by empty lines, ignore them. The final + // line is always "data: [DONE]", ignore it. + if (jsonResponse.isNullOrEmpty() || jsonResponse == "data: [DONE]") continue - val rootObject = JsonParser.parseString(jsonResponse).asJsonObject + // The CHAT API returns a json string, but they prepend the content + // with "data: " (which is not valid json). In order to parse this + // into a JsonObject, we have to strip away this extra string. + if (jsonResponse.startsWith("data: ")) + jsonResponse = jsonResponse.substring("data: ".length) + + lateinit var rootObject: JsonObject + try { + rootObject = JsonParser.parseString(jsonResponse).asJsonObject + } catch (ex: JsonParseException) { + println(jsonResponse) + ex.printStackTrace() + continue + } // Sometimes OpenAI will respond with an error code for malformed // requests, timeouts, rate limits, etc. We need to let the dev diff --git a/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt b/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt index 3f13220..129c869 100644 --- a/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt +++ b/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt @@ -262,9 +262,14 @@ class OpenAI @JvmOverloads constructor( try { val httpResponse = client.newCall(httpRequest).execute() + var response: ChatResponseChunk? = null MyCallback(true, onFailure) { - val response = gson.fromJson(it, ChatResponseChunk::class.java) - onResponse.accept(response) + if (response == null) + response = gson.fromJson(it, ChatResponseChunk::class.java) + else + response!!.update(it) + + onResponse.accept(response!!) }.onResponse(httpResponse) } catch (ex: IOException) { onFailure.accept(WrappedIOError(ex)) @@ -291,7 +296,7 @@ class OpenAI @JvmOverloads constructor( */ @JvmOverloads fun streamChatCompletionAsync( - request: CompletionRequest, + request: ChatRequest, onResponse: Consumer, onFailure: Consumer = Consumer { it.printStackTrace() } ) { @@ -299,9 +304,14 @@ class OpenAI @JvmOverloads constructor( request.stream = true // use requestResponse for stream=false val httpRequest = buildRequest(request, CHAT_ENDPOINT) + var response: ChatResponseChunk? = null client.newCall(httpRequest).enqueue(MyCallback(true, onFailure) { - val response = gson.fromJson(it, ChatResponseChunk::class.java) - onResponse.accept(response) + if (response == null) + response = gson.fromJson(it, ChatResponseChunk::class.java) + else + response!!.update(it) + + onResponse.accept(response!!) }) } diff --git a/src/test/java/JavaChatStreamTest.java b/src/test/java/JavaChatStreamTest.java deleted file mode 100644 index 2cf6ed7..0000000 --- a/src/test/java/JavaChatStreamTest.java +++ /dev/null @@ -1,38 +0,0 @@ -import com.cjcrafter.openai.OpenAI; -import com.cjcrafter.openai.chat.*; -import io.github.cdimascio.dotenv.Dotenv; - -import java.util.*; - -public class JavaChatStreamTest { - - public static void main(String[] args) { - Scanner scan = new Scanner(System.in); - - // Prepare the ChatRequest - ChatMessage prompt = ChatMessage.toSystemMessage("Be as unhelpful as possible"); - List messages = new ArrayList<>(Collections.singletonList(prompt)); - ChatRequest request = ChatRequest.builder() - .model("gpt-3.5-turbo") - .messages(messages).build(); - - // Load TOKEN from .env file - String key = Dotenv.load().get("OPENAI_TOKEN"); - OpenAI openai = new OpenAI(key); - - // Ask the user for input - System.out.println("Enter text below:\n\n"); - String input = scan.nextLine(); - - // Stream the response. Print out each 'delta' (new tokens) - messages.add(new ChatMessage(ChatUser.USER, input)); - openai.streamChatCompletion(request, message -> { - System.out.print(message.get(0).getDelta()); - - // Once the message is complete, we should save the message to our - // conversation (In case you want to generate more responses). - if (message.get(0).getFinishReason() != null) - messages.add(message.get(0).getMessage()); - }); - } -} diff --git a/src/test/java/JavaChatTest.java b/src/test/java/JavaChatTest.java deleted file mode 100644 index efe9466..0000000 --- a/src/test/java/JavaChatTest.java +++ /dev/null @@ -1,57 +0,0 @@ -import com.cjcrafter.openai.OpenAI; -import com.cjcrafter.openai.chat.*; -import com.cjcrafter.openai.exception.OpenAIError; -import io.github.cdimascio.dotenv.Dotenv; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Scanner; - -public class JavaChatTest { - - public static void main(String[] args) throws OpenAIError { - Scanner scan = new Scanner(System.in); - - // This is the prompt that the bot will refer back to for every message. - ChatMessage prompt = ChatMessage.toSystemMessage("You are a customer support chat-bot. Write brief summaries of the user's questions so that agents can easily find the answer in a database."); - - // Use a mutable (modifiable) list! Always! You should be reusing the - // ChatRequest variable, so in order for a conversation to continue - // you need to be able to modify the list. - List messages = new ArrayList<>(Collections.singletonList(prompt)); - - // ChatRequest is the request we send to OpenAI API. You can modify the - // model, temperature, maxTokens, etc. This should be saved, so you can - // reuse it for a conversation. - ChatRequest request = ChatRequest.builder() - .model("gpt-3.5-turbo") - .messages(messages).build(); - - // Loads the API key from the .env file in the root directory. - String key = Dotenv.load().get("OPENAI_TOKEN"); - OpenAI openai = new OpenAI(key); - - // The conversation lasts until the user quits the program - while (true) { - - // Prompt the user to enter a response - System.out.println("Enter text below:\n\n"); - String input = scan.nextLine(); - - // Add the newest user message to the conversation - messages.add(ChatMessage.toUserMessage(input)); - - // Use the OpenAI API to generate a response to the current - // conversation. Print the resulting message. - ChatResponse response = openai.createChatCompletion(request); - System.out.println("\n" + response.get(0).getMessage().getContent()); - - // Save the generated message to the conversational memory. It is - // crucial to save this message, otherwise future requests will be - // confused that there was no response. - messages.add(response.get(0).getMessage()); - } - } -} diff --git a/src/test/java/JavaTest.java b/src/test/java/JavaTest.java new file mode 100644 index 0000000..249d349 --- /dev/null +++ b/src/test/java/JavaTest.java @@ -0,0 +1,170 @@ +import com.cjcrafter.openai.OpenAI; +import com.cjcrafter.openai.chat.ChatMessage; +import com.cjcrafter.openai.chat.ChatRequest; +import com.cjcrafter.openai.chat.ChatResponse; +import com.cjcrafter.openai.completions.CompletionRequest; +import com.cjcrafter.openai.exception.OpenAIError; +import io.github.cdimascio.dotenv.Dotenv; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Scanner; + +public class JavaTest { + + // Colors for pretty formatting + public static final String RESET = "\033[0m"; + public static final String BLACK = "\033[0;30m"; + public static final String RED = "\033[0;31m"; + public static final String GREEN = "\033[0;32m"; + public static final String YELLOW = "\033[0;33m"; + public static final String BLUE = "\033[0;34m"; + public static final String PURPLE = "\033[0;35m"; + public static final String CYAN = "\033[0;36m"; + public static final String WHITE = "\033[0;37m"; + + public static void main(String[] args) throws OpenAIError { + Scanner scanner = new Scanner(System.in); + + // Print out the menu of options + System.out.println(GREEN + "Please select one of the options below by typing a number."); + System.out.println(); + System.out.println(GREEN + " 1. Completion (create, sync)"); + System.out.println(GREEN + " 2. Completion (stream, sync)"); + System.out.println(GREEN + " 3. Completion (create, async)"); + System.out.println(GREEN + " 4. Completion (stream, async)"); + System.out.println(GREEN + " 5. Chat (create, sync)"); + System.out.println(GREEN + " 6. Chat (stream, sync)"); + System.out.println(GREEN + " 7. Chat (create, async)"); + System.out.println(GREEN + " 8. Chat (stream, async)"); + System.out.println(); + + // Determine which method to call + switch (scanner.nextLine()) { + case "1": + doCompletion(false, false); + break; + case "2": + doCompletion(true, false); + break; + case "3": + doCompletion(false, true); + break; + case "4": + doCompletion(true, true); + break; + case "5": + doChat(false, false); + break; + case "6": + doChat(true, false); + break; + case "7": + doChat(false, true); + break; + case "8": + doChat(true, true); + break; + default: + System.err.println("Invalid option"); + break; + } + } + + public static void doCompletion(boolean stream, boolean async) throws OpenAIError { + Scanner scan = new Scanner(System.in); + System.out.println(YELLOW + "Enter completion: "); + String input = scan.nextLine(); + + // CompletionRequest contains the data we sent to the OpenAI API. We use + // 128 tokens, so we have a bit of a delay before the response (for testing). + CompletionRequest request = CompletionRequest.builder() + .model("davinci") + .prompt(input) + .maxTokens(128).build(); + + // Loads the API key from the .env file in the root directory. + String key = Dotenv.load().get("OPENAI_TOKEN"); + OpenAI openai = new OpenAI(key); + + System.out.println(RESET + "Generating Response" + PURPLE); + if (stream) { + if (async) + openai.streamCompletionAsync(request, response -> System.out.print(response.get(0).getText())); + else + openai.streamCompletion(request, response -> System.out.print(response.get(0).getText())); + } else { + if (async) + openai.createCompletionAsync(request, response -> System.out.println(response.get(0).getText())); + else + System.out.println(openai.createCompletion(request).get(0).getText()); + } + + System.out.println(CYAN + " !!! Code has finished executing. Wait for async code to complete." + RESET); + } + + public static void doChat(boolean stream, boolean async) throws OpenAIError { + Scanner scan = new Scanner(System.in); + + // This is the prompt that the bot will refer back to for every message. + ChatMessage prompt = ChatMessage.toSystemMessage("You are a customer support chat-bot. Write brief summaries of the user's questions so that agents can easily find the answer in a database."); + + // Use a mutable (modifiable) list! Always! You should be reusing the + // ChatRequest variable, so in order for a conversation to continue + // you need to be able to modify the list. + List messages = new ArrayList<>(Collections.singletonList(prompt)); + + // ChatRequest is the request we send to OpenAI API. You can modify the + // model, temperature, maxTokens, etc. This should be saved, so you can + // reuse it for a conversation. + ChatRequest request = ChatRequest.builder() + .model("gpt-3.5-turbo") + .messages(messages).build(); + + // Loads the API key from the .env file in the root directory. + String key = Dotenv.load().get("OPENAI_TOKEN"); + OpenAI openai = new OpenAI(key); + + // The conversation lasts until the user quits the program + while (true) { + + // Prompt the user to enter a response + System.out.println(YELLOW + "Enter text below:\n\n"); + String input = scan.nextLine(); + + // Add the newest user message to the conversation + messages.add(ChatMessage.toUserMessage(input)); + + System.out.println(RESET + "Generating Response" + PURPLE); + if (stream) { + if (async) { + openai.streamChatCompletionAsync(request, response -> { + System.out.print(response.get(0).getDelta()); + if (response.get(0).isFinished()) + messages.add(response.get(0).getMessage()); + }); + } else { + openai.streamChatCompletion(request, response -> { + System.out.print(response.get(0).getDelta()); + if (response.get(0).isFinished()) + messages.add(response.get(0).getMessage()); + }); + } + } else { + if (async) { + openai.createChatCompletionAsync(request, response -> { + System.out.println(response.get(0).getMessage().getContent()); + messages.add(response.get(0).getMessage()); + }); + } else { + ChatResponse response = openai.createChatCompletion(request); + System.out.println(response.get(0).getMessage().getContent()); + messages.add(response.get(0).getMessage()); + } + } + + System.out.println(CYAN + " !!! Code has finished executing. Wait for async code to complete."); + } + } +} \ No newline at end of file diff --git a/src/test/kotlin/KotlinChatStreamTest.kt b/src/test/kotlin/KotlinChatStreamTest.kt deleted file mode 100644 index ac296e4..0000000 --- a/src/test/kotlin/KotlinChatStreamTest.kt +++ /dev/null @@ -1,41 +0,0 @@ -import com.cjcrafter.openai.OpenAI -import com.cjcrafter.openai.chat.ChatMessage.Companion.toSystemMessage -import com.cjcrafter.openai.chat.ChatMessage.Companion.toUserMessage -import com.cjcrafter.openai.chat.ChatRequest -import com.cjcrafter.openai.completions.CompletionResponseChunk -import com.cjcrafter.openai.exception.OpenAIError -import io.github.cdimascio.dotenv.dotenv -import okhttp3.OkHttpClient -import okhttp3.Request -import okhttp3.RequestBody -import okhttp3.RequestBody.Companion.toRequestBody -import java.util.* -import java.util.function.Consumer - -fun main(args: Array) { - val scan = Scanner(System.`in`) - - // Prepare the ChatRequest - val prompt = "Be as unhelpful as possible" - val messages = mutableListOf(prompt.toSystemMessage()) - val request = ChatRequest(model="gpt-3.5-turbo", messages=messages) - - // Loads the API key from the .env file in the root directory. - val key = dotenv()["OPENAI_TOKEN"] - val openai = OpenAI(key) - - // Ask the user for input - println("Enter text below:\n") - val input = scan.nextLine() - - // Generate a response, and print it to the user. - messages.add(input.toUserMessage()) - openai.streamChatCompletion(request, { - print(it[0].delta) - - // Once the message is complete, we should save the message to our - // conversation (In case you want to generate more responses). - if (it[0].finishReason != null) - messages.add(it[0].message) - }) -} diff --git a/src/test/kotlin/KotlinChatTest.kt b/src/test/kotlin/KotlinChatTest.kt deleted file mode 100644 index f120a60..0000000 --- a/src/test/kotlin/KotlinChatTest.kt +++ /dev/null @@ -1,48 +0,0 @@ -import com.cjcrafter.openai.OpenAI -import com.cjcrafter.openai.chat.ChatMessage.Companion.toSystemMessage -import com.cjcrafter.openai.chat.ChatMessage.Companion.toUserMessage -import com.cjcrafter.openai.chat.ChatRequest -import io.github.cdimascio.dotenv.dotenv -import java.util.* - -fun main(args: Array) { - val scan = Scanner(System.`in`) - - // This is the prompt that the bot will refer back to for every message. - val prompt = "You are a customer support chat-bot. Write brief summaries of the user's questions so that agents can easily find the answer in a database." - - // Use a mutable (modifiable) list! Always! You should be reusing the - // ChatRequest variable, so in order for a conversation to continue you - // need to be able to modify the list. - val messages = mutableListOf(prompt.toSystemMessage()) - - // ChatRequest is the request we send to OpenAI API. You can modify the - // model, temperature, maxTokens, etc. This should be saved, so you can - // reuse it for a conversation. - val request = ChatRequest(model="gpt-3.5-turbo", messages=messages) - - // Loads the API key from the .env file in the root directory. - val key = dotenv()["OPENAI_TOKEN"] - val openai = OpenAI(key) - - // The conversation lasts until the user quits the program - while (true) { - - // Prompt the user to enter a response - println("Enter text below:\n") - val input = scan.nextLine() - - // Add the newest user message to the conversation - messages.add(input.toUserMessage()) - - // Use the OpenAI API to generate a response to the current - // conversation. Print the resulting message. - val response = openai.createChatCompletion(request) - println("\n${response[0].message.content}\n") - - // Save the generated message to the conversational memory. It is - // crucial to save this message, otherwise future requests will be - // confused that there was no response. - messages.add(response[0].message) - } -} \ No newline at end of file diff --git a/src/test/kotlin/KotlinCompletionStreamTest.kt b/src/test/kotlin/KotlinCompletionStreamTest.kt deleted file mode 100644 index 83dba7d..0000000 --- a/src/test/kotlin/KotlinCompletionStreamTest.kt +++ /dev/null @@ -1,19 +0,0 @@ -import com.cjcrafter.openai.OpenAI -import com.cjcrafter.openai.completions.CompletionRequest -import io.github.cdimascio.dotenv.dotenv - -fun main(args: Array) { - - // Prepare the ChatRequest - val request = CompletionRequest(model="davinci", prompt="The wheels on the bus", maxTokens = 128) - - // Loads the API key from the .env file in the root directory. - val key = dotenv()["OPENAI_TOKEN"] - val openai = OpenAI(key) - - // Generate a response, and print it to the user - //println(openai.createCompletion(request)) - val list = openai.streamCompletionKotlin(request) { - print(choices[0].text) - } -} \ No newline at end of file diff --git a/src/test/kotlin/KotlinTest.kt b/src/test/kotlin/KotlinTest.kt new file mode 100644 index 0000000..bc69546 --- /dev/null +++ b/src/test/kotlin/KotlinTest.kt @@ -0,0 +1,150 @@ +import com.cjcrafter.openai.OpenAI +import com.cjcrafter.openai.chat.ChatMessage +import com.cjcrafter.openai.chat.ChatMessage.Companion.toSystemMessage +import com.cjcrafter.openai.chat.ChatMessage.Companion.toUserMessage +import com.cjcrafter.openai.chat.ChatRequest +import com.cjcrafter.openai.chat.ChatResponse +import com.cjcrafter.openai.chat.ChatResponseChunk +import com.cjcrafter.openai.completions.CompletionRequest +import com.cjcrafter.openai.completions.CompletionResponse +import com.cjcrafter.openai.completions.CompletionResponseChunk +import com.cjcrafter.openai.exception.OpenAIError +import io.github.cdimascio.dotenv.Dotenv +import io.github.cdimascio.dotenv.dotenv +import java.util.* + +object KotlinTest { + + // Colors for pretty formatting + const val RESET = "\u001b[0m" + const val BLACK = "\u001b[0;30m" + const val RED = "\u001b[0;31m" + const val GREEN = "\u001b[0;32m" + const val YELLOW = "\u001b[0;33m" + const val BLUE = "\u001b[0;34m" + const val PURPLE = "\u001b[0;35m" + const val CYAN = "\u001b[0;36m" + const val WHITE = "\u001b[0;37m" + + @Throws(OpenAIError::class) + @JvmStatic + fun main(args: Array) { + val scanner = Scanner(System.`in`) + + // Print out the menu of options + println("""$GREEN + Please select one of the options below by typing a number. + + 1. Completion (create, sync) + 2. Completion (stream, sync) + 3. Completion (create, async) + 4. Completion (stream, async) + 5. Chat (create, sync) + 6. Chat (stream, sync) + 7. Chat (create, async) + 8. Chat (stream, async) + """.trimIndent()) + + when (scanner.nextLine().trim()) { + "1" -> doCompletion(stream = false, async = false) + "2" -> doCompletion(stream = true, async = false) + "3" -> doCompletion(stream = false, async = true) + "4" -> doCompletion(stream = true, async = true) + "5" -> doChat(stream = false, async = false) + "6" -> doChat(stream = true, async = false) + "7" -> doChat(stream = false, async = true) + "8" -> doChat(stream = true, async = true) + else -> System.err.println("Invalid option") + } + } + + @Throws(OpenAIError::class) + fun doCompletion(stream: Boolean, async: Boolean) { + val scan = Scanner(System.`in`) + println(YELLOW + "Enter completion: ") + val input = scan.nextLine() + + // CompletionRequest contains the data we sent to the OpenAI API. We use + // 128 tokens, so we have a bit of a delay before the response (for testing). + val request = CompletionRequest.builder() + .model("davinci") + .prompt(input) + .maxTokens(128).build() + + // Loads the API key from the .env file in the root directory. + val key = Dotenv.load()["OPENAI_TOKEN"] + val openai = OpenAI(key) + println(RESET + "Generating Response" + PURPLE) + if (stream) { + if (async) + openai.streamCompletionAsync(request, { print(it[0].text) }) + else + openai.streamCompletion(request, { print(it[0].text) }) + } else { + if (async) + openai.createCompletionAsync(request, { println(it[0].text) }) + else + println(openai.createCompletion(request)[0].text) + } + println("$CYAN !!! Code has finished executing. Wait for async code to complete.$RESET") + } + + @Throws(OpenAIError::class) + fun doChat(stream: Boolean, async: Boolean) { + val scan = Scanner(System.`in`) + + // This is the prompt that the bot will refer back to for every message. + val prompt = "You are a customer support chat-bot. Write brief summaries of the user's questions so that agents can easily find the answer in a database.".toSystemMessage() + + // Use a mutable (modifiable) list! Always! You should be reusing the + // ChatRequest variable, so in order for a conversation to continue + // you need to be able to modify the list. + val messages: MutableList = ArrayList(listOf(prompt)) + + // ChatRequest is the request we send to OpenAI API. You can modify the + // model, temperature, maxTokens, etc. This should be saved, so you can + // reuse it for a conversation. + val request = ChatRequest(model="gpt-3.5-turbo", messages=messages) + + // Loads the API key from the .env file in the root directory. + val key = dotenv()["OPENAI_TOKEN"] + val openai = OpenAI(key) + + // The conversation lasts until the user quits the program + while (true) { + + // Prompt the user to enter a response + println("\n${YELLOW}Enter text below:\n") + val input = scan.nextLine() + + // Add the newest user message to the conversation + messages.add(input.toUserMessage()) + println(RESET + "Generating Response" + PURPLE) + if (stream) { + if (async) { + openai.streamChatCompletionAsync(request, { response: ChatResponseChunk -> + print(response[0].delta) + if (response[0].isFinished()) messages.add(response[0].message) + }) + } else { + openai.streamChatCompletion(request, { response: ChatResponseChunk -> + print(response[0].delta) + if (response[0].isFinished()) messages.add(response[0].message) + }) + } + } else { + if (async) { + openai.createChatCompletionAsync(request, { response: ChatResponse -> + println(response[0].message.content) + messages.add(response[0].message) + }) + } else { + val response = openai.createChatCompletion(request) + println(response[0].message.content) + messages.add(response[0].message) + } + } + println("$CYAN !!! Code has finished executing. Wait for async code to complete.") + } + } +} \ No newline at end of file From d5d4e30dc2823badae71e3b9aee32abdc40e5dc6 Mon Sep 17 00:00:00 2001 From: Collin Date: Mon, 3 Apr 2023 11:43:33 -0400 Subject: [PATCH 09/12] fix kotlin test typos --- src/test/kotlin/KotlinTest.kt | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/test/kotlin/KotlinTest.kt b/src/test/kotlin/KotlinTest.kt index bc69546..ea5e917 100644 --- a/src/test/kotlin/KotlinTest.kt +++ b/src/test/kotlin/KotlinTest.kt @@ -6,8 +6,6 @@ import com.cjcrafter.openai.chat.ChatRequest import com.cjcrafter.openai.chat.ChatResponse import com.cjcrafter.openai.chat.ChatResponseChunk import com.cjcrafter.openai.completions.CompletionRequest -import com.cjcrafter.openai.completions.CompletionResponse -import com.cjcrafter.openai.completions.CompletionResponseChunk import com.cjcrafter.openai.exception.OpenAIError import io.github.cdimascio.dotenv.Dotenv import io.github.cdimascio.dotenv.dotenv @@ -32,9 +30,8 @@ object KotlinTest { val scanner = Scanner(System.`in`) // Print out the menu of options - println("""$GREEN - Please select one of the options below by typing a number. - + println(""" + ${GREEN}Please select one of the options below by typing a number. 1. Completion (create, sync) 2. Completion (stream, sync) 3. Completion (create, async) @@ -76,17 +73,20 @@ object KotlinTest { val openai = OpenAI(key) println(RESET + "Generating Response" + PURPLE) if (stream) { - if (async) + if (async) { openai.streamCompletionAsync(request, { print(it[0].text) }) - else + println("$CYAN !!! Code has finished executing. Wait for async code to complete.$PURPLE") + } else { openai.streamCompletion(request, { print(it[0].text) }) + } } else { - if (async) + if (async) { openai.createCompletionAsync(request, { println(it[0].text) }) - else + println("$CYAN !!! Code has finished executing. Wait for async code to complete.$PURPLE") + } else { println(openai.createCompletion(request)[0].text) + } } - println("$CYAN !!! Code has finished executing. Wait for async code to complete.$RESET") } @Throws(OpenAIError::class) @@ -126,6 +126,7 @@ object KotlinTest { print(response[0].delta) if (response[0].isFinished()) messages.add(response[0].message) }) + println("$CYAN !!! Code has finished executing. Wait for async code to complete.$PURPLE") } else { openai.streamChatCompletion(request, { response: ChatResponseChunk -> print(response[0].delta) @@ -138,13 +139,13 @@ object KotlinTest { println(response[0].message.content) messages.add(response[0].message) }) + println("$CYAN !!! Code has finished executing. Wait for async code to complete.$PURPLE") } else { val response = openai.createChatCompletion(request) println(response[0].message.content) messages.add(response[0].message) } } - println("$CYAN !!! Code has finished executing. Wait for async code to complete.") } } } \ No newline at end of file From d9ddc738e4ec3644c1ed858d65a6e6192a13ec17 Mon Sep 17 00:00:00 2001 From: Collin Date: Mon, 3 Apr 2023 13:47:29 -0400 Subject: [PATCH 10/12] add completion kdocs --- .../kotlin/com/cjcrafter/openai/OpenAI.kt | 111 +++++++++++++++--- .../kotlin/com/cjcrafter/openai/GsonTests.kt | 6 + 2 files changed, 98 insertions(+), 19 deletions(-) diff --git a/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt b/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt index 129c869..18c3459 100644 --- a/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt +++ b/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt @@ -35,6 +35,22 @@ import java.util.function.Consumer * 4. Obtain your API key: After subscribing to a plan, you will be redirected * to the API dashboard, where you can find your unique API key. Copy and store it securely. * + * All API methods in this class have a non-blocking option which will enqueues + * the HTTPS request on a different thread. These method names have `Async + * appended to the end of their names. + * + * Completions API: + * * [createCompletion] + * * [streamCompletion] + * * [createCompletionAsync] + * * [streamCompletionAsync] + * + * Chat API: + * * [createChatCompletion] + * * [streamChatCompletion] + * * [createChatCompletionAsync] + * * [streamChatCompletionAsync] + * * @property apiKey Your OpenAI API key. It starts with `"sk-"` (without the quotes). * @property organization If you belong to multiple organizations, specify which one to use (else `null`). * @property client Controls proxies, timeouts, etc. @@ -60,9 +76,20 @@ class OpenAI @JvmOverloads constructor( } /** + * Predicts which text comes after the prompt, thus "completing" the text. + * + * Calls OpenAI's Completions API and waits until the entire completion is + * generated. When [CompletionRequest.maxTokens] is a big number, it will + * take a long time to generate all the tokens, so it is recommended to use + * [streamCompletionAsync] instead to allow users to see partial completions. + * + * This method blocks the current thread until the stream is complete. For + * non-blocking options, use [streamCompletionAsync]. It is important to + * consider which thread you are currently running on. Running this method + * on [javax.swing]'s thread, for example, will cause your UI to freeze + * temporarily. * - * @param request The input information for the Completions API. - * @return The value returned by the Completions API. + * @param request The data to send to the API endpoint. * @since 1.3.0 */ @Throws(OpenAIError::class) @@ -85,11 +112,21 @@ class OpenAI @JvmOverloads constructor( } /** - * Create completion async + * Predicts which text comes after the prompt, thus "completing" the text. + * + * Calls OpenAI's Completions API and waits until the entire completion is + * generated. When [CompletionRequest.maxTokens] is a big number, it will + * take a long time to generate all the tokens, so it is recommended to use + * [streamCompletionAsync] instead to allow users to see partial completions. * - * @param request - * @param onResponse - * @param onFailure + * This method will not block the current thread. The code block [onResponse] + * will be run later on a different thread. Due to the different thread, it + * is important to consider thread safety in the context of your program. To + * avoid thread safety issues, use [streamCompletion] to block the main thread. + * + * @param request The data to send to the API endpoint. + * @param onResponse The code to execute for every chunk of text. + * @param onFailure The code to execute when a failure occurs. * @since 1.3.0 */ @JvmOverloads @@ -109,6 +146,8 @@ class OpenAI @JvmOverloads constructor( } /** + * Predicts which text comes after the prompt, thus "completing" the text. + * * Calls OpenAI's Completions API using a *stream* of data. Streams allow * developers to access tokens in real time as they are generated. This is * used to create the "scrolling text" or "living typing" effect. Using @@ -149,11 +188,13 @@ class OpenAI @JvmOverloads constructor( } /** + * Predicts which text comes after the prompt, thus "completing" the text. + * * Calls OpenAI's Completions API using a *stream* of data. Streams allow * developers to access tokens in real time as they are generated. This is * used to create the "scrolling text" or "living typing" effect. Using - * `streamCompletion` gives users information immediately, as opposed to - * `createCompletion` where you have to wait for the entire message to + * `streamCompletionAsync` gives users information immediately, as opposed to + * `createCompletionAsync` where you have to wait for the entire message to * generate. * * This method will not block the current thread. The code block [onResponse] @@ -183,9 +224,22 @@ class OpenAI @JvmOverloads constructor( } /** + * Responds to the input in a conversational manner. Chat can "remember" + * older parts of the conversation by looking at the different messages in + * the list. * - * @param request The input information for the Completions API. - * @return The value returned by the Completions API. + * Calls OpenAI's Completions API and waits until the entire message is + * generated. Since generating an entire CHAT message can be time-consuming, + * it is preferred to use [streamChatCompletionAsync] instead. + * + * This method blocks the current thread until the stream is complete. For + * non-blocking options, use [createChatCompletionAsync]. It is important to + * consider which thread you are currently running on. Running this method + * on [javax.swing]'s thread, for example, will cause your UI to freeze + * temporarily. + * + * @param request The data to send to the API endpoint. + * @return The generated response. * @since 1.3.0 */ @Throws(OpenAIError::class) @@ -195,7 +249,7 @@ class OpenAI @JvmOverloads constructor( val httpRequest = buildRequest(request, CHAT_ENDPOINT) try { - val httpResponse = client.newCall(httpRequest).execute(); + val httpResponse = client.newCall(httpRequest).execute() lateinit var response: ChatResponse MyCallback(true, { throw it }) { response = gson.fromJson(it, ChatResponse::class.java) @@ -208,11 +262,22 @@ class OpenAI @JvmOverloads constructor( } /** - * Create completion async + * Responds to the input in a conversational manner. Chat can "remember" + * older parts of the conversation by looking at the different messages in + * the list. * - * @param request - * @param onResponse - * @param onFailure + * Calls OpenAI's Completions API and waits until the entire message is + * generated. Since generating an entire CHAT message can be time-consuming, + * it is preferred to use [streamChatCompletionAsync] instead. + * + * This method will not block the current thread. The code block [onResponse] + * will be run later on a different thread. Due to the different thread, it + * is important to consider thread safety in the context of your program. To + * avoid thread safety issues, use [streamChatCompletion] to block the main thread. + * + * @param request The data to send to the API endpoint. + * @param onResponse The code to execute for every chunk of text. + * @param onFailure The code to execute when a failure occurs. * @since 1.3.0 */ @JvmOverloads @@ -232,6 +297,10 @@ class OpenAI @JvmOverloads constructor( } /** + * Responds to the input in a conversational manner. Chat can "remember" + * older parts of the conversation by looking at the different messages in + * the list. + * * Calls OpenAI's Completions API using a *stream* of data. Streams allow * developers to access tokens in real time as they are generated. This is * used to create the "scrolling text" or "living typing" effect. Using @@ -277,17 +346,21 @@ class OpenAI @JvmOverloads constructor( } /** + * Responds to the input in a conversational manner. Chat can "remember" + * older parts of the conversation by looking at the different messages in + * the list. + * * Calls OpenAI's Completions API using a *stream* of data. Streams allow * developers to access tokens in real time as they are generated. This is - * used to create the "scrolling text" or "living typing" effect. Using - * `streamCompletion` gives users information immediately, as opposed to - * `createCompletion` where you have to wait for the entire message to + * used to create the "scrolling text" or "live typing" effect. Using + * `streamChatCompletionAsync` gives users information immediately, as opposed to + * [createChatCompletionAsync] where you have to wait for the entire message to * generate. * * This method will not block the current thread. The code block [onResponse] * will be run later on a different thread. Due to the different thread, it * is important to consider thread safety in the context of your program. To - * avoid thread safety issues, use [streamCompletion] to block the main thread. + * avoid thread safety issues, use [streamChatCompletion] to block the main thread. * * @param request The data to send to the API endpoint. * @param onResponse The code to execute for every chunk of text. diff --git a/src/test/kotlin/com/cjcrafter/openai/GsonTests.kt b/src/test/kotlin/com/cjcrafter/openai/GsonTests.kt index 745d085..77e4001 100644 --- a/src/test/kotlin/com/cjcrafter/openai/GsonTests.kt +++ b/src/test/kotlin/com/cjcrafter/openai/GsonTests.kt @@ -3,6 +3,7 @@ package com.cjcrafter.openai import com.cjcrafter.openai.chat.* import com.cjcrafter.openai.chat.ChatMessage.Companion.toAssistantMessage import com.cjcrafter.openai.chat.ChatMessage.Companion.toSystemMessage +import com.cjcrafter.openai.completions.CompletionRequest import com.google.gson.Gson import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.Assertions @@ -51,6 +52,11 @@ class GsonTests { "{\"id\":\"chatcmpl-123\",\"created\":1677652288,\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"Hello there, how may I assist you today?\"},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":12,\"total_tokens\":21}}", ChatResponse("chatcmpl-123", 1677652288L, mutableListOf(ChatChoice(0, "Hello there, how may I assist you today?".toAssistantMessage(), FinishReason.STOP)), ChatUsage(9, 12, 21)), ChatResponse::class.java + ), + Arguments.of( + "{\"model\":\"davinci\",\"prompt\":[\"Hello\",\"Goodbye\"]}", + CompletionRequest(model="davinci", prompt=listOf("Hello", "Goodbye")), + CompletionRequest::class.java ) ) } From 94723785b3fbd2f5d1a4e7ae6bbfd5a854fde97c Mon Sep 17 00:00:00 2001 From: Collin Date: Mon, 3 Apr 2023 13:48:09 -0400 Subject: [PATCH 11/12] rename MyCallback to OpenAICallback --- .../kotlin/com/cjcrafter/openai/OpenAI.kt | 20 ++++++++----------- .../{MyCallback.kt => OpenAICallback.kt} | 2 +- 2 files changed, 9 insertions(+), 13 deletions(-) rename src/main/kotlin/com/cjcrafter/openai/{MyCallback.kt => OpenAICallback.kt} (98%) diff --git a/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt b/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt index 18c3459..8433b87 100644 --- a/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt +++ b/src/main/kotlin/com/cjcrafter/openai/OpenAI.kt @@ -11,14 +11,10 @@ import com.cjcrafter.openai.gson.ChatUserAdapter import com.cjcrafter.openai.gson.FinishReasonAdapter import com.google.gson.Gson import com.google.gson.GsonBuilder -import com.google.gson.JsonObject -import com.google.gson.JsonParser import okhttp3.* import okhttp3.MediaType.Companion.toMediaType import okhttp3.RequestBody.Companion.toRequestBody import java.io.IOException -import java.lang.IllegalStateException -import java.util.ArrayList import java.util.function.Consumer /** @@ -101,7 +97,7 @@ class OpenAI @JvmOverloads constructor( try { val httpResponse = client.newCall(httpRequest).execute(); lateinit var response: CompletionResponse - MyCallback(true, { throw it }) { + OpenAICallback(true, { throw it }) { response = gson.fromJson(it, CompletionResponse::class.java) }.onResponse(httpResponse) @@ -139,7 +135,7 @@ class OpenAI @JvmOverloads constructor( request.stream = false // use streamCompletionAsync for stream=true val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT) - client.newCall(httpRequest).enqueue(MyCallback(false, onFailure) { + client.newCall(httpRequest).enqueue(OpenAICallback(false, onFailure) { val response = gson.fromJson(it, CompletionResponse::class.java) onResponse.accept(response) }) @@ -178,7 +174,7 @@ class OpenAI @JvmOverloads constructor( try { val httpResponse = client.newCall(httpRequest).execute() - MyCallback(true, onFailure) { + OpenAICallback(true, onFailure) { val response = gson.fromJson(it, CompletionResponseChunk::class.java) onResponse.accept(response) }.onResponse(httpResponse) @@ -217,7 +213,7 @@ class OpenAI @JvmOverloads constructor( request.stream = true // use createCompletionAsync for stream=false val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT) - client.newCall(httpRequest).enqueue(MyCallback(true, onFailure) { + client.newCall(httpRequest).enqueue(OpenAICallback(true, onFailure) { val response = gson.fromJson(it, CompletionResponseChunk::class.java) onResponse.accept(response) }) @@ -251,7 +247,7 @@ class OpenAI @JvmOverloads constructor( try { val httpResponse = client.newCall(httpRequest).execute() lateinit var response: ChatResponse - MyCallback(true, { throw it }) { + OpenAICallback(true, { throw it }) { response = gson.fromJson(it, ChatResponse::class.java) }.onResponse(httpResponse) @@ -290,7 +286,7 @@ class OpenAI @JvmOverloads constructor( request.stream = false // use streamChatCompletionAsync for stream=true val httpRequest = buildRequest(request, CHAT_ENDPOINT) - client.newCall(httpRequest).enqueue(MyCallback(false, onFailure) { + client.newCall(httpRequest).enqueue(OpenAICallback(false, onFailure) { val response = gson.fromJson(it, ChatResponse::class.java) onResponse.accept(response) }) @@ -332,7 +328,7 @@ class OpenAI @JvmOverloads constructor( try { val httpResponse = client.newCall(httpRequest).execute() var response: ChatResponseChunk? = null - MyCallback(true, onFailure) { + OpenAICallback(true, onFailure) { if (response == null) response = gson.fromJson(it, ChatResponseChunk::class.java) else @@ -378,7 +374,7 @@ class OpenAI @JvmOverloads constructor( val httpRequest = buildRequest(request, CHAT_ENDPOINT) var response: ChatResponseChunk? = null - client.newCall(httpRequest).enqueue(MyCallback(true, onFailure) { + client.newCall(httpRequest).enqueue(OpenAICallback(true, onFailure) { if (response == null) response = gson.fromJson(it, ChatResponseChunk::class.java) else diff --git a/src/main/kotlin/com/cjcrafter/openai/MyCallback.kt b/src/main/kotlin/com/cjcrafter/openai/OpenAICallback.kt similarity index 98% rename from src/main/kotlin/com/cjcrafter/openai/MyCallback.kt rename to src/main/kotlin/com/cjcrafter/openai/OpenAICallback.kt index 81ac203..4f9f917 100644 --- a/src/main/kotlin/com/cjcrafter/openai/MyCallback.kt +++ b/src/main/kotlin/com/cjcrafter/openai/OpenAICallback.kt @@ -11,7 +11,7 @@ import okhttp3.Response import java.io.IOException import java.util.function.Consumer -internal class MyCallback( +internal class OpenAICallback( private val isStream: Boolean, private val onFailure: Consumer, private val onResponse: Consumer From 953eedf3b772428e089ebdf827936228bc1c02bd Mon Sep 17 00:00:00 2001 From: Collin Date: Mon, 3 Apr 2023 14:04:06 -0400 Subject: [PATCH 12/12] fix KotlinTest formatting --- src/test/kotlin/KotlinTest.kt | 218 +++++++++++++++++----------------- 1 file changed, 107 insertions(+), 111 deletions(-) diff --git a/src/test/kotlin/KotlinTest.kt b/src/test/kotlin/KotlinTest.kt index ea5e917..7186e1a 100644 --- a/src/test/kotlin/KotlinTest.kt +++ b/src/test/kotlin/KotlinTest.kt @@ -7,30 +7,26 @@ import com.cjcrafter.openai.chat.ChatResponse import com.cjcrafter.openai.chat.ChatResponseChunk import com.cjcrafter.openai.completions.CompletionRequest import com.cjcrafter.openai.exception.OpenAIError -import io.github.cdimascio.dotenv.Dotenv import io.github.cdimascio.dotenv.dotenv import java.util.* -object KotlinTest { - - // Colors for pretty formatting - const val RESET = "\u001b[0m" - const val BLACK = "\u001b[0;30m" - const val RED = "\u001b[0;31m" - const val GREEN = "\u001b[0;32m" - const val YELLOW = "\u001b[0;33m" - const val BLUE = "\u001b[0;34m" - const val PURPLE = "\u001b[0;35m" - const val CYAN = "\u001b[0;36m" - const val WHITE = "\u001b[0;37m" - - @Throws(OpenAIError::class) - @JvmStatic - fun main(args: Array) { - val scanner = Scanner(System.`in`) - - // Print out the menu of options - println(""" +// Colors for pretty formatting +const val RESET = "\u001b[0m" +const val BLACK = "\u001b[0;30m" +const val RED = "\u001b[0;31m" +const val GREEN = "\u001b[0;32m" +const val YELLOW = "\u001b[0;33m" +const val BLUE = "\u001b[0;34m" +const val PURPLE = "\u001b[0;35m" +const val CYAN = "\u001b[0;36m" +const val WHITE = "\u001b[0;37m" + +@Throws(OpenAIError::class) +fun main(args: Array) { + val scanner = Scanner(System.`in`) + + // Print out the menu of options + println(""" ${GREEN}Please select one of the options below by typing a number. 1. Completion (create, sync) 2. Completion (stream, sync) @@ -40,111 +36,111 @@ object KotlinTest { 6. Chat (stream, sync) 7. Chat (create, async) 8. Chat (stream, async) - """.trimIndent()) - - when (scanner.nextLine().trim()) { - "1" -> doCompletion(stream = false, async = false) - "2" -> doCompletion(stream = true, async = false) - "3" -> doCompletion(stream = false, async = true) - "4" -> doCompletion(stream = true, async = true) - "5" -> doChat(stream = false, async = false) - "6" -> doChat(stream = true, async = false) - "7" -> doChat(stream = false, async = true) - "8" -> doChat(stream = true, async = true) - else -> System.err.println("Invalid option") - } + """.trimIndent() + ) + + when (scanner.nextLine().trim()) { + "1" -> doCompletion(stream = false, async = false) + "2" -> doCompletion(stream = true, async = false) + "3" -> doCompletion(stream = false, async = true) + "4" -> doCompletion(stream = true, async = true) + "5" -> doChat(stream = false, async = false) + "6" -> doChat(stream = true, async = false) + "7" -> doChat(stream = false, async = true) + "8" -> doChat(stream = true, async = true) + else -> System.err.println("Invalid option") } - - @Throws(OpenAIError::class) - fun doCompletion(stream: Boolean, async: Boolean) { - val scan = Scanner(System.`in`) - println(YELLOW + "Enter completion: ") - val input = scan.nextLine() - - // CompletionRequest contains the data we sent to the OpenAI API. We use - // 128 tokens, so we have a bit of a delay before the response (for testing). - val request = CompletionRequest.builder() - .model("davinci") - .prompt(input) - .maxTokens(128).build() - - // Loads the API key from the .env file in the root directory. - val key = Dotenv.load()["OPENAI_TOKEN"] - val openai = OpenAI(key) - println(RESET + "Generating Response" + PURPLE) - if (stream) { - if (async) { - openai.streamCompletionAsync(request, { print(it[0].text) }) - println("$CYAN !!! Code has finished executing. Wait for async code to complete.$PURPLE") - } else { - openai.streamCompletion(request, { print(it[0].text) }) - } +} + +@Throws(OpenAIError::class) +fun doCompletion(stream: Boolean, async: Boolean) { + val scan = Scanner(System.`in`) + println(YELLOW + "Enter completion: ") + val input = scan.nextLine() + + // CompletionRequest contains the data we sent to the OpenAI API. We use + // 128 tokens, so we have a bit of a delay before the response (for testing). + val request = CompletionRequest.builder() + .model("davinci") + .prompt(input) + .maxTokens(128).build() + + // Loads the API key from the .env file in the root directory. + val key = dotenv()["OPENAI_TOKEN"] + val openai = OpenAI(key) + println(RESET + "Generating Response" + PURPLE) + if (stream) { + if (async) { + openai.streamCompletionAsync(request, { print(it[0].text) }) + println("$CYAN !!! Code has finished executing. Wait for async code to complete.$PURPLE") } else { - if (async) { - openai.createCompletionAsync(request, { println(it[0].text) }) - println("$CYAN !!! Code has finished executing. Wait for async code to complete.$PURPLE") - } else { - println(openai.createCompletion(request)[0].text) - } + openai.streamCompletion(request, { print(it[0].text) }) + } + } else { + if (async) { + openai.createCompletionAsync(request, { println(it[0].text) }) + println("$CYAN !!! Code has finished executing. Wait for async code to complete.$PURPLE") + } else { + println(openai.createCompletion(request)[0].text) } } +} - @Throws(OpenAIError::class) - fun doChat(stream: Boolean, async: Boolean) { - val scan = Scanner(System.`in`) +@Throws(OpenAIError::class) +fun doChat(stream: Boolean, async: Boolean) { + val scan = Scanner(System.`in`) - // This is the prompt that the bot will refer back to for every message. - val prompt = "You are a customer support chat-bot. Write brief summaries of the user's questions so that agents can easily find the answer in a database.".toSystemMessage() + // This is the prompt that the bot will refer back to for every message. + val prompt = "You are a customer support chat-bot. Write brief summaries of the user's questions so that agents can easily find the answer in a database.".toSystemMessage() - // Use a mutable (modifiable) list! Always! You should be reusing the - // ChatRequest variable, so in order for a conversation to continue - // you need to be able to modify the list. - val messages: MutableList = ArrayList(listOf(prompt)) + // Use a mutable (modifiable) list! Always! You should be reusing the + // ChatRequest variable, so in order for a conversation to continue + // you need to be able to modify the list. + val messages: MutableList = ArrayList(listOf(prompt)) - // ChatRequest is the request we send to OpenAI API. You can modify the - // model, temperature, maxTokens, etc. This should be saved, so you can - // reuse it for a conversation. - val request = ChatRequest(model="gpt-3.5-turbo", messages=messages) + // ChatRequest is the request we send to OpenAI API. You can modify the + // model, temperature, maxTokens, etc. This should be saved, so you can + // reuse it for a conversation. + val request = ChatRequest(model = "gpt-3.5-turbo", messages = messages) - // Loads the API key from the .env file in the root directory. - val key = dotenv()["OPENAI_TOKEN"] - val openai = OpenAI(key) + // Loads the API key from the .env file in the root directory. + val key = dotenv()["OPENAI_TOKEN"] + val openai = OpenAI(key) - // The conversation lasts until the user quits the program - while (true) { + // The conversation lasts until the user quits the program + while (true) { - // Prompt the user to enter a response - println("\n${YELLOW}Enter text below:\n") - val input = scan.nextLine() + // Prompt the user to enter a response + println("\n${YELLOW}Enter text below:\n") + val input = scan.nextLine() - // Add the newest user message to the conversation - messages.add(input.toUserMessage()) - println(RESET + "Generating Response" + PURPLE) - if (stream) { - if (async) { - openai.streamChatCompletionAsync(request, { response: ChatResponseChunk -> - print(response[0].delta) - if (response[0].isFinished()) messages.add(response[0].message) - }) - println("$CYAN !!! Code has finished executing. Wait for async code to complete.$PURPLE") - } else { - openai.streamChatCompletion(request, { response: ChatResponseChunk -> - print(response[0].delta) - if (response[0].isFinished()) messages.add(response[0].message) - }) - } + // Add the newest user message to the conversation + messages.add(input.toUserMessage()) + println(RESET + "Generating Response" + PURPLE) + if (stream) { + if (async) { + openai.streamChatCompletionAsync(request, { response: ChatResponseChunk -> + print(response[0].delta) + if (response[0].isFinished()) messages.add(response[0].message) + }) + println("$CYAN !!! Code has finished executing. Wait for async code to complete.$PURPLE") } else { - if (async) { - openai.createChatCompletionAsync(request, { response: ChatResponse -> - println(response[0].message.content) - messages.add(response[0].message) - }) - println("$CYAN !!! Code has finished executing. Wait for async code to complete.$PURPLE") - } else { - val response = openai.createChatCompletion(request) + openai.streamChatCompletion(request, { response: ChatResponseChunk -> + print(response[0].delta) + if (response[0].isFinished()) messages.add(response[0].message) + }) + } + } else { + if (async) { + openai.createChatCompletionAsync(request, { response: ChatResponse -> println(response[0].message.content) messages.add(response[0].message) - } + }) + println("$CYAN !!! Code has finished executing. Wait for async code to complete.$PURPLE") + } else { + val response = openai.createChatCompletion(request) + println(response[0].message.content) + messages.add(response[0].message) } } }