Skip to content

Commit

Permalink
Refactoring: new content package, document support (PDF upload), and …
Browse files Browse the repository at this point in the history
…simplified use of tools. (#15)
  • Loading branch information
morisil authored Nov 2, 2024
1 parent 79f77b1 commit b7bed78
Show file tree
Hide file tree
Showing 31 changed files with 658 additions and 318 deletions.
31 changes: 15 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,13 @@ If you want to write AI agents, you need tools, and this is where this library s
```kotlin
@AnthropicTool("get_weather")
@Description("Get the weather for a specific location")
data class WeatherTool(val location: String): ToolInput {
override fun use(
toolUseId: String
) = ToolResult(
toolUseId,
"The weather is 73f" // it should use some external service
)
data class WeatherTool(val location: String): ToolInput() {
init {
use {
// in the real world it should use some external service
+"The weather is 73f"
}
}
}

fun main() = runBlocking {
Expand Down Expand Up @@ -192,21 +192,20 @@ internet or DB connection pool to access the database.
```kotlin
@AnthropicTool("query_database")
@Description("Executes SQL on the database")
data class QueryDatabase(val sql: String): ToolInput {
data class QueryDatabase(val sql: String): ToolInput() {

@Transient
internal lateinit var connection: Connection

override fun use(
toolUseId: String
) = ToolResult(
toolUseId,
text = connection.prepareStatement(sql).use { statement ->
statement.resultSet.use { resultSet ->
resultSet.toString()
init {
use {
+connection.prepareStatement(sql).use { statement ->
statement.executeQuery().use { resultSet ->
resultSet.toString()
}
}
}
)
}

}

Expand Down
19 changes: 14 additions & 5 deletions src/commonMain/kotlin/Anthropic.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import com.xemantic.anthropic.error.AnthropicException
import com.xemantic.anthropic.error.ErrorResponse
import com.xemantic.anthropic.event.Event
import com.xemantic.anthropic.cache.CacheControl
import com.xemantic.anthropic.content.ToolUse
import com.xemantic.anthropic.message.MessageRequest
import com.xemantic.anthropic.message.MessageResponse
import com.xemantic.anthropic.tool.BuiltInTool
import com.xemantic.anthropic.tool.ToolUse
import com.xemantic.anthropic.tool.Tool
import com.xemantic.anthropic.tool.ToolInput
import io.ktor.client.HttpClient
Expand All @@ -27,7 +27,7 @@ import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map

Expand Down Expand Up @@ -95,6 +95,7 @@ class Anthropic internal constructor(

var tools: List<Tool> = emptyList()

// TODO in the future this should be rather Tool builder
inline fun <reified T : ToolInput> tool(
cacheControl: CacheControl? = null,
noinline inputInitializer: T.() -> Unit = {}
Expand Down Expand Up @@ -172,7 +173,14 @@ class Anthropic internal constructor(
is MessageResponse -> response.apply {
content.filterIsInstance<ToolUse>()
.forEach { toolUse ->
toolUse.tool = toolMap[toolUse.name]!!
val tool = toolMap[toolUse.name]
if (tool != null) {
toolUse.tool = tool
} else {
// Sometimes it happens that Claude is sending non-defined tool name in tool use
// TODO in the future it should go to the stderr
println("Error!!! Unexpected tool use: ${toolUse.name}")
}
}
}
is ErrorResponse -> throw AnthropicException(
Expand Down Expand Up @@ -206,8 +214,9 @@ class Anthropic internal constructor(
}
) {
incoming
.filter { it.data != null }
.map { anthropicJson.decodeFromString<Event>(it.data!!) }
.map { it.data }
.filterNotNull()
.map { anthropicJson.decodeFromString<Event>(it) }
.collect {
emit(it)
}
Expand Down
10 changes: 6 additions & 4 deletions src/commonMain/kotlin/AnthropicJson.kt
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package com.xemantic.anthropic

import com.xemantic.anthropic.batch.MessageBatchResponse
import com.xemantic.anthropic.content.Document
import com.xemantic.anthropic.error.ErrorResponse
import com.xemantic.anthropic.image.Image
import com.xemantic.anthropic.content.Image
import com.xemantic.anthropic.message.Content
import com.xemantic.anthropic.message.MessageResponse
import com.xemantic.anthropic.text.Text
import com.xemantic.anthropic.content.Text
import com.xemantic.anthropic.content.ToolResult
import com.xemantic.anthropic.content.ToolUse
import com.xemantic.anthropic.tool.BuiltInTool
import com.xemantic.anthropic.tool.DefaultTool
import com.xemantic.anthropic.tool.Tool
import com.xemantic.anthropic.tool.ToolResult
import com.xemantic.anthropic.tool.ToolUse
import com.xemantic.anthropic.tool.bash.Bash
import com.xemantic.anthropic.tool.computer.Computer
import com.xemantic.anthropic.tool.editor.TextEditor
Expand Down Expand Up @@ -42,6 +43,7 @@ private val anthropicSerializersModule = SerializersModule {
subclass(Image::class)
subclass(ToolUse::class)
subclass(ToolResult::class)
subclass(Document::class)
}
polymorphicDefaultDeserializer(Tool::class) { ToolSerializer }
polymorphicDefaultSerializer(Tool::class) { ToolSerializer }
Expand Down
25 changes: 25 additions & 0 deletions src/commonMain/kotlin/content/ContentBuilder.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.xemantic.anthropic.content

import com.xemantic.anthropic.message.Content

interface ContentBuilder {

val content: MutableList<Content>

operator fun Content.unaryPlus() {
content += this
}

operator fun String.unaryPlus() {
content += Text(this)
}

operator fun Number.unaryPlus() {
content += Text(this.toString())
}

operator fun Collection<Content>.unaryPlus() {
content += this
}

}
42 changes: 42 additions & 0 deletions src/commonMain/kotlin/content/Document.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.xemantic.anthropic.content

import com.xemantic.anthropic.cache.CacheControl
import com.xemantic.anthropic.message.Content
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
@SerialName("document")
data class Document(
val source: Source,
@SerialName("cache_control")
override val cacheControl: CacheControl? = null
) : Content() {

enum class MediaType {
@SerialName("application/pdf")
APPLICATION_PDF
}

@Serializable
data class Source(
val type: Type = Type.BASE64,
@SerialName("media_type")
val mediaType: MediaType,
val data: String
) {

enum class Type {
@SerialName("base64")
BASE64
}

}

class Builder {
var data: ByteArray? = null
var mediaType: MediaType? = null
var cacheControl: CacheControl? = null
}

}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xemantic.anthropic.image
package com.xemantic.anthropic.content

import com.xemantic.anthropic.cache.CacheControl
import com.xemantic.anthropic.message.Content
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.xemantic.anthropic.text
package com.xemantic.anthropic.content

import com.xemantic.anthropic.cache.CacheControl
import com.xemantic.anthropic.message.Content
Expand Down
103 changes: 103 additions & 0 deletions src/commonMain/kotlin/content/Tool.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package com.xemantic.anthropic.content

import com.xemantic.anthropic.anthropicJson
import com.xemantic.anthropic.cache.CacheControl
import com.xemantic.anthropic.message.Content
import com.xemantic.anthropic.message.toNullIfEmpty
import com.xemantic.anthropic.tool.Tool
import com.xemantic.anthropic.tool.ToolInput
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.Transient
import kotlinx.serialization.json.JsonObject
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract

@Serializable
@SerialName("tool_use")
data class ToolUse(
val id: String,
val name: String,
val input: JsonObject,
@SerialName("cache_control")
override val cacheControl: CacheControl? = null
) : Content() {

@Transient
@PublishedApi
internal lateinit var tool: Tool

@PublishedApi
internal fun decodeInput() = anthropicJson.decodeFromJsonElement(
deserializer = tool.inputSerializer,
element = input
).apply(tool.inputInitializer)

inline fun <reified T : ToolInput> input(): T = (decodeInput() as T)

suspend fun use(): ToolResult {
return try {
if (::tool.isInitialized) {
val toolInput = decodeInput()
toolInput.use(toolUseId = id)
} else {
ToolResult(toolUseId = id) {
error("Cannot use unknown tool: $name")
}
}
} catch (e: Exception) {
e.printStackTrace()
ToolResult(toolUseId = id) {
error(e.message ?: "Unknown error occurred")
}
}
}

}

@Serializable
@SerialName("tool_result")
data class ToolResult(
@SerialName("tool_use_id")
val toolUseId: String,
val content: List<Content>? = null,
@SerialName("is_error")
val isError: Boolean? = false,
@SerialName("cache_control")
override val cacheControl: CacheControl? = null
) : Content() {

class Builder : ContentBuilder {

override val content: MutableList<Content> = mutableListOf()

var isError: Boolean? = null
var cacheControl: CacheControl? = null

fun error(message: String) {
+message
isError = true
}

}

}

@OptIn(ExperimentalContracts::class)
inline fun ToolResult(
toolUseId: String,
block: ToolResult.Builder.() -> Unit = {}
): ToolResult {
contract {
callsInPlace(block, InvocationKind.EXACTLY_ONCE)
}
val builder = ToolResult.Builder()
block(builder)
return ToolResult(
toolUseId = toolUseId,
content = builder.content.toNullIfEmpty(),
isError = if (builder.isError == null) false else null,
cacheControl = builder.cacheControl
)
}
26 changes: 8 additions & 18 deletions src/commonMain/kotlin/message/Messages.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.xemantic.anthropic.message
import com.xemantic.anthropic.Model
import com.xemantic.anthropic.Response
import com.xemantic.anthropic.cache.CacheControl
import com.xemantic.anthropic.text.Text
import com.xemantic.anthropic.content.ContentBuilder
import com.xemantic.anthropic.tool.Tool
import com.xemantic.anthropic.tool.ToolChoice
import com.xemantic.anthropic.tool.ToolInput
Expand Down Expand Up @@ -88,7 +88,7 @@ data class MessageRequest(

/**
* Sets both, the [tools] list and the [toolChoice] with
* just one tool to use, forcing the API to respond with the [com.xemantic.anthropic.tool.ToolUse].
* just one tool to use, forcing the API to respond with the [com.xemantic.anthropic.content.ToolUse].
*/
inline fun <reified T : ToolInput> singleTool() {
val name = toolName<T>()
Expand All @@ -103,14 +103,14 @@ data class MessageRequest(
/**
* Sets both, the [tools] list and the [toolChoice] with
* just one tool to use, forcing the API to respond with the
* [com.xemantic.anthropic.tool.ToolUse] instance.
* [com.xemantic.anthropic.content.ToolUse] instance.
*/
fun chooseTool(name: String) {
val tool = requireNotNull(toolMap[name]) {
"No tool with such name defined in Anthropic client: $name"
}
tools = listOf(tool)
toolChoice = ToolChoice.Tool(name = tool.name)
toolChoice = ToolChoice.Tool(name = tool.name, disableParallelToolUse = true)
}

fun messages(vararg messages: Message) {
Expand Down Expand Up @@ -176,21 +176,11 @@ data class Message(
val content: List<Content>
) {

class Builder {
var role = Role.USER
val content = mutableListOf<Content>()

operator fun Content.unaryPlus() {
content += this
}
class Builder : ContentBuilder {

operator fun List<Content>.unaryPlus() {
content += this
}
override val content = mutableListOf<Content>()

operator fun String.unaryPlus() {
content += Text(this)
}
var role = Role.USER

fun build() = Message(
role = role,
Expand All @@ -215,7 +205,7 @@ data class System(
) {

enum class Type {
@SerialName("text")
@SerialName("content/text")
TEXT
}

Expand Down
Loading

0 comments on commit b7bed78

Please sign in to comment.