Skip to content

Commit

Permalink
add logging and test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
CJCrafter committed Nov 12, 2023
1 parent acc69b8 commit c62433a
Show file tree
Hide file tree
Showing 12 changed files with 334 additions and 25 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ log/
target/

# ChatGPT-Java-API Specific Ignore
.env
.env
debug.log
4 changes: 4 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ dependencies {
implementation("com.fasterxml.jackson.core:jackson-annotations:2.15.3")
implementation("com.fasterxml.jackson.module:jackson-module-kotlin:2.15.3")

implementation("org.slf4j:slf4j-api:2.0.9")

implementation("org.jetbrains:annotations:24.0.1")

testImplementation("io.github.cdimascio:dotenv-kotlin:6.4.1")
testImplementation("org.junit.jupiter:junit-jupiter:5.9.2")
testImplementation("com.squareup.okhttp3:okhttp:4.9.2")
testImplementation("com.squareup.okhttp3:mockwebserver:4.9.2")
}

kotlin {
Expand Down
2 changes: 2 additions & 0 deletions examples/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ dependencies {
implementation("com.fasterxml.jackson.core:jackson-databind:2.15.3")
implementation("io.github.cdimascio:dotenv-kotlin:6.4.1")

implementation("ch.qos.logback:logback-classic:1.4.11")

// https://mvnrepository.com/artifact/org.mariuszgromada.math/MathParser.org-mXparser
// Used for tool tests
implementation("org.mariuszgromada.math:MathParser.org-mXparser:5.2.1")
Expand Down
13 changes: 13 additions & 0 deletions examples/src/main/resources/logback.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<configuration>
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>debug.log</file>
<append>false</append>
<encoder>
<pattern>%date %level [%thread] %logger{10} %msg%n</pattern>
</encoder>
</appender>

<root level="DEBUG">
<appender-ref ref="FILE"/>
</root>
</configuration>
7 changes: 3 additions & 4 deletions src/main/kotlin/com/cjcrafter/openai/AzureOpenAI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,23 @@ import org.jetbrains.annotations.ApiStatus
*
* This class constructs url in the form of: https://<azureBaseUrl>/openai/deployments/<modelName>/<endpoint>?api-version=<apiVersion>
*
* @property azureBaseUrl The base URL for the Azure OpenAI API. Usually https://<your_resource_group>.openai.azure.com
* @property apiVersion The API version to use. Defaults to 2023-03-15-preview.
* @property modelName The model name to use. This is the name of the model deployed to Azure.
*/
class AzureOpenAI @ApiStatus.Internal constructor(
apiKey: String,
organization: String? = null,
client: OkHttpClient = OkHttpClient(),
private val azureBaseUrl: String = "",
baseUrl: String = "https://api.openai.com",
private val apiVersion: String = "2023-03-15-preview",
private val modelName: String = ""
) : OpenAIImpl(apiKey, organization, client) {
) : OpenAIImpl(apiKey, organization, client, baseUrl) {

override fun buildRequest(request: Any, endpoint: String): Request {
val json = objectMapper.writeValueAsString(request)
val body: RequestBody = json.toRequestBody(mediaType)
return Request.Builder()
.url("$azureBaseUrl/openai/deployments/$modelName/$endpoint?api-version=$apiVersion")
.url("$baseUrl/openai/deployments/$modelName/$endpoint?api-version=$apiVersion")
.addHeader("Content-Type", "application/json")
.addHeader("api-key", apiKey)
.apply { if (organization != null) addHeader("OpenAI-Organization", organization) }
Expand Down
43 changes: 25 additions & 18 deletions src/main/kotlin/com/cjcrafter/openai/OpenAI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.cjcrafter.openai.completions.CompletionRequest
import com.cjcrafter.openai.completions.CompletionResponse
import com.cjcrafter.openai.completions.CompletionResponseChunk
import com.cjcrafter.openai.util.OpenAIDslMarker
import com.fasterxml.jackson.annotation.JsonAutoDetect
import com.fasterxml.jackson.annotation.JsonInclude
import com.fasterxml.jackson.databind.DeserializationFeature
import com.fasterxml.jackson.databind.ObjectMapper
Expand All @@ -14,6 +15,7 @@ import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import okhttp3.OkHttpClient
import org.jetbrains.annotations.ApiStatus
import org.jetbrains.annotations.Contract
import org.slf4j.LoggerFactory

interface OpenAI {

Expand Down Expand Up @@ -91,46 +93,49 @@ interface OpenAI {
protected var apiKey: String? = null
protected var organization: String? = null
protected var client: OkHttpClient = OkHttpClient()
protected var baseUrl: String = "https://api.openai.com"

fun apiKey(apiKey: String) = apply { this.apiKey = apiKey }
fun organization(organization: String?) = apply { this.organization = organization }
fun client(client: OkHttpClient) = apply { this.client = client }
fun baseUrl(baseUrl: String) = apply { this.baseUrl = baseUrl }

@Contract(pure = true)
open fun build(): OpenAI {
return OpenAIImpl(
apiKey ?: throw IllegalStateException("apiKey must be defined to use OpenAI"),
organization,
client
apiKey = apiKey ?: throw IllegalStateException("apiKey must be defined to use OpenAI"),
organization = organization,
client = client,
baseUrl = baseUrl,
)
}
}

@OpenAIDslMarker
class AzureBuilder internal constructor(): Builder() {
private var azureBaseUrl: String? = null
private var apiVersion: String? = null
private var modelName: String? = null

fun azureBaseUrl(azureBaseUrl: String) = apply { this.azureBaseUrl = azureBaseUrl }
fun apiVersion(apiVersion: String) = apply { this.apiVersion = apiVersion }
fun modelName(modelName: String) = apply { this.modelName = modelName }

@Contract(pure = true)
override fun build(): OpenAI {
return AzureOpenAI(
apiKey ?: throw IllegalStateException("apiKey must be defined to use OpenAI"),
organization,
client,
azureBaseUrl ?: throw IllegalStateException("azureBaseUrl must be defined for azure"),
apiVersion ?: throw IllegalStateException("apiVersion must be defined for azure"),
modelName ?: throw IllegalStateException("modelName must be defined for azure")
apiKey = apiKey ?: throw IllegalStateException("apiKey must be defined to use OpenAI"),
organization = organization,
client = client,
baseUrl = if (baseUrl == "https://api.openai.com") throw IllegalStateException("baseUrl must be set to an azure endpoint") else baseUrl,
apiVersion = apiVersion ?: throw IllegalStateException("apiVersion must be defined for azure"),
modelName = modelName ?: throw IllegalStateException("modelName must be defined for azure")
)
}
}

companion object {

internal val logger = LoggerFactory.getLogger(OpenAI::class.java)

/**
* Instantiates a builder for a default OpenAI instance. For Azure's
* OpenAI, use [azureBuilder] instead.
Expand All @@ -155,6 +160,14 @@ interface OpenAI {
setSerializationInclusion(JsonInclude.Include.NON_NULL)
configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)

// By default, Jackson can serialize fields AND getters. We just want fields.
setVisibility(serializationConfig.getDefaultVisibilityChecker()
.withFieldVisibility(JsonAutoDetect.Visibility.ANY)
.withGetterVisibility(JsonAutoDetect.Visibility.NONE)
.withSetterVisibility(JsonAutoDetect.Visibility.NONE)
.withCreatorVisibility(JsonAutoDetect.Visibility.NONE)
)

// Register modules with custom serializers/deserializers
val module = SimpleModule().apply {
addSerializer(ToolChoice::class.java, ToolChoice.serializer())
Expand All @@ -180,10 +193,4 @@ interface OpenAI {
consumer(chunk)
}
}
}

@Contract(pure = true)
fun openAI(init: OpenAI.Builder.() -> Unit) = OpenAI.builder().apply(init).build()

@Contract(pure = true)
fun azureOpenAI(init: OpenAI.AzureBuilder.() -> Unit) = OpenAI.azureBuilder().apply(init).build()
}
15 changes: 15 additions & 0 deletions src/main/kotlin/com/cjcrafter/openai/OpenAIDsl.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.cjcrafter.openai

import org.jetbrains.annotations.Contract

/**
* Builds an [OpenAI] instance using the default implementation.
*/
@Contract(pure = true)
fun openAI(init: OpenAI.Builder.() -> Unit) = OpenAI.builder().apply(init).build()

/**
* Builds an [OpenAI] instance using the Azure implementation.
*/
@Contract(pure = true)
fun azureOpenAI(init: OpenAI.AzureBuilder.() -> Unit) = OpenAI.azureBuilder().apply(init).build()
9 changes: 7 additions & 2 deletions src/main/kotlin/com/cjcrafter/openai/OpenAIImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import java.io.IOException
open class OpenAIImpl @ApiStatus.Internal constructor(
protected val apiKey: String,
protected val organization: String? = null,
private val client: OkHttpClient = OkHttpClient()
protected val client: OkHttpClient = OkHttpClient(),
protected val baseUrl: String = "https://api.openai.com",
): OpenAI {
protected val mediaType = "application/json; charset=utf-8".toMediaType()
protected val objectMapper = OpenAI.createObjectMapper()
Expand All @@ -25,7 +26,7 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
val json = objectMapper.writeValueAsString(request)
val body: RequestBody = json.toRequestBody(mediaType)
return Request.Builder()
.url("https://api.openai.com/$endpoint")
.url("$baseUrl/$endpoint")
.addHeader("Content-Type", "application/json")
.addHeader("Authorization", "Bearer $apiKey")
.apply { if (organization != null) addHeader("OpenAI-Organization", organization) }
Expand All @@ -43,6 +44,7 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
val jsonReader = httpResponse.body?.byteStream()?.bufferedReader()
?: throw IOException("Response body is null")
val responseStr = jsonReader.readText()
OpenAI.logger.debug(responseStr)
return objectMapper.readValue(responseStr, responseType)
}

Expand Down Expand Up @@ -72,6 +74,8 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
var line: String?
do {
line = reader.readLine()
OpenAI.logger.debug(line)

if (line == "data: [DONE]") {
reader.close()
return null
Expand All @@ -86,6 +90,7 @@ open class OpenAIImpl @ApiStatus.Internal constructor(

override fun next(): T {
val line = nextLine ?: throw NoSuchElementException("No more lines")

currentResponse = if (currentResponse == null) {
objectMapper.readValue(line, responseType)
} else {
Expand Down
111 changes: 111 additions & 0 deletions src/test/kotlin/com/cjcrafter/openai/chat/ChatRequestTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package com.cjcrafter.openai.chat

import com.cjcrafter.openai.OpenAI
import com.cjcrafter.openai.chat.ChatMessage.Companion.toSystemMessage
import org.intellij.lang.annotations.Language
import org.junit.jupiter.api.Assertions.*
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.Arguments
import org.junit.jupiter.params.provider.MethodSource
import java.util.stream.Stream

class ChatRequestTest {

@ParameterizedTest
@MethodSource("provide_serialize")
fun `test deserialize to json`(obj: Any, json: String) {
val objectMapper = OpenAI.createObjectMapper()
val expected = objectMapper.readTree(json)
val actual = objectMapper.readTree(objectMapper.writeValueAsString(obj))
assertEquals(expected, actual)
}

@ParameterizedTest
@MethodSource("provide_serialize")
fun `test serialize from json`(expected: Any, json: String) {
val objectMapper = OpenAI.createObjectMapper()
val actual = objectMapper.readValue(json, expected::class.java)
assertEquals(expected, actual)
}

companion object {
@JvmStatic
fun provide_serialize(): Stream<Arguments> {
return buildList<Arguments> {

@Language("JSON")
var json = """
{
"messages": [
{
"role": "system",
"content": "Be as helpful as possible"
}
],
"model": "gpt-3.5-turbo"
}
""".trimIndent()
add(Arguments.of(
ChatRequest.builder()
.model("gpt-3.5-turbo")
.messages(mutableListOf("Be as helpful as possible".toSystemMessage()))
.build(),
json
))

json = """
{
"messages": [
{
"role": "system",
"content": "Be as helpful as possible"
},
{
"role": "user",
"content": "What is 2 + 2?"
}
],
"model": "gpt-3.5-turbo",
"tools": [
{
"type": "function",
"function": {
"name": "solve_math_problem",
"parameters": {
"type": "object",
"properties": {
"equation": {
"type": "string",
"description": "The math problem for you to solve"
}
},
"required": [
"equation"
]
},
"description": "Returns the result of a math problem as a double"
}
}
]
}
""".trimIndent()
add(Arguments.of(
chatRequest {
model("gpt-3.5-turbo")
messages(mutableListOf(
ChatMessage(ChatUser.SYSTEM, "Be as helpful as possible"),
ChatMessage(ChatUser.USER, "What is 2 + 2?")
))
function {
name("solve_math_problem")
description("Returns the result of a math problem as a double")
addStringParameter("equation", "The math problem for you to solve", true)
}
},
json
))

}.stream()
}
}
}
Loading

0 comments on commit c62433a

Please sign in to comment.