From 7338ee33d90f8b9594cb154b030c155dbf8a5882 Mon Sep 17 00:00:00 2001 From: Kazik Pogoda Date: Wed, 9 Oct 2024 18:09:00 +0200 Subject: [PATCH] better tools support, initial attempt WIP --- src/commonMain/kotlin/Anthropic.kt | 33 ++----------- src/commonMain/kotlin/message/Messages.kt | 39 +++------------ src/commonMain/kotlin/tool/Tools.kt | 32 +++++++++---- src/commonTest/kotlin/AnthropicTest.kt | 53 ++++++++++----------- src/commonTest/kotlin/AnthropicTestTools.kt | 27 +++-------- 5 files changed, 63 insertions(+), 121 deletions(-) diff --git a/src/commonMain/kotlin/Anthropic.kt b/src/commonMain/kotlin/Anthropic.kt index 7c537be..0c2adea 100644 --- a/src/commonMain/kotlin/Anthropic.kt +++ b/src/commonMain/kotlin/Anthropic.kt @@ -1,12 +1,11 @@ package com.xemantic.anthropic import com.xemantic.anthropic.event.Event -import com.xemantic.anthropic.event.Usage import com.xemantic.anthropic.message.Error import com.xemantic.anthropic.message.ErrorResponse import com.xemantic.anthropic.message.MessageRequest import com.xemantic.anthropic.message.MessageResponse -import com.xemantic.anthropic.message.UsableTool +import com.xemantic.anthropic.tool.UsableTool import io.ktor.client.HttpClient import io.ktor.client.call.body import io.ktor.client.plugins.contentnegotiation.ContentNegotiation @@ -28,14 +27,8 @@ import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.filter import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map -import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.json.Json -import kotlinx.serialization.modules.SerializersModule -import kotlinx.serialization.modules.polymorphic -import kotlinx.serialization.serializer import kotlin.reflect.KClass -import kotlin.reflect.KType -import kotlin.reflect.typeOf const val ANTHROPIC_API_BASE: String = "https://api.anthropic.com/" @@ -75,7 +68,7 @@ fun Anthropic( apiBase = config.apiBase, defaultModel = defaultModel, directBrowserAccess = config.directBrowserAccess, - context = config.context + usableTools = config.usableTools ) } @@ -86,8 +79,7 @@ class Anthropic internal constructor( val apiBase: String, val defaultModel: String, val directBrowserAccess: Boolean, - val tools: MutableList> = mutableListOf>(), - val context: Context? + val usableTools: MutableList> = mutableListOf() ) { class Config { @@ -97,26 +89,9 @@ class Anthropic internal constructor( var apiBase: String = ANTHROPIC_API_BASE var defaultModel: String? = null var directBrowserAccess: Boolean = false - var tools: MutableList>? = null - var context: Context? = null + var usableTools: MutableList> = mutableListOf() } - interface Context { - - fun service(type: KType): T - - } - - companion object { - val EMPTY_CONTEXT: Context = object : Context { - override fun service(type: KType): T { - throw UnsupportedOperationException("No services available") - } - } - } - - inline fun Context.service(): T = service(typeOf()) - private val client = HttpClient { install(ContentNegotiation) { json(anthropicJson) diff --git a/src/commonMain/kotlin/message/Messages.kt b/src/commonMain/kotlin/message/Messages.kt index 4ff1b7e..73ccd41 100644 --- a/src/commonMain/kotlin/message/Messages.kt +++ b/src/commonMain/kotlin/message/Messages.kt @@ -1,7 +1,7 @@ package com.xemantic.anthropic.message -import com.xemantic.anthropic.Anthropic import com.xemantic.anthropic.schema.JsonSchema +import com.xemantic.anthropic.tool.UsableTool import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.KSerializer @@ -69,6 +69,10 @@ data class MessageRequest( val topK: Int? = null val topP: Int? = null + fun tools(vararg classes: KClass) { + // TODO it needs access to Anthropic, therefore either needs a constructor parameter, or needs to be inner class + } + fun messages(vararg messages: Message) { this.messages += messages.toList() } @@ -281,41 +285,12 @@ data class ToolUse( val input: UsableTool ) : Content() { -// inline fun input(): T = -// anthropicJson.decodeFromJsonElement(input) - - fun use( - context: Anthropic.Context = Anthropic.EMPTY_CONTEXT - ): ToolResult = input.use( - toolUseId = id, - context + fun use(): ToolResult = input.use( + toolUseId = id ) } -@JsonClassDiscriminator("type") -@OptIn(ExperimentalSerializationApi::class) -//@Serializable(with = UsableToolSerializer::class) -interface UsableTool { - - fun use( - toolUseId: String, - context: Anthropic.Context - ): ToolResult - -} - -interface SimpleUsableTool : UsableTool { - - override fun use( - toolUseId: String, - context: Anthropic.Context - ): ToolResult = use(toolUseId) - - fun use(toolUseId: String): ToolResult - -} - @Serializable @SerialName("tool_result") data class ToolResult( diff --git a/src/commonMain/kotlin/tool/Tools.kt b/src/commonMain/kotlin/tool/Tools.kt index 5b26a8c..ab3482b 100644 --- a/src/commonMain/kotlin/tool/Tools.kt +++ b/src/commonMain/kotlin/tool/Tools.kt @@ -2,10 +2,11 @@ package com.xemantic.anthropic.tool import com.xemantic.anthropic.message.CacheControl import com.xemantic.anthropic.message.Tool -import com.xemantic.anthropic.message.UsableTool +import com.xemantic.anthropic.message.ToolResult import com.xemantic.anthropic.schema.toJsonSchema import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.InternalSerializationApi +import kotlinx.serialization.json.JsonClassDiscriminator import kotlinx.serialization.modules.SerializersModule import kotlinx.serialization.modules.polymorphic import kotlinx.serialization.serializer @@ -15,6 +16,17 @@ annotation class Description( val value: String ) +@JsonClassDiscriminator("type") +@OptIn(ExperimentalSerializationApi::class) +//@Serializable(with = UsableToolSerializer::class) +interface UsableTool { + + fun use( + toolUseId: String + ): ToolResult + +} + @OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) fun KClass.verify() { // TODO how to get class serializer correctly? @@ -51,12 +63,12 @@ fun List>.toSerializersModule(): SerializersModule = } } -inline fun Tool( - description: String, - cacheControl: CacheControl? = null -): Tool = Tool( - name = anthropicTypeOf(), - description = description, - inputSchema = jsonSchemaOf(), - cacheControl = cacheControl -) +//inline fun Tool( +// description: String, +// cacheControl: CacheControl? = null +//): Tool = Tool( +// name = anthropicTypeOf(), +// description = description, +// inputSchema = jsonSchemaOf(), +// cacheControl = cacheControl +//) diff --git a/src/commonTest/kotlin/AnthropicTest.kt b/src/commonTest/kotlin/AnthropicTest.kt index f6a1ca2..d827cf7 100644 --- a/src/commonTest/kotlin/AnthropicTest.kt +++ b/src/commonTest/kotlin/AnthropicTest.kt @@ -8,7 +8,6 @@ import com.xemantic.anthropic.message.MessageResponse import com.xemantic.anthropic.message.Role import com.xemantic.anthropic.message.StopReason import com.xemantic.anthropic.message.Text -import com.xemantic.anthropic.message.Tool import com.xemantic.anthropic.message.ToolUse import kotlinx.coroutines.flow.filterIsInstance import kotlinx.coroutines.flow.map @@ -104,43 +103,39 @@ class AnthropicTest { fun shouldUseCalculatorTool() = runTest { // given val client = Anthropic() - client.tools += Calculator::class - -// // when -// val response = client.messages.create { -// +Message { -// +"What's 15 multiplied by 7?" -// } -// tools = listOf(calculatorTool) -// } -// -// // then -// response.apply { -// assertTrue(content.size == 1) -// assertTrue(content[0] is ToolUse) -// val toolUse = content[0] as ToolUse -// assertTrue(toolUse.name == "com_xemantic_anthropic_AnthropicTest_Calculator") -// val result = toolUse.use() -// assertTrue(result.toolUseId == toolUse.id) -// assertFalse(result.isError) -// assertTrue(result.content == listOf(Text(text = "${15.0 * 7.0}"))) -// } + client.usableTools += Calculator::class + + // when + val response = client.messages.create { + +Message { + +"What's 15 multiplied by 7?" + } + tools(Calculator::class) + } + + // then + response.apply { + assertTrue(content.size == 1) + assertTrue(content[0] is ToolUse) + val toolUse = content[0] as ToolUse + assertTrue(toolUse.name == "com_xemantic_anthropic_AnthropicTest_Calculator") + val result = toolUse.use() + assertTrue(result.toolUseId == toolUse.id) + assertFalse(result.isError) + assertTrue(result.content == listOf(Text(text = "${15.0 * 7.0}"))) + } } @Test fun shouldUseCalculatorToolForFibonacci() = runTest { // given - val client = Anthropic { - tools = listOf(FibonacciTool::class) - } - val fibonacciTool = Tool( - description = "Calculates fibonacci number of a given n", - ) + val client = Anthropic() + client.usableTools += FibonacciTool::class // when val response = client.messages.create { +Message { +"What's fibonacci number 42" } - tools = listOf(fibonacciTool) + tools(FibonacciTool::class) } // then diff --git a/src/commonTest/kotlin/AnthropicTestTools.kt b/src/commonTest/kotlin/AnthropicTestTools.kt index c6bf573..5b03328 100644 --- a/src/commonTest/kotlin/AnthropicTestTools.kt +++ b/src/commonTest/kotlin/AnthropicTestTools.kt @@ -1,19 +1,15 @@ package com.xemantic.anthropic -import com.xemantic.anthropic.message.SimpleUsableTool -import com.xemantic.anthropic.message.Text import com.xemantic.anthropic.message.ToolResult -import com.xemantic.anthropic.message.UsableTool +import com.xemantic.anthropic.tool.Description +import com.xemantic.anthropic.tool.UsableTool import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable -import kotlinx.serialization.modules.SerializersModule -import kotlinx.serialization.modules.polymorphic -import kotlinx.serialization.modules.subclass @Serializable @SerialName("fibonacci") @Description("Calculate Fibonacci number n") -data class FibonacciTool(val n: Int): SimpleUsableTool { +data class FibonacciTool(val n: Int): UsableTool { tailrec fun fibonacci( n: Int, a: Int = 0, b: Int = 1 @@ -23,21 +19,18 @@ data class FibonacciTool(val n: Int): SimpleUsableTool { override fun use( toolUseId: String, - ) = ToolResult( - toolUseId, - content = listOf(Text(text = "${fibonacci(n)}")) - ) + ) = ToolResult(toolUseId, "${fibonacci(n)}") } @Serializable -@SerialName("com_xemantic_anthropic_AnthropicTest_Calculator") +@SerialName("calculator") @Description("Calculates the arithmetic outcome of an operation when given the arguments a and b") data class Calculator( val operation: Operation, val a: Double, val b: Double -): SimpleUsableTool { +): UsableTool { @Suppress("unused") // it is used, but by Anthropic, so we skip the warning enum class Operation( @@ -55,11 +48,3 @@ data class Calculator( ) } - -// TODO this can be constructed on fly -val testToolsSerializersModule = SerializersModule { - polymorphic(UsableTool::class) { - subclass(Calculator::class) - subclass(FibonacciTool::class) - } -}