From 6872a5acf538273ad01a258685a64fad83833a5e Mon Sep 17 00:00:00 2001 From: Kazik Pogoda Date: Thu, 10 Oct 2024 17:08:59 +0200 Subject: [PATCH] better tools support, initial attempt WIP --- build.gradle.kts | 7 +- gradle/libs.versions.toml | 5 +- src/commonMain/kotlin/Anthropic.kt | 48 ++++++- src/commonMain/kotlin/message/Messages.kt | 18 ++- src/commonMain/kotlin/tool/Tools.kt | 87 +++++++++---- src/commonTest/kotlin/AnthropicTest.kt | 124 +++++++++++++++++-- src/commonTest/kotlin/AnthropicTestTools.kt | 50 ++++++-- src/commonTest/kotlin/test/AnthropicTest.kt | 16 +++ src/commonTest/kotlin/tool/UsableToolTest.kt | 74 +++++++++++ 9 files changed, 373 insertions(+), 56 deletions(-) create mode 100644 src/commonTest/kotlin/test/AnthropicTest.kt create mode 100644 src/commonTest/kotlin/tool/UsableToolTest.kt diff --git a/build.gradle.kts b/build.gradle.kts index cb25fda..171cc36 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -76,6 +76,7 @@ kotlin { dependencies { implementation(libs.kotlin.test) implementation(libs.kotlinx.coroutines.test) + implementation(libs.kotest.assertions.core) implementation(libs.kotest.assertions.json) } } @@ -131,11 +132,7 @@ tasks.withType { @Suppress("OPT_IN_USAGE") powerAssert { functions = listOf( - "kotlin.assert", - "kotlin.test.assertTrue", - "kotlin.test.assertFalse", - "kotlin.test.assertEquals", - "kotlin.test.assertNull" + "com.xemantic.anthropic.test.shouldBe" ) includedSourceSets = listOf("commonTest", "jvmTest", "nativeTest") } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 173847e..c278730 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -2,9 +2,9 @@ kotlinTarget = "2.0" javaTarget = "17" -kotlin = "2.0.20" +kotlin = "2.0.21" kotlinxCoroutines = "1.9.0" -ktor = "3.0.0-rc-2" +ktor = "3.0.0" kotest = "5.9.1" log4j = "2.24.1" @@ -30,6 +30,7 @@ ktor-serialization-kotlinx-json = { module = "io.ktor:ktor-serialization-kotlinx ktor-client-java = { module = "io.ktor:ktor-client-java", version.ref = "ktor" } ktor-client-curl = { module = "io.ktor:ktor-client-curl", version.ref = "ktor" } +kotest-assertions-core = { module = "io.kotest:kotest-assertions-core", version.ref = "kotest" } kotest-assertions-json = { module = "io.kotest:kotest-assertions-json", version.ref = "kotest" } [plugins] diff --git a/src/commonMain/kotlin/Anthropic.kt b/src/commonMain/kotlin/Anthropic.kt index 0c2adea..d58966b 100644 --- a/src/commonMain/kotlin/Anthropic.kt +++ b/src/commonMain/kotlin/Anthropic.kt @@ -5,6 +5,8 @@ import com.xemantic.anthropic.message.Error import com.xemantic.anthropic.message.ErrorResponse import com.xemantic.anthropic.message.MessageRequest import com.xemantic.anthropic.message.MessageResponse +import com.xemantic.anthropic.message.Tool +import com.xemantic.anthropic.message.ToolUse import com.xemantic.anthropic.tool.UsableTool import io.ktor.client.HttpClient import io.ktor.client.call.body @@ -27,7 +29,9 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.filter import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map +import kotlinx.serialization.KSerializer import kotlinx.serialization.json.Json +import kotlinx.serialization.serializer import kotlin.reflect.KClass const val ANTHROPIC_API_BASE: String = "https://api.anthropic.com/" @@ -67,9 +71,10 @@ fun Anthropic( anthropicBeta = config.anthropicBeta, apiBase = config.apiBase, defaultModel = defaultModel, - directBrowserAccess = config.directBrowserAccess, + directBrowserAccess = config.directBrowserAccess + ).apply { usableTools = config.usableTools - ) + } } class Anthropic internal constructor( @@ -78,8 +83,7 @@ class Anthropic internal constructor( val anthropicBeta: String?, val apiBase: String, val defaultModel: String, - val directBrowserAccess: Boolean, - val usableTools: MutableList> = mutableListOf() + val directBrowserAccess: Boolean ) { class Config { @@ -89,7 +93,36 @@ class Anthropic internal constructor( var apiBase: String = ANTHROPIC_API_BASE var defaultModel: String? = null var directBrowserAccess: Boolean = false - var usableTools: MutableList> = mutableListOf() + var usableTools: List> = emptyList() + + inline fun tool( + block: T.() -> Unit = {} + ) { + usableTools += T::class + } + } + + private class ToolEntry( + val tool: Tool, + ) + + private var toolSerializerMap = mapOf>() + + var usableTools: List> = emptyList() + get() = field + set(value) { + value.validate() + field = value + } + + inline fun tool() { + usableTools += T::class + } + + fun List>.validate() { + forEach { tool -> + //tool.serializer() + } } private val client = HttpClient { @@ -126,7 +159,10 @@ class Anthropic internal constructor( setBody(request) } if (response.status.isSuccess()) { - return response.body() + return response.body().apply { + content.filterIsInstance() + .forEach { it.toolSerializerMap = toolSerializerMap } + } } else { throw AnthropicException( error = response.body().error, diff --git a/src/commonMain/kotlin/message/Messages.kt b/src/commonMain/kotlin/message/Messages.kt index 6d44d73..d9868c4 100644 --- a/src/commonMain/kotlin/message/Messages.kt +++ b/src/commonMain/kotlin/message/Messages.kt @@ -1,5 +1,6 @@ package com.xemantic.anthropic.message +import com.xemantic.anthropic.anthropicJson import com.xemantic.anthropic.schema.JsonSchema import com.xemantic.anthropic.tool.UsableTool import kotlinx.serialization.* @@ -64,6 +65,10 @@ data class MessageRequest( val topK: Int? = null val topP: Int? = null + fun useTools() { + //too + } + fun tools(vararg classes: KClass) { // TODO it needs access to Anthropic, therefore either needs a constructor parameter, or needs to be inner class } @@ -278,7 +283,18 @@ data class ToolUse( val id: String, val name: String, val input: JsonObject -) : Content() +) : Content() { + + @Transient + internal lateinit var toolSerializerMap: Map> + + fun use(): ToolResult { + val serializer = toolSerializerMap[name]!! + val tool = anthropicJson.decodeFromJsonElement(serializer, input) + return tool.use(toolUseId = id) + } + +} @Serializable @SerialName("tool_result") diff --git a/src/commonMain/kotlin/tool/Tools.kt b/src/commonMain/kotlin/tool/Tools.kt index a546dff..aca6f39 100644 --- a/src/commonMain/kotlin/tool/Tools.kt +++ b/src/commonMain/kotlin/tool/Tools.kt @@ -5,23 +5,27 @@ import com.xemantic.anthropic.message.CacheControl import com.xemantic.anthropic.message.Tool import com.xemantic.anthropic.message.ToolResult import com.xemantic.anthropic.message.ToolUse +import com.xemantic.anthropic.schema.jsonSchemaOf import com.xemantic.anthropic.schema.toJsonSchema import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.KSerializer -import kotlinx.serialization.json.JsonClassDiscriminator +import kotlinx.serialization.MetaSerializable +import kotlinx.serialization.SerializationException import kotlinx.serialization.modules.SerializersModule import kotlinx.serialization.modules.polymorphic import kotlinx.serialization.serializer import kotlin.reflect.KClass -annotation class Description( - val value: String +@OptIn(ExperimentalSerializationApi::class) +@MetaSerializable +@Target(AnnotationTarget.CLASS) +annotation class SerializableTool( + val name: String, + val description: String ) -@JsonClassDiscriminator("name") @OptIn(ExperimentalSerializationApi::class) -//@Serializable(with = UsableToolSerializer::class) interface UsableTool { fun use( @@ -30,35 +34,64 @@ interface UsableTool { } -class ToolContentSerializer() { - -} +fun Tool.cacheControl( + cacheControl: CacheControl? = null +): Tool = if (cacheControl == null) this else Tool( + name, + description, + inputSchema, + cacheControl +) -@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) -fun KClass.verify() { - // TODO how to get class serializer correctly? - checkNotNull(serializer()) { - "Invalid tool definition, not serializer for class ${this@verify}" +@OptIn(ExperimentalSerializationApi::class) +inline fun toolOf(): Tool { + val serializer = try { + serializer() + } catch (e :SerializationException) { + throw SerializationException("The class ${T::class.qualifiedName} must be annotated with @SerializableTool", e) } - checkNotNull(serializer().descriptor.annotations.filterIsInstance().firstOrNull()) { - "Not @Description annotation specified for the tool" + val description = checkNotNull( + serializer + .descriptor + .annotations + .filterIsInstance() + .firstOrNull() + ) { + "No @Description annotation found for ${T::class.qualifiedName}" } -} - -@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) -fun KClass.instance( - cacheControl: CacheControl? = null -): Tool { - val descriptor = serializer().descriptor - val description = descriptor.annotations.filterIsInstance().firstOrNull()!!.value return Tool( - name = descriptor.serialName, - description = description, - inputSchema = toJsonSchema(), - cacheControl = cacheControl + name = description.name, + description = description.description, + inputSchema = jsonSchemaOf(), + cacheControl = null ) } +//@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) +//fun KClass.verify() { +// // TODO how to get class serializer correctly? +// checkNotNull(serializer()) { +// "Invalid tool definition, not serializer for class ${this@verify}" +// } +// checkNotNull(serializer().descriptor.annotations.filterIsInstance().firstOrNull()) { +// "Not @Description annotation specified for the tool" +// } +//} + +//@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) +//fun KClass.instance( +// cacheControl: CacheControl? = null +//): Tool { +// val descriptor = serializer().descriptor +// val description = descriptor.annotations.filterIsInstance().firstOrNull()!!.value +// return Tool( +// name = descriptor.serialName, +// description = description, +// inputSchema = toJsonSchema(), +// cacheControl = cacheControl +// ) +//} + //inline fun anthropicTypeOf(): String = // T::class.qualifiedName!!.replace('.', '_') diff --git a/src/commonTest/kotlin/AnthropicTest.kt b/src/commonTest/kotlin/AnthropicTest.kt index aac76b5..1927e4d 100644 --- a/src/commonTest/kotlin/AnthropicTest.kt +++ b/src/commonTest/kotlin/AnthropicTest.kt @@ -9,7 +9,8 @@ import com.xemantic.anthropic.message.Role import com.xemantic.anthropic.message.StopReason import com.xemantic.anthropic.message.Text import com.xemantic.anthropic.message.ToolUse -import com.xemantic.anthropic.tool.use +import com.xemantic.anthropic.test.then +import com.xemantic.anthropic.test.shouldBe import kotlinx.coroutines.flow.filterIsInstance import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList @@ -128,15 +129,16 @@ class AnthropicTest { } @Test - fun shouldUseCalculatorToolForFibonacci() = runTest { + fun shouldUseFibonacciTool() = runTest { // given - val client = Anthropic() - client.usableTools += FibonacciTool::class + val client = Anthropic { + tool() + } // when val response = client.messages.create { +Message { +"What's fibonacci number 42" } - tools(FibonacciTool::class) + useTools() } // then @@ -153,8 +155,116 @@ class AnthropicTest { } @Test - fun shouldUseAnnotations() { - println(FibonacciTool::class.annotations) + fun shouldUse2ToolsInSequence() = runTest { + // given + val client = Anthropic { + tool() + tool() + } + + // when + val conversation = mutableListOf() + conversation += Message { + +"Calculate Fibonacci number 42 and then divide it by 42" + } + val response1 = client.messages.create { + messages = conversation + useTools() + } + + // then + val fibonacciResult = with(response1) { + assertTrue(content.size == 1) + assertTrue(content[0] is ToolUse) + val toolUse = content[0] as ToolUse + assertTrue(toolUse.name == "fibonacci") + val result = toolUse.use() + assertTrue(result.toolUseId == toolUse.id) + assertFalse(result.isError) + assertTrue(result.content == listOf(Text(text = "267914296"))) + result + } + + // when + conversation += Message { + +fibonacciResult + } + val response2 = client.messages.create { + messages = conversation + useTools() + } + // then + val calculatorResult = with(response2) { + assertTrue(content.size == 1) + assertTrue(content[0] is ToolUse) + val toolUse = content[0] as ToolUse + assertTrue(toolUse.name == "calculator") + val result = toolUse.use() + assertTrue(result.toolUseId == toolUse.id) + assertFalse(result.isError) + assertTrue(result.content == listOf(Text(text = "267914296"))) + result + } + + // when + conversation += Message { +calculatorResult } + val response3 = client.messages.create { + messages = conversation + useTools() + } + with(response3) { + assertTrue(content.size == 1) + assertTrue(content[0] is Text) + val text = content[0] as Text + assertTrue(text.text.contains("6378911.8")) + } + } + + @Test + fun shouldUseToolWithDependencies() = runTest { + // given + val testDb = TestDatabase() + val client = Anthropic { + tool { + database = testDb + } + } + + // when + val conversation = mutableListOf() + conversation += Message { +"List data in CUSTOMER table" } + val response1 = client.messages.create { + messages = conversation + useTools() // TODO it should be a single tool + } + + then(response1) { + stopReason shouldBe StopReason.TOOL_USE + content.size shouldBe 1 + content[0] shouldBe ToolUse + val toolUse = content[0] as ToolUse + assertTrue(toolUse.name == "fibonacci") + } + val toolUse = response1.content[0] as ToolUse + val result = toolUse.use() + then(result) { + toolUseId shouldBe toolUse + isError shouldBe false + content shouldBe listOf(Text(text = "267914296")) + } + + // when + conversation += Message { +result } + val response2 = client.messages.create { + messages = conversation + } + + then(response2) { + content.size shouldBe 1 + content[0] is Text + val text = content[0] as Text + text.text.contains("6378911.8") + } } } diff --git a/src/commonTest/kotlin/AnthropicTestTools.kt b/src/commonTest/kotlin/AnthropicTestTools.kt index 5b03328..d00b6b5 100644 --- a/src/commonTest/kotlin/AnthropicTestTools.kt +++ b/src/commonTest/kotlin/AnthropicTestTools.kt @@ -1,14 +1,14 @@ package com.xemantic.anthropic import com.xemantic.anthropic.message.ToolResult -import com.xemantic.anthropic.tool.Description +import com.xemantic.anthropic.tool.SerializableTool import com.xemantic.anthropic.tool.UsableTool import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable -@Serializable -@SerialName("fibonacci") -@Description("Calculate Fibonacci number n") +@SerializableTool( + name = "FibonacciTool", + description = "Calculate Fibonacci number n" +) data class FibonacciTool(val n: Int): UsableTool { tailrec fun fibonacci( @@ -23,9 +23,10 @@ data class FibonacciTool(val n: Int): UsableTool { } -@Serializable -@SerialName("calculator") -@Description("Calculates the arithmetic outcome of an operation when given the arguments a and b") +@SerializableTool( + name = "Calculator", + description = "Calculates the arithmetic outcome of an operation when given the arguments a and b" +) data class Calculator( val operation: Operation, val a: Double, @@ -48,3 +49,36 @@ data class Calculator( ) } + +interface Database { + fun execute(query: String): List +} + +class TestDatabase : Database { + var executedQuery: String? = null + override fun execute( + query: String + ): List { + executedQuery = query + return listOf("foo", "bar", "buzz") + } +} + +@SerializableTool( + name = "DatabaseQuery", + description = "Executes database query" +) +data class DatabaseQuery( + val query: String +) : UsableTool { + + lateinit var database: Database + + override fun use( + toolUseId: String + ) = ToolResult( + toolUseId, + text = database.execute(query).joinToString() + ) + +} diff --git a/src/commonTest/kotlin/test/AnthropicTest.kt b/src/commonTest/kotlin/test/AnthropicTest.kt new file mode 100644 index 0000000..a698d5e --- /dev/null +++ b/src/commonTest/kotlin/test/AnthropicTest.kt @@ -0,0 +1,16 @@ +package com.xemantic.anthropic.test + +import kotlin.reflect.KClass +import kotlin.test.assertTrue + +fun then(value: T, block: T.() -> Unit) { + block(value) +} + +infix fun T.shouldBe(expected: T): Unit = assert(expected == this) + +fun T.shouldBe(expected: T, message: () -> String): Unit = assert(this == expected, message) + +//infix fun T.shouldBe(expected: KClass) { +// assert(this is expected) +//} diff --git a/src/commonTest/kotlin/tool/UsableToolTest.kt b/src/commonTest/kotlin/tool/UsableToolTest.kt new file mode 100644 index 0000000..b1b4d67 --- /dev/null +++ b/src/commonTest/kotlin/tool/UsableToolTest.kt @@ -0,0 +1,74 @@ +package com.xemantic.anthropic.tool + +import com.xemantic.anthropic.message.ToolResult +import com.xemantic.anthropic.schema.JsonSchema +import com.xemantic.anthropic.schema.JsonSchemaProperty +import com.xemantic.anthropic.test.then +import com.xemantic.anthropic.test.shouldBe +import kotlinx.serialization.Serializable +import kotlinx.serialization.SerializationException +import kotlin.test.Test +import kotlin.test.assertFailsWith + +class UsableToolTest { + + @SerializableTool( + name = "TestTool", + description = "Test tool receiving a message and outputting it back" + ) + class TestTool( + val message: String + ) : UsableTool { + override fun use(toolUseId: String) = ToolResult(toolUseId, message) + } + + @Test + fun shouldCreateToolFromUsableTool() { + // when + val tool = toolOf() + + then(tool) { + name shouldBe "TestTool" + description shouldBe "Test tool receiving a message and outputting it back" + inputSchema shouldBe JsonSchema( + properties = mapOf("message" to JsonSchemaProperty.STRING), + required = listOf("message") + ) + cacheControl shouldBe null + } + } + + class NoAnnotationTool : UsableTool { + override fun use(toolUseId: String) = ToolResult(toolUseId, "nothing") + } + + @Test + fun shouldFailToCreateToolWithoutSerializableToolAnnotation() { + assertFailsWith { + toolOf() + } + try { + toolOf() + } catch (e: SerializationException) { + e.message shouldBe "The class com.xemantic.anthropic.tool.UsableToolTest.NoAnnotationTool must be annotated with @SerializableTool" + } + } + + @Serializable + class OnlySerializableAnnotationTool : UsableTool { + override fun use(toolUseId: String) = ToolResult(toolUseId, "nothing") + } + + @Test + fun shouldFailToCreateToolWithOnlySerializableToolAnnotation() { + assertFailsWith { + toolOf() + } + try { + toolOf() + } catch (e: SerializationException) { + e.message shouldBe "The class com.xemantic.anthropic.tool.UsableToolTest.NoAnnotationTool must be annotated with @SerializableTool" + } + } + +}