Skip to content

Commit

Permalink
Add handling for Optional types
Browse files Browse the repository at this point in the history
  • Loading branch information
Tarjei400 committed Nov 5, 2024
1 parent bd88ad9 commit 24a45f0
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ public class ToolProcessor {
Object.class);
private static final Logger log = Logger.getLogger(ToolProcessor.class);

public static final DotName OPTIONAL = DotName.createSimple("java.util.Optional");
public static final DotName OPTIONAL_INT = DotName.createSimple("java.util.OptionalInt");
public static final DotName OPTIONAL_LONG = DotName.createSimple("java.util.OptionalLong");
public static final DotName OPTIONAL_DOUBLE = DotName.createSimple("java.util.OptionalDouble");

private static final DotName DATE = DotName.createSimple("java.util.Date");
private static final DotName LOCAL_DATE = DotName.createSimple("java.time.LocalDate");
private static final DotName LOCAL_DATE_TIME = DotName.createSimple("java.time.LocalDateTime");
private static final DotName OFFSET_DATE_TIME = DotName.createSimple("java.time.OffsetDateTime");

@BuildStep
public void telemetry(Capabilities capabilities, BuildProducer<AdditionalBeanBuildItem> additionalBeanProducer) {
var addOpenTelemetrySpan = capabilities.isPresent(Capability.OPENTELEMETRY_TRACER);
Expand Down Expand Up @@ -453,7 +463,15 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
|| DotNames.BIG_DECIMAL.equals(typeName)) {
return removeNulls(NUMBER, description);
}
if (LOCAL_DATE_TIME.equals(typeName) || OFFSET_DATE_TIME.equals(typeName)) {
return removeNulls(JsonSchemaProperty.from("type", "string"), JsonSchemaProperty.from("format", "date-time"),
description);
}

if (DATE.equals(typeName) || LOCAL_DATE.equals(typeName)) {
return removeNulls(JsonSchemaProperty.from("type", "string"), JsonSchemaProperty.from("format", "date"),
description);
}
// TODO something else?
if (type.kind() == Type.Kind.ARRAY || DotNames.LIST.equals(typeName) || DotNames.SET.equals(typeName)) {
ParameterizedType parameterizedType = type.kind() == Type.Kind.PARAMETERIZED_TYPE ? type.asParameterizedType()
Expand Down Expand Up @@ -488,11 +506,18 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
ClassInfo classInfo = index.getClassByName(type.name());

List<String> required = new ArrayList<>();

if (classInfo != null) {
for (FieldInfo field : classInfo.fields()) {
String fieldName = field.name();
Type fieldType = field.type();

Iterable<JsonSchemaProperty> fieldSchema = toJsonSchemaProperties(field.type(), index, null);
boolean isOptional = isJavaOptionalType(fieldType);
if (isOptional) {
fieldType = unwrapOptionalType(fieldType);
}

Iterable<JsonSchemaProperty> fieldSchema = toJsonSchemaProperties(fieldType, index, null);
Map<String, Object> fieldDescription = new HashMap<>();

for (JsonSchemaProperty fieldProperty : fieldSchema) {
Expand All @@ -506,6 +531,10 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
fieldDescription.put("description", String.join(",", descriptionValue));
}
}
if (!isOptional) {
required.add(fieldName);
}

properties.put(fieldName, fieldDescription);
}
}
Expand All @@ -517,10 +546,39 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
throw new IllegalArgumentException("Unsupported type: " + type);
}

private boolean isJavaOptionalType(Type type) {
DotName typeName = type.name();
return typeName.equals(DotName.createSimple("java.util.Optional"))
|| typeName.equals(DotName.createSimple("java.util.OptionalInt"))
|| typeName.equals(DotName.createSimple("java.util.OptionalLong"))
|| typeName.equals(DotName.createSimple("java.util.OptionalDouble"));
}

private Type unwrapOptionalType(Type optionalType) {
if (optionalType.kind() == Type.Kind.PARAMETERIZED_TYPE) {
ParameterizedType parameterizedType = optionalType.asParameterizedType();
return parameterizedType.arguments().get(0);
}
return optionalType;
}

