Skip to content

Commit

Permalink
better tools support, initial attempt WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
morisil committed Oct 9, 2024
1 parent 6c01e9f commit 7338ee3
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 121 deletions.
33 changes: 4 additions & 29 deletions src/commonMain/kotlin/Anthropic.kt
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
package com.xemantic.anthropic

import com.xemantic.anthropic.event.Event
import com.xemantic.anthropic.event.Usage
import com.xemantic.anthropic.message.Error
import com.xemantic.anthropic.message.ErrorResponse
import com.xemantic.anthropic.message.MessageRequest
import com.xemantic.anthropic.message.MessageResponse
import com.xemantic.anthropic.message.UsableTool
import com.xemantic.anthropic.tool.UsableTool
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
Expand All @@ -28,14 +27,8 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map
import kotlinx.serialization.InternalSerializationApi
import kotlinx.serialization.json.Json
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.modules.polymorphic
import kotlinx.serialization.serializer
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.typeOf

const val ANTHROPIC_API_BASE: String = "https://api.anthropic.com/"

Expand Down Expand Up @@ -75,7 +68,7 @@ fun Anthropic(
apiBase = config.apiBase,
defaultModel = defaultModel,
directBrowserAccess = config.directBrowserAccess,
context = config.context
usableTools = config.usableTools
)
}

