From 6c01e9fa3e52d9ad034a64e9609515706a5d3779 Mon Sep 17 00:00:00 2001 From: Kazik Pogoda Date: Wed, 9 Oct 2024 17:15:04 +0200 Subject: [PATCH] better tools support, initial attempt WIP --- build.gradle.kts | 46 ++-- src/commonMain/kotlin/Anthropic.kt | 41 +++- src/commonMain/kotlin/event/Events.kt | 26 +- src/commonMain/kotlin/message/Messages.kt | 225 +++++++++++++++--- .../kotlin/schema/JsonSchemaGenerator.kt | 6 +- src/commonMain/kotlin/tool/Tools.kt | 62 +++++ src/commonTest/kotlin/AnthropicTest.kt | 102 +++----- src/commonTest/kotlin/AnthropicTestTools.kt | 65 +++++ src/commonTest/kotlin/message/MessagesTest.kt | 34 +++ 9 files changed, 473 insertions(+), 134 deletions(-) create mode 100644 src/commonMain/kotlin/tool/Tools.kt create mode 100644 src/commonTest/kotlin/AnthropicTestTools.kt diff --git a/build.gradle.kts b/build.gradle.kts index 2d87906..cb25fda 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -44,10 +44,22 @@ repositories { } kotlin { + //explicitApi() // check with serialization? + jvm { + testRuns["test"].executionTask.configure { + useJUnitPlatform() + } + // set up according to https://jakewharton.com/gradle-toolchains-are-rarely-a-good-idea/ + compilerOptions { + apiVersion = kotlinTarget + languageVersion = kotlinTarget + jvmTarget = JvmTarget.fromTarget(javaTarget) + freeCompilerArgs.add("-Xjdk-release=$javaTarget") + progressiveMode = true + } + } - jvm {} - - linuxX64() +// linuxX64() sourceSets { @@ -70,6 +82,7 @@ kotlin { jvmTest { dependencies { + implementation(kotlin("test-junit5")) runtimeOnly(libs.log4j.slf4j2) runtimeOnly(libs.log4j.core) runtimeOnly(libs.jackson.databind) @@ -78,28 +91,17 @@ kotlin { } } - nativeTest { - dependencies { - // on mac/ios it should be rather Darwin - implementation(libs.ktor.client.curl) - } - } +// nativeTest { +// dependencies { +// // on mac/ios it should be rather Darwin +// implementation(libs.ktor.client.curl) +// } +// } } } -// set up according to https://jakewharton.com/gradle-toolchains-are-rarely-a-good-idea/ -tasks.withType { - compilerOptions { - apiVersion = kotlinTarget - languageVersion = kotlinTarget - jvmTarget = JvmTarget.fromTarget(javaTarget) - freeCompilerArgs.add("-Xjdk-release=$javaTarget") - progressiveMode = true - } -} - fun isNonStable(version: String): Boolean { val stableKeyword = listOf("RELEASE", "FINAL", "GA").any { version.uppercase().contains(it) } val regex = "^[0-9,.v-]+(-r)?$".toRegex() @@ -113,7 +115,7 @@ tasks.withType { } } -tasks.withType() { +tasks.withType { testLogging { events( TestLogEvent.PASSED, @@ -123,6 +125,7 @@ tasks.withType() { showStackTraces = true exceptionFormat = TestExceptionFormat.FULL } + enabled = true } @Suppress("OPT_IN_USAGE") @@ -130,6 +133,7 @@ powerAssert { functions = listOf( "kotlin.assert", "kotlin.test.assertTrue", + "kotlin.test.assertFalse", "kotlin.test.assertEquals", "kotlin.test.assertNull" ) diff --git a/src/commonMain/kotlin/Anthropic.kt b/src/commonMain/kotlin/Anthropic.kt index b38f0a4..7c537be 100644 --- a/src/commonMain/kotlin/Anthropic.kt +++ b/src/commonMain/kotlin/Anthropic.kt @@ -1,10 +1,12 @@ package com.xemantic.anthropic import com.xemantic.anthropic.event.Event +import com.xemantic.anthropic.event.Usage import com.xemantic.anthropic.message.Error import com.xemantic.anthropic.message.ErrorResponse import com.xemantic.anthropic.message.MessageRequest import com.xemantic.anthropic.message.MessageResponse +import com.xemantic.anthropic.message.UsableTool import io.ktor.client.HttpClient import io.ktor.client.call.body import io.ktor.client.plugins.contentnegotiation.ContentNegotiation @@ -26,7 +28,14 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.filter import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map +import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.json.Json +import kotlinx.serialization.modules.SerializersModule +import kotlinx.serialization.modules.polymorphic +import kotlinx.serialization.serializer +import kotlin.reflect.KClass +import kotlin.reflect.KType +import kotlin.reflect.typeOf const val ANTHROPIC_API_BASE: String = "https://api.anthropic.com/" @@ -45,6 +54,11 @@ val anthropicJson: Json = Json { allowSpecialFloatingPointValues = true explicitNulls = false encodeDefaults = true +// serializersModule = SerializersModule { +// //contextual(UsableTool::class, UsableToolSerializer::class) +//// polymorphic(UsableTool::class) { +//// } +// } } fun Anthropic( @@ -60,7 +74,8 @@ fun Anthropic( anthropicBeta = config.anthropicBeta, apiBase = config.apiBase, defaultModel = defaultModel, - directBrowserAccess = config.directBrowserAccess + directBrowserAccess = config.directBrowserAccess, + context = config.context ) } @@ -70,7 +85,9 @@ class Anthropic internal constructor( val anthropicBeta: String?, val apiBase: String, val defaultModel: String, - val directBrowserAccess: Boolean + val directBrowserAccess: Boolean, + val tools: MutableList> = mutableListOf>(), + val context: Context? ) { class Config { @@ -80,8 +97,26 @@ class Anthropic internal constructor( var apiBase: String = ANTHROPIC_API_BASE var defaultModel: String? = null var directBrowserAccess: Boolean = false + var tools: MutableList>? = null + var context: Context? = null } + interface Context { + + fun service(type: KType): T + + } + + companion object { + val EMPTY_CONTEXT: Context = object : Context { + override fun service(type: KType): T { + throw UnsupportedOperationException("No services available") + } + } + } + + inline fun Context.service(): T = service(typeOf()) + private val client = HttpClient { install(ContentNegotiation) { json(anthropicJson) @@ -157,5 +192,3 @@ class Anthropic internal constructor( } -inline fun anthropicTypeOf(): String = - T::class.qualifiedName!!.replace('.', '_') diff --git a/src/commonMain/kotlin/event/Events.kt b/src/commonMain/kotlin/event/Events.kt index b9d3da8..7c5fa32 100644 --- a/src/commonMain/kotlin/event/Events.kt +++ b/src/commonMain/kotlin/event/Events.kt @@ -7,6 +7,8 @@ import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonClassDiscriminator +// reference https://docs.spring.io/spring-ai/reference/_images/anthropic-claude3-events-model.jpg + @Serializable @JsonClassDiscriminator("type") @OptIn(ExperimentalSerializationApi::class) @@ -14,14 +16,14 @@ sealed class Event @Serializable @SerialName("message_start") -data class MessageStart( +data class MessageStartEvent( val message: MessageResponse ) : Event() @Serializable @SerialName("message_delta") -data class MessageDelta( - val delta: MessageDelta.Delta, +data class MessageDeltaEvent( + val delta: Delta, val usage: Usage ) : Event() { @@ -37,13 +39,15 @@ data class MessageDelta( @Serializable @SerialName("message_stop") -class MessageStop : Event() { +class MessageStopEvent : Event() { override fun toString(): String = "MessageStop" } +// TODO error event is missing, should we rename all of these to events? + @Serializable @SerialName("content_block_start") -data class ContentBlockStart( +data class ContentBlockStartEvent( val index: Int, @SerialName("content_block") val contentBlock: ContentBlock @@ -51,7 +55,7 @@ data class ContentBlockStart( @Serializable @SerialName("content_block_stop") -data class ContentBlockStop( +data class ContentBlockStopEvent( val index: Int ) : Event() @@ -66,17 +70,23 @@ sealed class ContentBlock { val text: String ) : ContentBlock() + @Serializable + @SerialName("tool_use") + class ToolUse( + val text: String // TODO tool_id + ) : ContentBlock() + // TODO missing tool_use } @Serializable @SerialName("ping") -class Ping: Event() { +class PingEvent: Event() { override fun toString(): String = "Ping" } @Serializable @SerialName("content_block_delta") -data class ContentBlockDelta( +data class ContentBlockDeltaEvent( val index: Int, val delta: Delta ) : Event() diff --git a/src/commonMain/kotlin/message/Messages.kt b/src/commonMain/kotlin/message/Messages.kt index 6044cfb..4ff1b7e 100644 --- a/src/commonMain/kotlin/message/Messages.kt +++ b/src/commonMain/kotlin/message/Messages.kt @@ -1,16 +1,24 @@ package com.xemantic.anthropic.message -import com.xemantic.anthropic.anthropicJson -import com.xemantic.anthropic.anthropicTypeOf +import com.xemantic.anthropic.Anthropic import com.xemantic.anthropic.schema.JsonSchema -import com.xemantic.anthropic.schema.jsonSchemaOf import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.InternalSerializationApi +import kotlinx.serialization.KSerializer import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable +import kotlinx.serialization.SerializationException +import kotlinx.serialization.builtins.serializer +import kotlinx.serialization.descriptors.PolymorphicKind +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.buildSerialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder import kotlinx.serialization.json.JsonClassDiscriminator -import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.JsonDecoder +import kotlinx.serialization.serializerOrNull import kotlin.collections.mutableListOf +import kotlin.reflect.KClass enum class Role { @SerialName("user") @@ -35,7 +43,7 @@ data class MessageRequest( @SerialName("stop_sequences") val stopSequences: List?, val stream: Boolean?, - val system: List?, + val system: List?, val temperature: Double?, @SerialName("tool_choice") val toolChoice: ToolChoice?, @@ -49,12 +57,12 @@ data class MessageRequest( ) { var model: String? = null var maxTokens = 1024 - val messages = mutableListOf() + var messages: List = mutableListOf() var metadata = null val stopSequences = mutableListOf() var stream: Boolean? = null internal set - val systemTexts = mutableListOf() + var system: List? = null var temperature: Double? = null var toolChoice: ToolChoice? = null var tools: List? = null @@ -77,14 +85,11 @@ data class MessageRequest( this.stopSequences += stopSequences.toList() } - var system: String? - get() = if (systemTexts.isEmpty()) null else systemTexts[0].text - set(value) { - systemTexts.clear() - if (value != null) { - systemTexts.add(Text(text = value)) - } - } + fun system( + text: String + ) { + system = listOf(System(text = text)) + } fun build(): MessageRequest = MessageRequest( model = if (model != null) model!! else defaultApiModel, @@ -93,7 +98,7 @@ data class MessageRequest( metadata = metadata, stopSequences = stopSequences.toNullIfEmpty(), stream = if (stream != null) stream else null, - system = systemTexts.toNullIfEmpty(), + system = system, temperature = temperature, toolChoice = toolChoice, tools = tools, @@ -131,8 +136,15 @@ data class MessageResponse( @SerialName("message") MESSAGE } + + fun asMessage() = Message { + role = Role.ASSISTANT + content += this.content + } + } + @Serializable data class ErrorResponse( val type: String, @@ -180,30 +192,37 @@ fun Message(block: Message.Builder.() -> Unit): Message { return builder.build() } +@Serializable +data class System( + @SerialName("cache_control") + val cacheControl: CacheControl? = null, + val type: Type = Type.TEXT, + val text: String? = null, +) { + + enum class Type { + @SerialName("text") + TEXT + } + +} + @Serializable data class Tool( val name: String, val description: String, @SerialName("input_schema") val inputSchema: JsonSchema, + @SerialName("cache_control") val cacheControl: CacheControl? ) -inline fun Tool( - description: String, - cacheControl: CacheControl? = null -): Tool = Tool( - name = anthropicTypeOf(), - description = description, - inputSchema = jsonSchemaOf(), - cacheControl = cacheControl -) - @Serializable @JsonClassDiscriminator("type") @OptIn(ExperimentalSerializationApi::class) sealed class Content { + @SerialName("cache_control") abstract val cacheControl: CacheControl? } @@ -212,6 +231,7 @@ sealed class Content { @SerialName("text") data class Text( val text: String, + @SerialName("cache_control") override val cacheControl: CacheControl? = null, ) : Content() @@ -219,6 +239,7 @@ data class Text( @SerialName("image") data class Image( val source: Source, + @SerialName("cache_control") override val cacheControl: CacheControl? = null ) : Content() { @@ -250,38 +271,78 @@ data class Image( } -@Serializable @SerialName("tool_use") +@Serializable data class ToolUse( + @SerialName("cache_control") override val cacheControl: CacheControl? = null, val id: String, val name: String, - val input: JsonObject + val input: UsableTool ) : Content() { - inline fun input(): T = - anthropicJson.decodeFromJsonElement(input) +// inline fun input(): T = +// anthropicJson.decodeFromJsonElement(input) + + fun use( + context: Anthropic.Context = Anthropic.EMPTY_CONTEXT + ): ToolResult = input.use( + toolUseId = id, + context + ) + +} + +@JsonClassDiscriminator("type") +@OptIn(ExperimentalSerializationApi::class) +//@Serializable(with = UsableToolSerializer::class) +interface UsableTool { + + fun use( + toolUseId: String, + context: Anthropic.Context + ): ToolResult + +} + +interface SimpleUsableTool : UsableTool { + + override fun use( + toolUseId: String, + context: Anthropic.Context + ): ToolResult = use(toolUseId) + + fun use(toolUseId: String): ToolResult } @Serializable @SerialName("tool_result") data class ToolResult( - override val cacheControl: CacheControl? = null, @SerialName("tool_use_id") val toolUseId: String, + val content: List, // TODO only Text, Image allowed here, should be accessible in gthe builder @SerialName("is_error") val isError: Boolean = false, - val content: List + @SerialName("cache_control") + override val cacheControl: CacheControl? = null ) : Content() +fun ToolResult( + toolUseId: String, + text: String +): ToolResult = ToolResult( + toolUseId, + content = listOf(Text(text)) +) + @Serializable data class CacheControl( val type: Type ) { - @SerialName("ephemeral") enum class Type { + @SerialName("ephemeral") EPHEMERAL } @@ -330,3 +391,99 @@ data class Usage( @SerialName("output_tokens") val outputTokens: Int ) + + +interface CacheableBuilder { + + var cacheControl: CacheControl? + + var cache: Boolean + get() = cacheControl != null + set(value) { + if (value) { + cacheControl = CacheControl(type = CacheControl.Type.EPHEMERAL) + } else { + cacheControl = null + } + } + +} + +//class UsableToolSerializer : JsonContentPolymorphicSerializer2(UsableTool::class) { +// +//// override val descriptor: SerialDescriptor = buildClassSerialDescriptor("UsableTool") { +//// element("type") +//// element("data") +//// } +//// +//// override fun serialize( +//// encoder: Encoder, +//// value: UsableTool +//// ) { +////// val polymorphic: SerializationStrategy = serializersModule.getPolymorphic(UsableTool::class, "foo") +//// PolymorphicSerializer(UsableTool::class) +////// encoder.encodeString(value) +////// encoder.encodeString(value.name) +////// polymorphic.seri +////// encoder.encodeSerializableValue(polymorphic) +//// } +//// +//// override fun deserialize(decoder: Decoder): UsableTool { +//// require(decoder is JsonDecoder) { "This serializer can be used only with Json format" } +//// val name = decoder.decodeString() +//// val polymorphic = decoder.serializersModule.getPolymorphic(UsableTool::class, name) +//// val id = decoder.decodeString() +//// return DummyUsableTool() +//// } +// +// override fun selectDeserializer(element: JsonElement): DeserializationStrategy { +// println(element) +// TODO("dupa dupa Not yet implemented") +// } +// +//} + + +//class UsableToolSerializer : JsonContentPolymorphicSerializer2( +// UsableTool::class +//) + +@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) +open class JsonContentPolymorphicSerializer2(private val baseClass: KClass) : KSerializer { + /** + * A descriptor for this set of content-based serializers. + * By default, it uses the name composed of [baseClass] simple name, + * kind is set to [PolymorphicKind.SEALED] and contains 0 elements. + * + * However, this descriptor can be overridden to achieve better representation of custom transformed JSON shape + * for schema generating/introspection purposes. + */ + override val descriptor: SerialDescriptor = + buildSerialDescriptor("JsonContentPolymorphicSerializer<${baseClass.simpleName}>", PolymorphicKind.SEALED) + + final override fun serialize(encoder: Encoder, value: T) { + val actualSerializer = + encoder.serializersModule.getPolymorphic(baseClass, value) + ?: value::class.serializerOrNull() + ?: throw SerializationException("fiu fiu") + @Suppress("UNCHECKED_CAST") + (actualSerializer as KSerializer).serialize(encoder, value) + } + + final override fun deserialize(decoder: Decoder): T { + val input = decoder.asJsonDecoder() + input.json.serializersModule.getPolymorphic(UsableTool::class, "foo") + val tree = input.decodeJsonElement() + + @Suppress("UNCHECKED_CAST") + val actualSerializer = String.serializer() as KSerializer + return input.json.decodeFromJsonElement(actualSerializer, tree) + } + +} + +internal fun Decoder.asJsonDecoder(): JsonDecoder = this as? JsonDecoder + ?: throw IllegalStateException( + "This serializer can be used only with Json format." + + "Expected Decoder to be JsonDecoder, got ${this::class}" + ) \ No newline at end of file diff --git a/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt b/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt index ef71a22..5e1968e 100644 --- a/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt +++ b/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt @@ -3,9 +3,11 @@ package com.xemantic.anthropic.schema import kotlinx.serialization.* import kotlinx.serialization.descriptors.* import kotlin.collections.set +import kotlin.reflect.KClass -inline fun jsonSchemaOf(): JsonSchema = generateSchema( - serializer().descriptor +@OptIn(InternalSerializationApi::class) +fun KClass<*>.toJsonSchema(): JsonSchema = generateSchema( + serializer().descriptor ) @OptIn(ExperimentalSerializationApi::class) diff --git a/src/commonMain/kotlin/tool/Tools.kt b/src/commonMain/kotlin/tool/Tools.kt new file mode 100644 index 0000000..5b26a8c --- /dev/null +++ b/src/commonMain/kotlin/tool/Tools.kt @@ -0,0 +1,62 @@ +package com.xemantic.anthropic.tool + +import com.xemantic.anthropic.message.CacheControl +import com.xemantic.anthropic.message.Tool +import com.xemantic.anthropic.message.UsableTool +import com.xemantic.anthropic.schema.toJsonSchema +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.InternalSerializationApi +import kotlinx.serialization.modules.SerializersModule +import kotlinx.serialization.modules.polymorphic +import kotlinx.serialization.serializer +import kotlin.reflect.KClass + +annotation class Description( + val value: String +) + +@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) +fun KClass.verify() { + // TODO how to get class serializer correctly? + checkNotNull(serializer()) { + "Invalid tool definition, not serializer for class ${this@verify}" + } + checkNotNull(serializer().descriptor.annotations.filterIsInstance().firstOrNull()) { + "Not @Description annotation specified for the tool" + } +} + +@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) +fun KClass.instance( + cacheControl: CacheControl? = null +): Tool { + val descriptor = serializer().descriptor + val description = descriptor.annotations.filterIsInstance().firstOrNull()!!.value + return Tool( + name = descriptor.serialName, + description = description, + inputSchema = toJsonSchema(), + cacheControl = cacheControl + ) +} + +//inline fun anthropicTypeOf(): String = +// T::class.qualifiedName!!.replace('.', '_') + + +@OptIn(InternalSerializationApi::class) +fun List>.toSerializersModule(): SerializersModule = SerializersModule { + polymorphic(UsableTool::class) { + forEach { subclass(it, it.serializer()) } + } +} + +inline fun Tool( + description: String, + cacheControl: CacheControl? = null +): Tool = Tool( + name = anthropicTypeOf(), + description = description, + inputSchema = jsonSchemaOf(), + cacheControl = cacheControl +) diff --git a/src/commonTest/kotlin/AnthropicTest.kt b/src/commonTest/kotlin/AnthropicTest.kt index 605b231..f6a1ca2 100644 --- a/src/commonTest/kotlin/AnthropicTest.kt +++ b/src/commonTest/kotlin/AnthropicTest.kt @@ -1,6 +1,6 @@ package com.xemantic.anthropic -import com.xemantic.anthropic.event.ContentBlockDelta +import com.xemantic.anthropic.event.ContentBlockDeltaEvent import com.xemantic.anthropic.event.Delta.TextDelta import com.xemantic.anthropic.message.Image import com.xemantic.anthropic.message.Message @@ -9,15 +9,14 @@ import com.xemantic.anthropic.message.Role import com.xemantic.anthropic.message.StopReason import com.xemantic.anthropic.message.Text import com.xemantic.anthropic.message.Tool -import com.xemantic.anthropic.message.ToolChoice import com.xemantic.anthropic.message.ToolUse import kotlinx.coroutines.flow.filterIsInstance import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList import kotlinx.coroutines.test.runTest -import kotlinx.serialization.Serializable import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFalse import kotlin.test.assertNull import kotlin.test.assertTrue @@ -92,7 +91,7 @@ class AnthropicTest { +"Say: 'The quick brown fox jumps over the lazy dog'" } } - .filterIsInstance() + .filterIsInstance() .map { (it.delta as TextDelta).text } .toList() .joinToString(separator = "") @@ -101,71 +100,40 @@ class AnthropicTest { assertTrue(response == "The quick brown fox jumps over the lazy dog.") } - // given - @Serializable - data class Calculator( - val operation: Operation, - val a: Double, - val b: Double - ) { - - @Suppress("unused") // it is used, but by Anthropic, so we skip the warning - enum class Operation( - val calculate: (a: Double, b: Double) -> Double - ) { - ADD({ a, b -> a + b }), - SUBTRACT({ a, b -> a - b }), - MULTIPLY({ a, b -> a * b }), - DIVIDE({ a, b -> a / b }) - } - - fun calculate() = operation.calculate(a, b) - - } - @Test fun shouldUseCalculatorTool() = runTest { // given val client = Anthropic() - val calculatorTool = Tool( - description = "Perform basic arithmetic operations", - ) - - // when - val response = client.messages.create { - +Message { - +"What's 15 multiplied by 7?" - } - tools = listOf(calculatorTool) - toolChoice = ToolChoice.Any() - } - - // then - response.apply { - assertTrue(content.size == 1) - assertTrue(content[0] is ToolUse) - val toolUse = content[0] as ToolUse - assertTrue(toolUse.name == "com_xemantic_anthropic_AnthropicTest_Calculator") - val calculator = toolUse.input() - val result = calculator.calculate() - assertTrue(result == 15.0 * 7.0) - } - } - - @Serializable - data class Fibonacci(val n: Int) - - tailrec fun fibonacci( - n: Int, a: Int = 0, b: Int = 1 - ): Int = when (n) { - 0 -> a; 1 -> b; else -> fibonacci(n - 1, b, a + b) + client.tools += Calculator::class + +// // when +// val response = client.messages.create { +// +Message { +// +"What's 15 multiplied by 7?" +// } +// tools = listOf(calculatorTool) +// } +// +// // then +// response.apply { +// assertTrue(content.size == 1) +// assertTrue(content[0] is ToolUse) +// val toolUse = content[0] as ToolUse +// assertTrue(toolUse.name == "com_xemantic_anthropic_AnthropicTest_Calculator") +// val result = toolUse.use() +// assertTrue(result.toolUseId == toolUse.id) +// assertFalse(result.isError) +// assertTrue(result.content == listOf(Text(text = "${15.0 * 7.0}"))) +// } } @Test fun shouldUseCalculatorToolForFibonacci() = runTest { // given - val client = Anthropic() - val fibonacciTool = Tool( + val client = Anthropic { + tools = listOf(FibonacciTool::class) + } + val fibonacciTool = Tool( description = "Calculates fibonacci number of a given n", ) @@ -173,7 +141,6 @@ class AnthropicTest { val response = client.messages.create { +Message { +"What's fibonacci number 42" } tools = listOf(fibonacciTool) - toolChoice = ToolChoice.Any() } // then @@ -182,11 +149,16 @@ class AnthropicTest { assertTrue(content[0] is ToolUse) val toolUse = content[0] as ToolUse assertTrue(toolUse.name == "com_xemantic_anthropic_AnthropicTest_Fibonacci") - val n = toolUse.input().n - assertTrue(n == 42) - val fibonacciNumber = fibonacci(n) // doing the job for Anthropic - assertTrue(fibonacciNumber == 267914296) + val result = toolUse.use() + assertTrue(result.toolUseId == toolUse.id) + assertFalse(result.isError) + assertTrue(result.content == listOf(Text(text = "267914296"))) } } + @Test + fun shouldUseAnnotations() { + println(FibonacciTool::class.annotations) + } + } diff --git a/src/commonTest/kotlin/AnthropicTestTools.kt b/src/commonTest/kotlin/AnthropicTestTools.kt new file mode 100644 index 0000000..c6bf573 --- /dev/null +++ b/src/commonTest/kotlin/AnthropicTestTools.kt @@ -0,0 +1,65 @@ +package com.xemantic.anthropic + +import com.xemantic.anthropic.message.SimpleUsableTool +import com.xemantic.anthropic.message.Text +import com.xemantic.anthropic.message.ToolResult +import com.xemantic.anthropic.message.UsableTool +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.modules.SerializersModule +import kotlinx.serialization.modules.polymorphic +import kotlinx.serialization.modules.subclass + +@Serializable +@SerialName("fibonacci") +@Description("Calculate Fibonacci number n") +data class FibonacciTool(val n: Int): SimpleUsableTool { + + tailrec fun fibonacci( + n: Int, a: Int = 0, b: Int = 1 + ): Int = when (n) { + 0 -> a; 1 -> b; else -> fibonacci(n - 1, b, a + b) + } + + override fun use( + toolUseId: String, + ) = ToolResult( + toolUseId, + content = listOf(Text(text = "${fibonacci(n)}")) + ) + +} + +@Serializable +@SerialName("com_xemantic_anthropic_AnthropicTest_Calculator") +@Description("Calculates the arithmetic outcome of an operation when given the arguments a and b") +data class Calculator( + val operation: Operation, + val a: Double, + val b: Double +): SimpleUsableTool { + + @Suppress("unused") // it is used, but by Anthropic, so we skip the warning + enum class Operation( + val calculate: (a: Double, b: Double) -> Double + ) { + ADD({ a, b -> a + b }), + SUBTRACT({ a, b -> a - b }), + MULTIPLY({ a, b -> a * b }), + DIVIDE({ a, b -> a / b }) + } + + override fun use(toolUseId: String) = ToolResult( + toolUseId, + operation.calculate(a, b).toString() + ) + +} + +// TODO this can be constructed on fly +val testToolsSerializersModule = SerializersModule { + polymorphic(UsableTool::class) { + subclass(Calculator::class) + subclass(FibonacciTool::class) + } +} diff --git a/src/commonTest/kotlin/message/MessagesTest.kt b/src/commonTest/kotlin/message/MessagesTest.kt index e5e9444..67a9410 100644 --- a/src/commonTest/kotlin/message/MessagesTest.kt +++ b/src/commonTest/kotlin/message/MessagesTest.kt @@ -1,6 +1,7 @@ package com.xemantic.anthropic.message import com.xemantic.anthropic.anthropicJson +import com.xemantic.anthropic.testToolsSerializersModule import io.kotest.assertions.json.shouldEqualJson import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.encodeToString @@ -19,6 +20,7 @@ class MessagesTest { prettyPrint = true @OptIn(ExperimentalSerializationApi::class) prettyPrintIndent = " " + serializersModule = testToolsSerializersModule } @Test @@ -55,4 +57,36 @@ class MessagesTest { """.trimIndent() } + @Test + fun shouldDeserializeToolUseRequest() { + val request = """ + { + "id": "msg_01PspkNzNG3nrf5upeTsmWLF", + "type": "message", + "role": "assistant", + "model": "claude-3-opus-20240229", + "content": [ + { + "type": "tool_use", + "id": "toolu_01YHJK38TBKCRPn7zfjxcKHx", + "name": "com_xemantic_anthropic_AnthropicTest_Calculator", + "input": { + "operation": "MULTIPLY", + "a": 15, + "b": 7 + } + } + ], + "stop_reason": "tool_use", + "stop_sequence": null, + "usage": { + "input_tokens": 419, + "output_tokens": 86 + } + } + """.trimIndent() + + val response = json.decodeFromString(request) + } + }