Skip to content

Commit

Permalink
better tools support, initial attempt WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
morisil committed Oct 10, 2024
1 parent 94efe63 commit 6872a5a
Show file tree
Hide file tree
Showing 9 changed files with 373 additions and 56 deletions.
7 changes: 2 additions & 5 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ kotlin {
dependencies {
implementation(libs.kotlin.test)
implementation(libs.kotlinx.coroutines.test)
implementation(libs.kotest.assertions.core)
implementation(libs.kotest.assertions.json)
}
}
Expand Down Expand Up @@ -131,11 +132,7 @@ tasks.withType<Test> {
@Suppress("OPT_IN_USAGE")
powerAssert {
functions = listOf(
"kotlin.assert",
"kotlin.test.assertTrue",
"kotlin.test.assertFalse",
"kotlin.test.assertEquals",
"kotlin.test.assertNull"
"com.xemantic.anthropic.test.shouldBe"
)
includedSourceSets = listOf("commonTest", "jvmTest", "nativeTest")
}
Expand Down
5 changes: 3 additions & 2 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
kotlinTarget = "2.0"
javaTarget = "17"

kotlin = "2.0.20"
kotlin = "2.0.21"
kotlinxCoroutines = "1.9.0"
ktor = "3.0.0-rc-2"
ktor = "3.0.0"
kotest = "5.9.1"

log4j = "2.24.1"
Expand All @@ -30,6 +30,7 @@ ktor-serialization-kotlinx-json = { module = "io.ktor:ktor-serialization-kotlinx
ktor-client-java = { module = "io.ktor:ktor-client-java", version.ref = "ktor" }
ktor-client-curl = { module = "io.ktor:ktor-client-curl", version.ref = "ktor" }

kotest-assertions-core = { module = "io.kotest:kotest-assertions-core", version.ref = "kotest" }
kotest-assertions-json = { module = "io.kotest:kotest-assertions-json", version.ref = "kotest" }

[plugins]
Expand Down
48 changes: 42 additions & 6 deletions src/commonMain/kotlin/Anthropic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import com.xemantic.anthropic.message.Error
import com.xemantic.anthropic.message.ErrorResponse
import com.xemantic.anthropic.message.MessageRequest
import com.xemantic.anthropic.message.MessageResponse
import com.xemantic.anthropic.message.Tool
import com.xemantic.anthropic.message.ToolUse
import com.xemantic.anthropic.tool.UsableTool
import io.ktor.client.HttpClient
import io.ktor.client.call.body
Expand All @@ -27,7 +29,9 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map
import kotlinx.serialization.KSerializer
import kotlinx.serialization.json.Json
import kotlinx.serialization.serializer
import kotlin.reflect.KClass

const val ANTHROPIC_API_BASE: String = "https://api.anthropic.com/"
Expand Down Expand Up @@ -67,9 +71,10 @@ fun Anthropic(
anthropicBeta = config.anthropicBeta,
apiBase = config.apiBase,
defaultModel = defaultModel,
directBrowserAccess = config.directBrowserAccess,
directBrowserAccess = config.directBrowserAccess
).apply {
usableTools = config.usableTools
)
}
}