Expand All @@ -86,8 +79,7 @@ class Anthropic internal constructor(
val apiBase: String,
val defaultModel: String,
val directBrowserAccess: Boolean,
val tools: MutableList<KClass<out UsableTool>> = mutableListOf<KClass<out UsableTool>>(),
val context: Context?
val usableTools: MutableList<KClass<out UsableTool>> = mutableListOf()
) {

class Config {
Expand All @@ -97,26 +89,9 @@ class Anthropic internal constructor(
var apiBase: String = ANTHROPIC_API_BASE
var defaultModel: String? = null
var directBrowserAccess: Boolean = false
var tools: MutableList<KClass<out UsableTool>>? = null
var context: Context? = null
var usableTools: MutableList<KClass<out UsableTool>> = mutableListOf()
}

interface Context {

fun <T> service(type: KType): T

}

companion object {
val EMPTY_CONTEXT: Context = object : Context {
override fun <T> service(type: KType): T {
throw UnsupportedOperationException("No services available")
}
}
}

inline fun <reified T> Context.service(): T = service(typeOf<T>())

private val client = HttpClient {
install(ContentNegotiation) {
json(anthropicJson)
Expand Down
39 changes: 7 additions & 32 deletions src/commonMain/kotlin/message/Messages.kt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.xemantic.anthropic.message

import com.xemantic.anthropic.Anthropic
import com.xemantic.anthropic.schema.JsonSchema
import com.xemantic.anthropic.tool.UsableTool
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.InternalSerializationApi
import kotlinx.serialization.KSerializer
Expand Down Expand Up @@ -69,6 +69,10 @@ data class MessageRequest(
val topK: Int? = null
val topP: Int? = null

fun tools(vararg classes: KClass<out UsableTool>) {
// TODO it needs access to Anthropic, therefore either needs a constructor parameter, or needs to be inner class
}

fun messages(vararg messages: Message) {
this.messages += messages.toList()
}
Expand Down Expand Up @@ -281,41 +285,12 @@ data class ToolUse(
val input: UsableTool
) : Content() {

// inline fun <reified T> input(): T =
// anthropicJson.decodeFromJsonElement<T>(input)

fun use(
context: Anthropic.Context = Anthropic.EMPTY_CONTEXT
): ToolResult = input.use(
toolUseId = id,
context
fun use(): ToolResult = input.use(
toolUseId = id
)

}

@JsonClassDiscriminator("type")
@OptIn(ExperimentalSerializationApi::class)
//@Serializable(with = UsableToolSerializer::class)
interface UsableTool {

fun use(
toolUseId: String,
context: Anthropic.Context
): ToolResult

}

interface SimpleUsableTool : UsableTool {

override fun use(
toolUseId: String,
context: Anthropic.Context
): ToolResult = use(toolUseId)

fun use(toolUseId: String): ToolResult

}

@Serializable
@SerialName("tool_result")
data class ToolResult(
Expand Down
32 changes: 22 additions & 10 deletions src/commonMain/kotlin/tool/Tools.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package com.xemantic.anthropic.tool

import com.xemantic.anthropic.message.CacheControl
import com.xemantic.anthropic.message.Tool
import com.xemantic.anthropic.message.UsableTool
import com.xemantic.anthropic.message.ToolResult
import com.xemantic.anthropic.schema.toJsonSchema
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.InternalSerializationApi
import kotlinx.serialization.json.JsonClassDiscriminator
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.modules.polymorphic
import kotlinx.serialization.serializer
Expand All @@ -15,6 +16,17 @@ annotation class Description(
val value: String
)

@JsonClassDiscriminator("type")
@OptIn(ExperimentalSerializationApi::class)
//@Serializable(with = UsableToolSerializer::class)
interface UsableTool {

fun use(
toolUseId: String
): ToolResult

}

@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class)
fun <T : UsableTool> KClass<T>.verify() {
// TODO how to get class serializer correctly?
Expand Down Expand Up @@ -51,12 +63,12 @@ fun <T : UsableTool> List<KClass<T>>.toSerializersModule(): SerializersModule =
}
}

inline fun <reified T : UsableTool> Tool(
description: String,
cacheControl: CacheControl? = null
): Tool = Tool(
name = anthropicTypeOf<T>(),
description = description,
inputSchema = jsonSchemaOf<T>(),
cacheControl = cacheControl
)
//inline fun <reified T : UsableTool> Tool(
// description: String,
// cacheControl: CacheControl? = null
//): Tool = Tool(
// name = anthropicTypeOf<T>(),
// description = description,
// inputSchema = jsonSchemaOf<T>(),
// cacheControl = cacheControl
//)
53 changes: 24 additions & 29 deletions src/commonTest/kotlin/AnthropicTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import com.xemantic.anthropic.message.MessageResponse
import com.xemantic.anthropic.message.Role
import com.xemantic.anthropic.message.StopReason
import com.xemantic.anthropic.message.Text
import com.xemantic.anthropic.message.Tool
import com.xemantic.anthropic.message.ToolUse
import kotlinx.coroutines.flow.filterIsInstance
import kotlinx.coroutines.flow.map
Expand Down Expand Up @@ -104,43 +103,39 @@ class AnthropicTest {
fun shouldUseCalculatorTool() = runTest {
// given
val client = Anthropic()
client.tools += Calculator::class

// // when
// val response = client.messages.create {
// +Message {
// +"What's 15 multiplied by 7?"
// }
// tools = listOf(calculatorTool)
// }
//
// // then
// response.apply {
// assertTrue(content.size == 1)
// assertTrue(content[0] is ToolUse)
// val toolUse = content[0] as ToolUse
// assertTrue(toolUse.name == "com_xemantic_anthropic_AnthropicTest_Calculator")
// val result = toolUse.use()
// assertTrue(result.toolUseId == toolUse.id)
// assertFalse(result.isError)
// assertTrue(result.content == listOf(Text(text = "${15.0 * 7.0}")))
// }
client.usableTools += Calculator::class

// when
val response = client.messages.create {
+Message {
+"What's 15 multiplied by 7?"
}
tools(Calculator::class)
}

// then
response.apply {
assertTrue(content.size == 1)
assertTrue(content[0] is ToolUse)
val toolUse = content[0] as ToolUse
assertTrue(toolUse.name == "com_xemantic_anthropic_AnthropicTest_Calculator")
val result = toolUse.use()
assertTrue(result.toolUseId == toolUse.id)
assertFalse(result.isError)
assertTrue(result.content == listOf(Text(text = "${15.0 * 7.0}")))
}
}

@Test
fun shouldUseCalculatorToolForFibonacci() = runTest {
// given
val client = Anthropic {
tools = listOf(FibonacciTool::class)
}
val fibonacciTool = Tool<FibonacciTool>(
description = "Calculates fibonacci number of a given n",
)
val client = Anthropic()
client.usableTools += FibonacciTool::class

// when
val response = client.messages.create {
+Message { +"What's fibonacci number 42" }
tools = listOf(fibonacciTool)
tools(FibonacciTool::class)
}

// then
Expand Down
27 changes: 6 additions & 21 deletions src/commonTest/kotlin/AnthropicTestTools.kt
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
package com.xemantic.anthropic

import com.xemantic.anthropic.message.SimpleUsableTool
import com.xemantic.anthropic.message.Text
import com.xemantic.anthropic.message.ToolResult
import com.xemantic.anthropic.message.UsableTool
import com.xemantic.anthropic.tool.Description
import com.xemantic.anthropic.tool.UsableTool
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.modules.polymorphic
import kotlinx.serialization.modules.subclass

@Serializable
@SerialName("fibonacci")
@Description("Calculate Fibonacci number n")
data class FibonacciTool(val n: Int): SimpleUsableTool {
data class FibonacciTool(val n: Int): UsableTool {

tailrec fun fibonacci(
n: Int, a: Int = 0, b: Int = 1
Expand All @@ -23,21 +19,18 @@ data class FibonacciTool(val n: Int): SimpleUsableTool {

override fun use(
toolUseId: String,
) = ToolResult(
toolUseId,
content = listOf(Text(text = "${fibonacci(n)}"))
)
) = ToolResult(toolUseId, "${fibonacci(n)}")

}

@Serializable
@SerialName("com_xemantic_anthropic_AnthropicTest_Calculator")
@SerialName("calculator")
@Description("Calculates the arithmetic outcome of an operation when given the arguments a and b")
data class Calculator(
val operation: Operation,
val a: Double,
val b: Double
): SimpleUsableTool {
): UsableTool {

@Suppress("unused") // it is used, but by Anthropic, so we skip the warning
enum class Operation(
Expand All @@ -55,11 +48,3 @@ data class Calculator(
)

}

// TODO this can be constructed on fly
val testToolsSerializersModule = SerializersModule {
polymorphic(UsableTool::class) {
subclass(Calculator::class)
subclass(FibonacciTool::class)
}
}

0 comments on commit 7338ee3

Please sign in to comment.