diff --git a/build.gradle.kts b/build.gradle.kts index 498415f..dd3a347 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -70,6 +70,7 @@ kotlin { commonMain { dependencies { + implementation(libs.kotlinx.datetime) implementation(libs.ktor.client.core) implementation(libs.ktor.client.content.negotiation) implementation(libs.ktor.client.logging) diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 015165c..fb90453 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -4,6 +4,7 @@ javaTarget = "17" kotlin = "2.0.21" kotlinxCoroutines = "1.9.0" +kotlinxDatetime = "0.6.1" ktor = "3.0.0" kotest = "6.0.0.M1" @@ -19,6 +20,7 @@ publishPlugin = "2.0.0" [libraries] kotlin-test = { module = "org.jetbrains.kotlin:kotlin-test", version.ref = "kotlin" } kotlinx-coroutines-test = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-test", version.ref = "kotlinxCoroutines" } +kotlinx-datetime = { module = "org.jetbrains.kotlinx:kotlinx-datetime", version.ref = "kotlinxDatetime" } # logging libs #kotlin-logging = { module = "io.github.oshai:kotlin-logging", version.ref = "kotlinLogging" } diff --git a/src/commonMain/kotlin/Anthropic.kt b/src/commonMain/kotlin/Anthropic.kt index 031da92..296935d 100644 --- a/src/commonMain/kotlin/Anthropic.kt +++ b/src/commonMain/kotlin/Anthropic.kt @@ -1,10 +1,7 @@ package com.xemantic.anthropic import com.xemantic.anthropic.event.Event -import com.xemantic.anthropic.message.Error -import com.xemantic.anthropic.message.ErrorResponse import com.xemantic.anthropic.message.MessageRequest -import com.xemantic.anthropic.message.MessageResponse import com.xemantic.anthropic.message.Tool import com.xemantic.anthropic.message.ToolUse import com.xemantic.anthropic.tool.UsableTool @@ -24,7 +21,6 @@ import io.ktor.http.ContentType import io.ktor.http.HttpMethod import io.ktor.http.HttpStatusCode import io.ktor.http.contentType -import io.ktor.http.isSuccess import io.ktor.serialization.kotlinx.json.json import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.filter @@ -187,24 +183,26 @@ class Anthropic internal constructor( toolEntryMap = toolEntryMap ).apply(block).build() - val response = client.post("/v1/messages") { + val apiResponse = client.post("/v1/messages") { contentType(ContentType.Application.Json) setBody(request) } - if (response.status.isSuccess()) { - return response.body().apply { + val response = apiResponse.body() + when (response) { + is MessageResponse -> response.apply { content.filterIsInstance() .forEach { toolUse -> val entry = toolEntryMap[toolUse.name]!! toolUse.toolEntry = entry } } - } else { - throw AnthropicException( - error = response.body().error, - httpStatusCode = response.status + is ErrorResponse -> throw AnthropicException( + error = response.error, + httpStatusCode = apiResponse.status ) + else -> throw RuntimeException("Unsupported response: $response") // should never happen } + return response } fun stream( diff --git a/src/commonMain/kotlin/Responses.kt b/src/commonMain/kotlin/Responses.kt new file mode 100644 index 0000000..ca0b7a5 --- /dev/null +++ b/src/commonMain/kotlin/Responses.kt @@ -0,0 +1,73 @@ +package com.xemantic.anthropic + +import com.xemantic.anthropic.batch.ProcessingStatus +import com.xemantic.anthropic.batch.RequestCounts +import com.xemantic.anthropic.message.Content +import com.xemantic.anthropic.message.Message +import com.xemantic.anthropic.message.Role +import com.xemantic.anthropic.message.StopReason +import com.xemantic.anthropic.message.Usage +import kotlinx.datetime.LocalDateTime +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonClassDiscriminator + +@Serializable +@JsonClassDiscriminator("type") +@OptIn(ExperimentalSerializationApi::class) +sealed class Response( + val type: String +) + +@Serializable +@SerialName("error") +data class ErrorResponse( + val error: Error +) : Response(type = "error") + +@Serializable +@SerialName("message") +data class MessageResponse( + val id: String, + val role: Role, + val content: List, // limited to Text and ToolUse + val model: String, + @SerialName("stop_reason") + val stopReason: StopReason?, + @SerialName("stop_sequence") + val stopSequence: String?, + val usage: Usage +) : Response(type = "message") { + + fun asMessage(): Message = Message { + role = Role.ASSISTANT + content += this@MessageResponse.content + } + +} + +@Serializable +@SerialName("message_batch") +data class MessageBatchResponse( + val id: String, + @SerialName("processing_status") + val processingStatus: ProcessingStatus, + @SerialName("request_counts") + val requestCounts: RequestCounts, + @SerialName("ended_at") + val endedAt: LocalDateTime?, + @SerialName("created_at") + val createdAt: LocalDateTime, + @SerialName("expires_at") + val expiresAt: LocalDateTime, + @SerialName("cancel_initiated_at") + val cancelInitiatedAt: LocalDateTime?, + @SerialName("results_url") + val resultsUrl: String? +) : Response(type = "message_batch") {} + +@Serializable +data class Error( + val type: String, val message: String +) diff --git a/src/commonMain/kotlin/batch/Batches.kt b/src/commonMain/kotlin/batch/Batches.kt new file mode 100644 index 0000000..955f904 --- /dev/null +++ b/src/commonMain/kotlin/batch/Batches.kt @@ -0,0 +1,47 @@ +package com.xemantic.anthropic.batch + +import com.xemantic.anthropic.message.Message +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +@Serializable +data class MessageBatchRequest( + val requests: List +) + +@Serializable +data class Request( + @SerialName("custom_id") + val customId: String, + val params: Params +) { + + @Serializable + data class Params( + val model: String, + val maxTokens: Int, + val messages: List + ) + +} + +@Serializable +data class RequestCounts( + val processing: Int, + val succeeded: Int, + val errored: Int, + val canceled: Int, + val expired: Int +) + +/** + * Processing status of the Message Batch. + */ +enum class ProcessingStatus { + @SerialName("in_progress") + IN_PROGRESS, + @SerialName("canceling") + CANCELING, + @SerialName("ended") + ENDED +} diff --git a/src/commonMain/kotlin/event/Events.kt b/src/commonMain/kotlin/event/Events.kt index 7c5fa32..cef2430 100644 --- a/src/commonMain/kotlin/event/Events.kt +++ b/src/commonMain/kotlin/event/Events.kt @@ -1,6 +1,6 @@ package com.xemantic.anthropic.event -import com.xemantic.anthropic.message.MessageResponse +import com.xemantic.anthropic.MessageResponse import com.xemantic.anthropic.message.StopReason import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.SerialName diff --git a/src/commonMain/kotlin/message/Messages.kt b/src/commonMain/kotlin/message/Messages.kt index 7777897..ac13635 100644 --- a/src/commonMain/kotlin/message/Messages.kt +++ b/src/commonMain/kotlin/message/Messages.kt @@ -138,44 +138,6 @@ fun MessageRequest( return builder.build() } -@Serializable -data class MessageResponse( - val id: String, - val type: Type, - val role: Role, - val content: List, // limited to Text and ToolUse - val model: String, - @SerialName("stop_reason") - val stopReason: StopReason?, - @SerialName("stop_sequence") - val stopSequence: String?, - val usage: Usage -) { - - enum class Type { - @SerialName("message") - MESSAGE - } - - fun asMessage(): Message = Message { - role = Role.ASSISTANT - content += this@MessageResponse.content - } - -} - - -@Serializable -data class ErrorResponse( - val type: String, - val error: Error -) - -@Serializable -data class Error( - val type: String, val message: String -) - @Serializable data class Message( val role: Role, @@ -384,7 +346,10 @@ data class CacheControl( @Serializable @JsonClassDiscriminator("type") @OptIn(ExperimentalSerializationApi::class) -sealed class ToolChoice { +sealed class ToolChoice( + @SerialName("disable_parallel_tool_use") + val disableParallelToolUse: Boolean = false +) { @Serializable @SerialName("auto") diff --git a/src/commonTest/kotlin/AnthropicTest.kt b/src/commonTest/kotlin/AnthropicTest.kt index 5dfb438..85c3935 100644 --- a/src/commonTest/kotlin/AnthropicTest.kt +++ b/src/commonTest/kotlin/AnthropicTest.kt @@ -4,7 +4,6 @@ import com.xemantic.anthropic.event.ContentBlockDeltaEvent import com.xemantic.anthropic.event.Delta.TextDelta import com.xemantic.anthropic.message.Image import com.xemantic.anthropic.message.Message -import com.xemantic.anthropic.message.MessageResponse import com.xemantic.anthropic.message.Role import com.xemantic.anthropic.message.StopReason import com.xemantic.anthropic.message.Text @@ -44,7 +43,7 @@ class AnthropicTest { // then assertSoftly(response) { - type shouldBe MessageResponse.Type.MESSAGE + type shouldBe "message" role shouldBe Role.ASSISTANT model shouldBe "claude-3-5-sonnet-20240620" stopReason shouldBe StopReason.END_TURN diff --git a/src/commonTest/kotlin/message/MessageResponseTest.kt b/src/commonTest/kotlin/message/MessageResponseTest.kt index d26db8d..c1566e7 100644 --- a/src/commonTest/kotlin/message/MessageResponseTest.kt +++ b/src/commonTest/kotlin/message/MessageResponseTest.kt @@ -1,5 +1,6 @@ package com.xemantic.anthropic.message +import com.xemantic.anthropic.MessageResponse import com.xemantic.anthropic.test.testJson import io.kotest.assertions.assertSoftly import io.kotest.matchers.shouldBe @@ -44,7 +45,7 @@ class MessageResponseTest { val response = testJson.decodeFromString(jsonResponse) assertSoftly(response) { id shouldBe "msg_01PspkNzNG3nrf5upeTsmWLF" - type shouldBe MessageResponse.Type.MESSAGE + type shouldBe "message" role shouldBe Role.ASSISTANT model shouldBe "claude-3-5-sonnet-20240620" content.size shouldBe 1 diff --git a/src/jvmMain/kotlin/JvmAnthropic.kt b/src/jvmMain/kotlin/JvmAnthropic.kt index 036434c..ba07033 100644 --- a/src/jvmMain/kotlin/JvmAnthropic.kt +++ b/src/jvmMain/kotlin/JvmAnthropic.kt @@ -1,7 +1,6 @@ package com.xemantic.anthropic import com.xemantic.anthropic.message.MessageRequest -import com.xemantic.anthropic.message.MessageResponse import kotlinx.coroutines.runBlocking import java.util.function.Consumer diff --git a/src/jvmTest/kotlin/StructuredOutputTest.kt b/src/jvmTest/kotlin/StructuredOutputTest.kt index e306dc9..ab091b3 100644 --- a/src/jvmTest/kotlin/StructuredOutputTest.kt +++ b/src/jvmTest/kotlin/StructuredOutputTest.kt @@ -9,6 +9,7 @@ import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldStartWith import kotlinx.coroutines.test.runTest import kotlinx.serialization.Serializable +import kotlin.test.Ignore import kotlin.test.Test /** @@ -50,6 +51,7 @@ data class Asset( class StructuredOutputTest { @Test + @Ignore // to be moved to anthropic-sdk-kotlin-demo soon fun shouldDecodeStructuredOutputFromReportImage() = runTest { val client = Anthropic { tool()