diff --git a/src/commonMain/kotlin/message/Messages.kt b/src/commonMain/kotlin/message/Messages.kt index 73ccd41..6d44d73 100644 --- a/src/commonMain/kotlin/message/Messages.kt +++ b/src/commonMain/kotlin/message/Messages.kt @@ -2,12 +2,7 @@ package com.xemantic.anthropic.message import com.xemantic.anthropic.schema.JsonSchema import com.xemantic.anthropic.tool.UsableTool -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.InternalSerializationApi -import kotlinx.serialization.KSerializer -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable -import kotlinx.serialization.SerializationException +import kotlinx.serialization.* import kotlinx.serialization.builtins.serializer import kotlinx.serialization.descriptors.PolymorphicKind import kotlinx.serialization.descriptors.SerialDescriptor @@ -16,7 +11,7 @@ import kotlinx.serialization.encoding.Decoder import kotlinx.serialization.encoding.Encoder import kotlinx.serialization.json.JsonClassDiscriminator import kotlinx.serialization.json.JsonDecoder -import kotlinx.serialization.serializerOrNull +import kotlinx.serialization.json.JsonObject import kotlin.collections.mutableListOf import kotlin.reflect.KClass @@ -282,14 +277,8 @@ data class ToolUse( override val cacheControl: CacheControl? = null, val id: String, val name: String, - val input: UsableTool -) : Content() { - - fun use(): ToolResult = input.use( - toolUseId = id - ) - -} + val input: JsonObject +) : Content() @Serializable @SerialName("tool_result") diff --git a/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt b/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt index 5e1968e..71ef76a 100644 --- a/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt +++ b/src/commonMain/kotlin/schema/JsonSchemaGenerator.kt @@ -5,6 +5,10 @@ import kotlinx.serialization.descriptors.* import kotlin.collections.set import kotlin.reflect.KClass +inline fun jsonSchemaOf(): JsonSchema = generateSchema( + serializer().descriptor +) + @OptIn(InternalSerializationApi::class) fun KClass<*>.toJsonSchema(): JsonSchema = generateSchema( serializer().descriptor diff --git a/src/commonMain/kotlin/tool/Tools.kt b/src/commonMain/kotlin/tool/Tools.kt index ab3482b..a546dff 100644 --- a/src/commonMain/kotlin/tool/Tools.kt +++ b/src/commonMain/kotlin/tool/Tools.kt @@ -1,11 +1,14 @@ package com.xemantic.anthropic.tool +import com.xemantic.anthropic.anthropicJson import com.xemantic.anthropic.message.CacheControl import com.xemantic.anthropic.message.Tool import com.xemantic.anthropic.message.ToolResult +import com.xemantic.anthropic.message.ToolUse import com.xemantic.anthropic.schema.toJsonSchema import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.InternalSerializationApi +import kotlinx.serialization.KSerializer import kotlinx.serialization.json.JsonClassDiscriminator import kotlinx.serialization.modules.SerializersModule import kotlinx.serialization.modules.polymorphic @@ -16,7 +19,7 @@ annotation class Description( val value: String ) -@JsonClassDiscriminator("type") +@JsonClassDiscriminator("name") @OptIn(ExperimentalSerializationApi::class) //@Serializable(with = UsableToolSerializer::class) interface UsableTool { @@ -27,6 +30,10 @@ interface UsableTool { } +class ToolContentSerializer() { + +} + @OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) fun KClass.verify() { // TODO how to get class serializer correctly? @@ -72,3 +79,12 @@ fun List>.toSerializersModule(): SerializersModule = // inputSchema = jsonSchemaOf(), // cacheControl = cacheControl //) + + +fun ToolUse.use( + map: Map> +): ToolResult { + val serializer = map[name]!! + val tool = anthropicJson.decodeFromJsonElement(serializer, input) + return tool.use(toolUseId = id) +} diff --git a/src/commonTest/kotlin/AnthropicTest.kt b/src/commonTest/kotlin/AnthropicTest.kt index d827cf7..aac76b5 100644 --- a/src/commonTest/kotlin/AnthropicTest.kt +++ b/src/commonTest/kotlin/AnthropicTest.kt @@ -9,6 +9,7 @@ import com.xemantic.anthropic.message.Role import com.xemantic.anthropic.message.StopReason import com.xemantic.anthropic.message.Text import com.xemantic.anthropic.message.ToolUse +import com.xemantic.anthropic.tool.use import kotlinx.coroutines.flow.filterIsInstance import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList diff --git a/src/commonTest/kotlin/message/MessagesTest.kt b/src/commonTest/kotlin/message/MessagesTest.kt index 67a9410..43cab02 100644 --- a/src/commonTest/kotlin/message/MessagesTest.kt +++ b/src/commonTest/kotlin/message/MessagesTest.kt @@ -1,7 +1,8 @@ package com.xemantic.anthropic.message +import com.xemantic.anthropic.Calculator import com.xemantic.anthropic.anthropicJson -import com.xemantic.anthropic.testToolsSerializersModule +import com.xemantic.anthropic.tool.toSerializersModule import io.kotest.assertions.json.shouldEqualJson import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.encodeToString @@ -20,7 +21,7 @@ class MessagesTest { prettyPrint = true @OptIn(ExperimentalSerializationApi::class) prettyPrintIndent = " " - serializersModule = testToolsSerializersModule + //serializersModule = testToolsSerializersModule } @Test @@ -59,6 +60,9 @@ class MessagesTest { @Test fun shouldDeserializeToolUseRequest() { + val json = Json(from = json) { + serializersModule = listOf(Calculator::class).toSerializersModule() + } val request = """ { "id": "msg_01PspkNzNG3nrf5upeTsmWLF",