private boolean isComplexType(Type type) {
return type.kind() == Type.Kind.CLASS || type.kind() == Type.Kind.PARAMETERIZED_TYPE;
}

private boolean isOptionalField(FieldInfo field, IndexView index) {
Type fieldType = field.type();
DotName fieldTypeName = fieldType.name();

if (OPTIONAL.equals(fieldTypeName) || OPTIONAL_INT.equals(fieldTypeName) || OPTIONAL_LONG.equals(fieldTypeName)
|| OPTIONAL_DOUBLE.equals(fieldTypeName)) {
return true;
}

return false;

}

private Iterable<JsonSchemaProperty> removeNulls(JsonSchemaProperty... properties) {
return stream(properties)
.filter(Objects::nonNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,13 @@ public <T> T fromJson(String json, Type type) {
if (e instanceof JsonParseException && isEnumType(type)) {
// this is the case where LangChain4j simply passes the string value of the enum to Json.fromJson()
// and Jackson does not handle it
Class<? extends Enum> enumClass = (Class<? extends Enum>) ((ParameterizedType) type).getRawType();
return (T) Enum.valueOf(enumClass, json);
if (type instanceof ParameterizedType) {
Class<? extends Enum> enumClass = (Class<? extends Enum>) ((ParameterizedType) type).getRawType();
return (T) Enum.valueOf(enumClass, json);
} else {

return (T) Enum.valueOf((Class<? extends Enum>) type, json);
}
}
throw new UncheckedIOException(e);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
package io.quarkiverse.langchain4j.runtime;

import java.lang.reflect.*;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.OffsetDateTime;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import com.fasterxml.jackson.databind.ObjectMapper;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.service.Result;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.TypeUtils;
//import dev.langchain4j.service.output.OutputParser;
import dev.langchain4j.service.output.ServiceOutputParser;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
import io.smallrye.mutiny.Multi;
Expand All @@ -23,13 +26,18 @@ public class QuarkusServiceOutputParser extends ServiceOutputParser {

@Override
public String outputFormatInstructions(Type returnType) {
Class<?> rawClass = getRawClass(returnType);
boolean isOptional = isJavaOptional(returnType);
Type actualType = isOptional ? unwrapOptionalType(returnType) : returnType;

Class<?> rawClass = getRawClass(actualType);

if (rawClass != String.class && rawClass != AiMessage.class && rawClass != TokenStream.class
&& rawClass != ChatMessage.class
&& rawClass != Response.class && !Multi.class.equals(rawClass)) {
try {
var schema = this.toJsonSchema(returnType);
return "You must answer strictly with json according to the following json schema format: " + schema;
return "You must answer strictly with json according to the following json schema format. Use description metadata to fill data properly: "
+ schema;
} catch (Exception e) {
return "";
}
Expand All @@ -52,7 +60,7 @@ public Object parse(Response<AiMessage> response, Type returnType) {
return response;
} else {
AiMessage aiMessage = response.content();
if (rawReturnClass == AiMessage.class) {
if (rawReturnClass == AiMessage.class || rawReturnClass == ChatMessage.class) {
return aiMessage;
} else {
String text = aiMessage.text();
Expand All @@ -77,7 +85,10 @@ private String extractJsonBlock(String text) {

public String toJsonSchema(Type type) throws Exception {
Map<String, Object> schema = new HashMap<>();
Class<?> rawClass = getRawClass(type);
boolean isOptional = isJavaOptional(type);
Type actualType = isOptional ? unwrapOptionalType(type) : type;

Class<?> rawClass = getRawClass(actualType);

if (type instanceof WildcardType wildcardType) {
Type boundType = wildcardType.getUpperBounds().length > 0 ? wildcardType.getUpperBounds()[0]
Expand All @@ -97,29 +108,77 @@ public String toJsonSchema(Type type) throws Exception {
Type elementType = getElementType(type);
Map<String, Object> itemsSchema = toJsonSchemaMap(elementType);
schema.put("items", itemsSchema);
} else if (rawClass == LocalDate.class || rawClass == Date.class) {
schema.put("type", "string");
schema.put("format", "date");
} else if (rawClass == LocalDateTime.class || rawClass == OffsetDateTime.class) {
schema.put("type", "string");
schema.put("format", "date-time");
} else if (rawClass.isEnum()) {
schema.put("type", "string");
schema.put("enum", getEnumConstants(rawClass));
} else {
schema.put("type", "object");
Map<String, Object> properties = new HashMap<>();

List<String> required = new ArrayList<>();
for (Field field : rawClass.getDeclaredFields()) {
field.setAccessible(true);
Map<String, Object> fieldSchema = toJsonSchemaMap(field.getGenericType());
properties.put(field.getName(), fieldSchema);
if (field.isAnnotationPresent(Description.class)) {
Description description = field.getAnnotation(Description.class);
fieldSchema.put("description", description.value());
try {
field.setAccessible(true);
Type fieldType = field.getGenericType();

// Check if the field is Optional and unwrap it if necessary
boolean fieldIsOptional = isJavaOptional(fieldType);
Type fieldActualType = fieldIsOptional ? unwrapOptionalType(fieldType) : fieldType;

Map<String, Object> fieldSchema = toJsonSchemaMap(fieldActualType);
properties.put(field.getName(), fieldSchema);

if (field.isAnnotationPresent(Description.class)) {
Description description = field.getAnnotation(Description.class);
fieldSchema.put("description", String.join(",", description.value()));
}

// Only add to required if it is not Optional
if (!fieldIsOptional) {
required.add(field.getName());
} else {
fieldSchema.put("nullable", true); // Mark as nullable in the JSON schema
}

} catch (Exception e) {

}

}
schema.put("properties", properties);
if (!required.isEmpty()) {
schema.put("required", required);
}
}
if (isOptional) {
schema.put("nullable", true);
}

ObjectMapper mapper = new ObjectMapper();
return mapper.writeValueAsString(schema); // Convert the schema map to a JSON string
}

private boolean isJavaOptional(Type type) {
if (type instanceof ParameterizedType) {
Type rawType = ((ParameterizedType) type).getRawType();
return rawType == Optional.class || rawType == OptionalInt.class || rawType == OptionalLong.class
|| rawType == OptionalDouble.class;
}
return false;
}

private Type unwrapOptionalType(Type optionalType) {
if (optionalType instanceof ParameterizedType) {
return ((ParameterizedType) optionalType).getActualTypeArguments()[0];
}
return optionalType;
}

private Class<?> getRawClass(Type type) {
if (type instanceof Class<?>) {
return (Class<?>) type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

import jakarta.annotation.PreDestroy;
Expand All @@ -12,6 +13,8 @@

import org.jboss.resteasy.reactive.RestQuery;

import com.fasterxml.jackson.annotation.JsonProperty;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
Expand All @@ -30,18 +33,24 @@ public AssistantWithToolsResource(Assistant assistant) {

public static class TestData {
@Description("Foo description for structured output")
@JsonProperty("foo")
String foo;

@Description("Foo description for structured output")
@JsonProperty("bar")
Integer bar;

@Description("Foo description for structured output")
Double baz;
@JsonProperty("baz")
Optional<Double> baz;

public TestData() {
}

TestData(String foo, Integer bar, Double baz) {
this.foo = foo;
this.bar = bar;
this.baz = baz;
this.baz = Optional.of(baz);
}
}

Expand All @@ -50,18 +59,11 @@ public String get(@RestQuery String message) {
return assistant.chat(message);
}

@GET
@Path("/many")
public List<TestData> getMany(@RestQuery String message) {
return assistant.chats(message);
}

@RegisterAiService(tools = Calculator.class, chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class)
public interface Assistant {

String chat(String userMessage);

List<TestData> chats(String userMessage);
}

@Singleton
Expand Down
Loading

0 comments on commit 24a45f0

Please sign in to comment.