class Anthropic internal constructor(
Expand All @@ -78,8 +83,7 @@ class Anthropic internal constructor(
val anthropicBeta: String?,
val apiBase: String,
val defaultModel: String,
val directBrowserAccess: Boolean,
val usableTools: MutableList<KClass<out UsableTool>> = mutableListOf()
val directBrowserAccess: Boolean
) {

class Config {
Expand All @@ -89,7 +93,36 @@ class Anthropic internal constructor(
var apiBase: String = ANTHROPIC_API_BASE
var defaultModel: String? = null
var directBrowserAccess: Boolean = false
var usableTools: MutableList<KClass<out UsableTool>> = mutableListOf()
var usableTools: List<KClass<out UsableTool>> = emptyList()

inline fun <reified T : UsableTool> tool(
block: T.() -> Unit = {}
) {
usableTools += T::class
}
}

private class ToolEntry(
val tool: Tool,
)

private var toolSerializerMap = mapOf<String, KSerializer<out UsableTool>>()

var usableTools: List<KClass<out UsableTool>> = emptyList()
get() = field
set(value) {
value.validate()
field = value
}

inline fun <reified T : UsableTool> tool() {
usableTools += T::class
}

fun List<KClass<out UsableTool>>.validate() {
forEach { tool ->
//tool.serializer()
}
}

private val client = HttpClient {
Expand Down Expand Up @@ -126,7 +159,10 @@ class Anthropic internal constructor(
setBody(request)
}
if (response.status.isSuccess()) {
return response.body<MessageResponse>()
return response.body<MessageResponse>().apply {
content.filterIsInstance<ToolUse>()
.forEach { it.toolSerializerMap = toolSerializerMap }
}
} else {
throw AnthropicException(
error = response.body<ErrorResponse>().error,
Expand Down
18 changes: 17 additions & 1 deletion src/commonMain/kotlin/message/Messages.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.xemantic.anthropic.message

import com.xemantic.anthropic.anthropicJson
import com.xemantic.anthropic.schema.JsonSchema
import com.xemantic.anthropic.tool.UsableTool
import kotlinx.serialization.*
Expand Down Expand Up @@ -64,6 +65,10 @@ data class MessageRequest(
val topK: Int? = null
val topP: Int? = null

fun useTools() {
//too
}

fun tools(vararg classes: KClass<out UsableTool>) {
// TODO it needs access to Anthropic, therefore either needs a constructor parameter, or needs to be inner class
}
Expand Down Expand Up @@ -278,7 +283,18 @@ data class ToolUse(
val id: String,
val name: String,
val input: JsonObject
) : Content()
) : Content() {

@Transient
internal lateinit var toolSerializerMap: Map<String, KSerializer<out UsableTool>>

fun use(): ToolResult {
val serializer = toolSerializerMap[name]!!
val tool = anthropicJson.decodeFromJsonElement(serializer, input)
return tool.use(toolUseId = id)
}

}

@Serializable
@SerialName("tool_result")
Expand Down
87 changes: 60 additions & 27 deletions src/commonMain/kotlin/tool/Tools.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,27 @@ import com.xemantic.anthropic.message.CacheControl
import com.xemantic.anthropic.message.Tool
import com.xemantic.anthropic.message.ToolResult
import com.xemantic.anthropic.message.ToolUse
import com.xemantic.anthropic.schema.jsonSchemaOf
import com.xemantic.anthropic.schema.toJsonSchema
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.InternalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.json.JsonClassDiscriminator
import kotlinx.serialization.MetaSerializable
import kotlinx.serialization.SerializationException
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.modules.polymorphic
import kotlinx.serialization.serializer
import kotlin.reflect.KClass

annotation class Description(
val value: String
@OptIn(ExperimentalSerializationApi::class)
@MetaSerializable
@Target(AnnotationTarget.CLASS)
annotation class SerializableTool(
val name: String,
val description: String
)

@JsonClassDiscriminator("name")
@OptIn(ExperimentalSerializationApi::class)
//@Serializable(with = UsableToolSerializer::class)
interface UsableTool {

fun use(
Expand All @@ -30,35 +34,64 @@ interface UsableTool {

}

class ToolContentSerializer() {

}
fun Tool.cacheControl(
cacheControl: CacheControl? = null
): Tool = if (cacheControl == null) this else Tool(
name,
description,
inputSchema,
cacheControl
)

@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class)
fun <T : UsableTool> KClass<T>.verify() {
// TODO how to get class serializer correctly?
checkNotNull(serializer()) {
"Invalid tool definition, not serializer for class ${this@verify}"
@OptIn(ExperimentalSerializationApi::class)
inline fun <reified T : UsableTool> toolOf(): Tool {
val serializer = try {
serializer<T>()
} catch (e :SerializationException) {
throw SerializationException("The class ${T::class.qualifiedName} must be annotated with @SerializableTool", e)
}
checkNotNull(serializer().descriptor.annotations.filterIsInstance<Description>().firstOrNull()) {
"Not @Description annotation specified for the tool"
val description = checkNotNull(
serializer
.descriptor
.annotations
.filterIsInstance<SerializableTool>()
.firstOrNull()
) {
"No @Description annotation found for ${T::class.qualifiedName}"
}
}

@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class)
fun <T : UsableTool> KClass<T>.instance(
cacheControl: CacheControl? = null
): Tool {
val descriptor = serializer().descriptor
val description = descriptor.annotations.filterIsInstance<Description>().firstOrNull()!!.value
return Tool(
name = descriptor.serialName,
description = description,
inputSchema = toJsonSchema(),
cacheControl = cacheControl
name = description.name,
description = description.description,
inputSchema = jsonSchemaOf<T>(),
cacheControl = null
)
}

//@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class)
//fun <T : UsableTool> KClass<T>.verify() {
// // TODO how to get class serializer correctly?
// checkNotNull(serializer()) {
// "Invalid tool definition, not serializer for class ${this@verify}"
// }
// checkNotNull(serializer().descriptor.annotations.filterIsInstance<Description>().firstOrNull()) {
// "Not @Description annotation specified for the tool"
// }
//}

//@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class)
//fun <T : UsableTool> KClass<T>.instance(
// cacheControl: CacheControl? = null
//): Tool {
// val descriptor = serializer().descriptor
// val description = descriptor.annotations.filterIsInstance<Description>().firstOrNull()!!.value
// return Tool(
// name = descriptor.serialName,
// description = description,
// inputSchema = toJsonSchema(),
// cacheControl = cacheControl
// )
//}

//inline fun <reified T> anthropicTypeOf(): String =
// T::class.qualifiedName!!.replace('.', '_')

Expand Down
Loading

0 comments on commit 6872a5a

Please sign in to comment.