Skip to content

Commit

Permalink
initial support for message batches (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
morisil authored Oct 17, 2024
1 parent 1d29673 commit b59239f
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 55 deletions.
1 change: 1 addition & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ kotlin {

commonMain {
dependencies {
implementation(libs.kotlinx.datetime)
implementation(libs.ktor.client.core)
implementation(libs.ktor.client.content.negotiation)
implementation(libs.ktor.client.logging)
Expand Down
2 changes: 2 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ javaTarget = "17"

kotlin = "2.0.21"
kotlinxCoroutines = "1.9.0"
kotlinxDatetime = "0.6.1"
ktor = "3.0.0"
kotest = "6.0.0.M1"

Expand All @@ -19,6 +20,7 @@ publishPlugin = "2.0.0"
[libraries]
kotlin-test = { module = "org.jetbrains.kotlin:kotlin-test", version.ref = "kotlin" }
kotlinx-coroutines-test = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-test", version.ref = "kotlinxCoroutines" }
kotlinx-datetime = { module = "org.jetbrains.kotlinx:kotlinx-datetime", version.ref = "kotlinxDatetime" }

# logging libs
#kotlin-logging = { module = "io.github.oshai:kotlin-logging", version.ref = "kotlinLogging" }
Expand Down
20 changes: 9 additions & 11 deletions src/commonMain/kotlin/Anthropic.kt
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package com.xemantic.anthropic

import com.xemantic.anthropic.event.Event
import com.xemantic.anthropic.message.Error
import com.xemantic.anthropic.message.ErrorResponse
import com.xemantic.anthropic.message.MessageRequest
import com.xemantic.anthropic.message.MessageResponse
import com.xemantic.anthropic.message.Tool
import com.xemantic.anthropic.message.ToolUse
import com.xemantic.anthropic.tool.UsableTool
Expand All @@ -24,7 +21,6 @@ import io.ktor.http.ContentType
import io.ktor.http.HttpMethod
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.http.isSuccess
import io.ktor.serialization.kotlinx.json.json
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filter
Expand Down Expand Up @@ -187,24 +183,26 @@ class Anthropic internal constructor(
toolEntryMap = toolEntryMap
).apply(block).build()

val response = client.post("/v1/messages") {
val apiResponse = client.post("/v1/messages") {
contentType(ContentType.Application.Json)
setBody(request)
}
if (response.status.isSuccess()) {
return response.body<MessageResponse>().apply {
val response = apiResponse.body<Response>()
when (response) {
is MessageResponse -> response.apply {
content.filterIsInstance<ToolUse>()
.forEach { toolUse ->
val entry = toolEntryMap[toolUse.name]!!
toolUse.toolEntry = entry
}
}
} else {
throw AnthropicException(
error = response.body<ErrorResponse>().error,
httpStatusCode = response.status
is ErrorResponse -> throw AnthropicException(
error = response.error,
httpStatusCode = apiResponse.status
)
else -> throw RuntimeException("Unsupported response: $response") // should never happen
}
return response
}

fun stream(
Expand Down
73 changes: 73 additions & 0 deletions src/commonMain/kotlin/Responses.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package com.xemantic.anthropic

import com.xemantic.anthropic.batch.ProcessingStatus
import com.xemantic.anthropic.batch.RequestCounts
import com.xemantic.anthropic.message.Content
import com.xemantic.anthropic.message.Message
import com.xemantic.anthropic.message.Role
import com.xemantic.anthropic.message.StopReason
import com.xemantic.anthropic.message.Usage
import kotlinx.datetime.LocalDateTime
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonClassDiscriminator

@Serializable
@JsonClassDiscriminator("type")
@OptIn(ExperimentalSerializationApi::class)
sealed class Response(
val type: String
)

@Serializable
@SerialName("error")
data class ErrorResponse(
val error: Error
) : Response(type = "error")

@Serializable
@SerialName("message")
data class MessageResponse(
val id: String,
val role: Role,
val content: List<Content>, // limited to Text and ToolUse
val model: String,
@SerialName("stop_reason")
val stopReason: StopReason?,
@SerialName("stop_sequence")
val stopSequence: String?,
val usage: Usage
) : Response(type = "message") {

fun asMessage(): Message = Message {
role = Role.ASSISTANT
content += this@MessageResponse.content
}

}

@Serializable
@SerialName("message_batch")
data class MessageBatchResponse(
val id: String,
@SerialName("processing_status")
val processingStatus: ProcessingStatus,
@SerialName("request_counts")
val requestCounts: RequestCounts,
@SerialName("ended_at")
val endedAt: LocalDateTime?,
@SerialName("created_at")
val createdAt: LocalDateTime,
@SerialName("expires_at")
val expiresAt: LocalDateTime,
@SerialName("cancel_initiated_at")
val cancelInitiatedAt: LocalDateTime?,
@SerialName("results_url")
val resultsUrl: String?
) : Response(type = "message_batch") {}

@Serializable
data class Error(
val type: String, val message: String
)
47 changes: 47 additions & 0 deletions src/commonMain/kotlin/batch/Batches.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.xemantic.anthropic.batch

import com.xemantic.anthropic.message.Message
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
data class MessageBatchRequest(
val requests: List<Request>
)

@Serializable
data class Request(
@SerialName("custom_id")
val customId: String,
val params: Params
) {

@Serializable
data class Params(
val model: String,
val maxTokens: Int,
val messages: List<Message>
)

}

@Serializable
data class RequestCounts(
val processing: Int,
val succeeded: Int,
val errored: Int,
val canceled: Int,
val expired: Int
)

/**
* Processing status of the Message Batch.
*/
enum class ProcessingStatus {
@SerialName("in_progress")
IN_PROGRESS,
@SerialName("canceling")
CANCELING,
@SerialName("ended")
ENDED
}
2 changes: 1 addition & 1 deletion src/commonMain/kotlin/event/Events.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.xemantic.anthropic.event

import com.xemantic.anthropic.message.MessageResponse
import com.xemantic.anthropic.MessageResponse
import com.xemantic.anthropic.message.StopReason
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.SerialName
Expand Down
43 changes: 4 additions & 39 deletions src/commonMain/kotlin/message/Messages.kt
Original file line number Diff line number Diff line change
Expand Up @@ -138,44 +138,6 @@ fun MessageRequest(
return builder.build()
}

@Serializable
data class MessageResponse(
val id: String,
val type: Type,
val role: Role,
val content: List<Content>, // limited to Text and ToolUse
val model: String,
@SerialName("stop_reason")
val stopReason: StopReason?,
@SerialName("stop_sequence")
val stopSequence: String?,
val usage: Usage
) {

enum class Type {
@SerialName("message")
MESSAGE
}

fun asMessage(): Message = Message {
role = Role.ASSISTANT
content += this@MessageResponse.content
}

}


@Serializable
data class ErrorResponse(
val type: String,
val error: Error
)

@Serializable
data class Error(
val type: String, val message: String
)

@Serializable
data class Message(
val role: Role,
Expand Down Expand Up @@ -384,7 +346,10 @@ data class CacheControl(
@Serializable
@JsonClassDiscriminator("type")
@OptIn(ExperimentalSerializationApi::class)
sealed class ToolChoice {
sealed class ToolChoice(
@SerialName("disable_parallel_tool_use")
val disableParallelToolUse: Boolean = false
) {

@Serializable
@SerialName("auto")
Expand Down
3 changes: 1 addition & 2 deletions src/commonTest/kotlin/AnthropicTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import com.xemantic.anthropic.event.ContentBlockDeltaEvent
import com.xemantic.anthropic.event.Delta.TextDelta
import com.xemantic.anthropic.message.Image
import com.xemantic.anthropic.message.Message
import com.xemantic.anthropic.message.MessageResponse
import com.xemantic.anthropic.message.Role
import com.xemantic.anthropic.message.StopReason
import com.xemantic.anthropic.message.Text
Expand Down Expand Up @@ -44,7 +43,7 @@ class AnthropicTest {

// then
assertSoftly(response) {
type shouldBe MessageResponse.Type.MESSAGE
type shouldBe "message"
role shouldBe Role.ASSISTANT
model shouldBe "claude-3-5-sonnet-20240620"
stopReason shouldBe StopReason.END_TURN
Expand Down
3 changes: 2 additions & 1 deletion src/commonTest/kotlin/message/MessageResponseTest.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.xemantic.anthropic.message

import com.xemantic.anthropic.MessageResponse
import com.xemantic.anthropic.test.testJson
import io.kotest.assertions.assertSoftly
import io.kotest.matchers.shouldBe
Expand Down Expand Up @@ -44,7 +45,7 @@ class MessageResponseTest {
val response = testJson.decodeFromString<MessageResponse>(jsonResponse)
assertSoftly(response) {
id shouldBe "msg_01PspkNzNG3nrf5upeTsmWLF"
type shouldBe MessageResponse.Type.MESSAGE
type shouldBe "message"
role shouldBe Role.ASSISTANT
model shouldBe "claude-3-5-sonnet-20240620"
content.size shouldBe 1
Expand Down
1 change: 0 additions & 1 deletion src/jvmMain/kotlin/JvmAnthropic.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.xemantic.anthropic

import com.xemantic.anthropic.message.MessageRequest
import com.xemantic.anthropic.message.MessageResponse
import kotlinx.coroutines.runBlocking
import java.util.function.Consumer

Expand Down
2 changes: 2 additions & 0 deletions src/jvmTest/kotlin/StructuredOutputTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldStartWith
import kotlinx.coroutines.test.runTest
import kotlinx.serialization.Serializable
import kotlin.test.Ignore
import kotlin.test.Test

/**
Expand Down Expand Up @@ -50,6 +51,7 @@ data class Asset(
class StructuredOutputTest {

@Test
@Ignore // to be moved to anthropic-sdk-kotlin-demo soon
fun shouldDecodeStructuredOutputFromReportImage() = runTest {
val client = Anthropic {
tool<DisclosureReport>()
Expand Down

0 comments on commit b59239f

Please sign in to comment.