From 24a45f0d601e21638a247657f6c749df75bbde55 Mon Sep 17 00:00:00 2001 From: Adrian Jutrowski Date: Mon, 4 Nov 2024 23:22:29 +0100 Subject: [PATCH] Add handling for Optional types --- .../langchain4j/deployment/ToolProcessor.java | 60 +++++++++++++- .../langchain4j/QuarkusJsonCodecFactory.java | 9 +- .../runtime/QuarkusServiceOutputParser.java | 83 ++++++++++++++++--- .../AssistantWithToolsResource.java | 20 +++-- .../aiservices/EntityMappedResource.java | 43 +++++++--- ...ssistantResourceWithEntityMappingTest.java | 18 ++-- .../AssistantResourceWithToolsTest.java | 10 --- 7 files changed, 184 insertions(+), 59 deletions(-) diff --git a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java index 4fcd861e5..36b5782a5 100644 --- a/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java +++ b/core/deployment/src/main/java/io/quarkiverse/langchain4j/deployment/ToolProcessor.java @@ -83,6 +83,16 @@ public class ToolProcessor { Object.class); private static final Logger log = Logger.getLogger(ToolProcessor.class); + public static final DotName OPTIONAL = DotName.createSimple("java.util.Optional"); + public static final DotName OPTIONAL_INT = DotName.createSimple("java.util.OptionalInt"); + public static final DotName OPTIONAL_LONG = DotName.createSimple("java.util.OptionalLong"); + public static final DotName OPTIONAL_DOUBLE = DotName.createSimple("java.util.OptionalDouble"); + + private static final DotName DATE = DotName.createSimple("java.util.Date"); + private static final DotName LOCAL_DATE = DotName.createSimple("java.time.LocalDate"); + private static final DotName LOCAL_DATE_TIME = DotName.createSimple("java.time.LocalDateTime"); + private static final DotName OFFSET_DATE_TIME = DotName.createSimple("java.time.OffsetDateTime"); + @BuildStep public void telemetry(Capabilities capabilities, BuildProducer additionalBeanProducer) { var addOpenTelemetrySpan = capabilities.isPresent(Capability.OPENTELEMETRY_TRACER); @@ -453,7 +463,15 @@ private Iterable toJsonSchemaProperties(Type type, IndexView || DotNames.BIG_DECIMAL.equals(typeName)) { return removeNulls(NUMBER, description); } + if (LOCAL_DATE_TIME.equals(typeName) || OFFSET_DATE_TIME.equals(typeName)) { + return removeNulls(JsonSchemaProperty.from("type", "string"), JsonSchemaProperty.from("format", "date-time"), + description); + } + if (DATE.equals(typeName) || LOCAL_DATE.equals(typeName)) { + return removeNulls(JsonSchemaProperty.from("type", "string"), JsonSchemaProperty.from("format", "date"), + description); + } // TODO something else? if (type.kind() == Type.Kind.ARRAY || DotNames.LIST.equals(typeName) || DotNames.SET.equals(typeName)) { ParameterizedType parameterizedType = type.kind() == Type.Kind.PARAMETERIZED_TYPE ? type.asParameterizedType() @@ -488,11 +506,18 @@ private Iterable toJsonSchemaProperties(Type type, IndexView ClassInfo classInfo = index.getClassByName(type.name()); List required = new ArrayList<>(); + if (classInfo != null) { for (FieldInfo field : classInfo.fields()) { String fieldName = field.name(); + Type fieldType = field.type(); - Iterable fieldSchema = toJsonSchemaProperties(field.type(), index, null); + boolean isOptional = isJavaOptionalType(fieldType); + if (isOptional) { + fieldType = unwrapOptionalType(fieldType); + } + + Iterable fieldSchema = toJsonSchemaProperties(fieldType, index, null); Map fieldDescription = new HashMap<>(); for (JsonSchemaProperty fieldProperty : fieldSchema) { @@ -506,6 +531,10 @@ private Iterable toJsonSchemaProperties(Type type, IndexView fieldDescription.put("description", String.join(",", descriptionValue)); } } + if (!isOptional) { + required.add(fieldName); + } + properties.put(fieldName, fieldDescription); } } @@ -517,10 +546,39 @@ private Iterable toJsonSchemaProperties(Type type, IndexView throw new IllegalArgumentException("Unsupported type: " + type); } + private boolean isJavaOptionalType(Type type) { + DotName typeName = type.name(); + return typeName.equals(DotName.createSimple("java.util.Optional")) + || typeName.equals(DotName.createSimple("java.util.OptionalInt")) + || typeName.equals(DotName.createSimple("java.util.OptionalLong")) + || typeName.equals(DotName.createSimple("java.util.OptionalDouble")); + } + + private Type unwrapOptionalType(Type optionalType) { + if (optionalType.kind() == Type.Kind.PARAMETERIZED_TYPE) { + ParameterizedType parameterizedType = optionalType.asParameterizedType(); + return parameterizedType.arguments().get(0); + } + return optionalType; + } + private boolean isComplexType(Type type) { return type.kind() == Type.Kind.CLASS || type.kind() == Type.Kind.PARAMETERIZED_TYPE; } + private boolean isOptionalField(FieldInfo field, IndexView index) { + Type fieldType = field.type(); + DotName fieldTypeName = fieldType.name(); + + if (OPTIONAL.equals(fieldTypeName) || OPTIONAL_INT.equals(fieldTypeName) || OPTIONAL_LONG.equals(fieldTypeName) + || OPTIONAL_DOUBLE.equals(fieldTypeName)) { + return true; + } + + return false; + + } + private Iterable removeNulls(JsonSchemaProperty... properties) { return stream(properties) .filter(Objects::nonNull) diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusJsonCodecFactory.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusJsonCodecFactory.java index f277da900..4914a3adc 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusJsonCodecFactory.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/QuarkusJsonCodecFactory.java @@ -72,8 +72,13 @@ public T fromJson(String json, Type type) { if (e instanceof JsonParseException && isEnumType(type)) { // this is the case where LangChain4j simply passes the string value of the enum to Json.fromJson() // and Jackson does not handle it - Class enumClass = (Class) ((ParameterizedType) type).getRawType(); - return (T) Enum.valueOf(enumClass, json); + if (type instanceof ParameterizedType) { + Class enumClass = (Class) ((ParameterizedType) type).getRawType(); + return (T) Enum.valueOf(enumClass, json); + } else { + + return (T) Enum.valueOf((Class) type, json); + } } throw new UncheckedIOException(e); } diff --git a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/QuarkusServiceOutputParser.java b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/QuarkusServiceOutputParser.java index d801450e8..1c4f5ccd6 100644 --- a/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/QuarkusServiceOutputParser.java +++ b/core/runtime/src/main/java/io/quarkiverse/langchain4j/runtime/QuarkusServiceOutputParser.java @@ -1,6 +1,9 @@ package io.quarkiverse.langchain4j.runtime; import java.lang.reflect.*; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -8,12 +11,12 @@ import com.fasterxml.jackson.databind.ObjectMapper; import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.structured.Description; import dev.langchain4j.service.Result; import dev.langchain4j.service.TokenStream; import dev.langchain4j.service.TypeUtils; -//import dev.langchain4j.service.output.OutputParser; import dev.langchain4j.service.output.ServiceOutputParser; import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory; import io.smallrye.mutiny.Multi; @@ -23,13 +26,18 @@ public class QuarkusServiceOutputParser extends ServiceOutputParser { @Override public String outputFormatInstructions(Type returnType) { - Class rawClass = getRawClass(returnType); + boolean isOptional = isJavaOptional(returnType); + Type actualType = isOptional ? unwrapOptionalType(returnType) : returnType; + + Class rawClass = getRawClass(actualType); if (rawClass != String.class && rawClass != AiMessage.class && rawClass != TokenStream.class + && rawClass != ChatMessage.class && rawClass != Response.class && !Multi.class.equals(rawClass)) { try { var schema = this.toJsonSchema(returnType); - return "You must answer strictly with json according to the following json schema format: " + schema; + return "You must answer strictly with json according to the following json schema format. Use description metadata to fill data properly: " + + schema; } catch (Exception e) { return ""; } @@ -52,7 +60,7 @@ public Object parse(Response response, Type returnType) { return response; } else { AiMessage aiMessage = response.content(); - if (rawReturnClass == AiMessage.class) { + if (rawReturnClass == AiMessage.class || rawReturnClass == ChatMessage.class) { return aiMessage; } else { String text = aiMessage.text(); @@ -77,7 +85,10 @@ private String extractJsonBlock(String text) { public String toJsonSchema(Type type) throws Exception { Map schema = new HashMap<>(); - Class rawClass = getRawClass(type); + boolean isOptional = isJavaOptional(type); + Type actualType = isOptional ? unwrapOptionalType(type) : type; + + Class rawClass = getRawClass(actualType); if (type instanceof WildcardType wildcardType) { Type boundType = wildcardType.getUpperBounds().length > 0 ? wildcardType.getUpperBounds()[0] @@ -97,6 +108,12 @@ public String toJsonSchema(Type type) throws Exception { Type elementType = getElementType(type); Map itemsSchema = toJsonSchemaMap(elementType); schema.put("items", itemsSchema); + } else if (rawClass == LocalDate.class || rawClass == Date.class) { + schema.put("type", "string"); + schema.put("format", "date"); + } else if (rawClass == LocalDateTime.class || rawClass == OffsetDateTime.class) { + schema.put("type", "string"); + schema.put("format", "date-time"); } else if (rawClass.isEnum()) { schema.put("type", "string"); schema.put("enum", getEnumConstants(rawClass)); @@ -104,22 +121,64 @@ public String toJsonSchema(Type type) throws Exception { schema.put("type", "object"); Map properties = new HashMap<>(); + List required = new ArrayList<>(); for (Field field : rawClass.getDeclaredFields()) { - field.setAccessible(true); - Map fieldSchema = toJsonSchemaMap(field.getGenericType()); - properties.put(field.getName(), fieldSchema); - if (field.isAnnotationPresent(Description.class)) { - Description description = field.getAnnotation(Description.class); - fieldSchema.put("description", description.value()); + try { + field.setAccessible(true); + Type fieldType = field.getGenericType(); + + // Check if the field is Optional and unwrap it if necessary + boolean fieldIsOptional = isJavaOptional(fieldType); + Type fieldActualType = fieldIsOptional ? unwrapOptionalType(fieldType) : fieldType; + + Map fieldSchema = toJsonSchemaMap(fieldActualType); + properties.put(field.getName(), fieldSchema); + + if (field.isAnnotationPresent(Description.class)) { + Description description = field.getAnnotation(Description.class); + fieldSchema.put("description", String.join(",", description.value())); + } + + // Only add to required if it is not Optional + if (!fieldIsOptional) { + required.add(field.getName()); + } else { + fieldSchema.put("nullable", true); // Mark as nullable in the JSON schema + } + + } catch (Exception e) { + } + } schema.put("properties", properties); + if (!required.isEmpty()) { + schema.put("required", required); + } + } + if (isOptional) { + schema.put("nullable", true); } - ObjectMapper mapper = new ObjectMapper(); return mapper.writeValueAsString(schema); // Convert the schema map to a JSON string } + private boolean isJavaOptional(Type type) { + if (type instanceof ParameterizedType) { + Type rawType = ((ParameterizedType) type).getRawType(); + return rawType == Optional.class || rawType == OptionalInt.class || rawType == OptionalLong.class + || rawType == OptionalDouble.class; + } + return false; + } + + private Type unwrapOptionalType(Type optionalType) { + if (optionalType instanceof ParameterizedType) { + return ((ParameterizedType) optionalType).getActualTypeArguments()[0]; + } + return optionalType; + } + private Class getRawClass(Type type) { if (type instanceof Class) { return (Class) type; diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java index 17b3f1602..522c486a5 100644 --- a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/AssistantWithToolsResource.java @@ -2,6 +2,7 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import jakarta.annotation.PreDestroy; @@ -12,6 +13,8 @@ import org.jboss.resteasy.reactive.RestQuery; +import com.fasterxml.jackson.annotation.JsonProperty; + import dev.langchain4j.agent.tool.Tool; import dev.langchain4j.memory.ChatMemory; import dev.langchain4j.memory.chat.ChatMemoryProvider; @@ -30,18 +33,24 @@ public AssistantWithToolsResource(Assistant assistant) { public static class TestData { @Description("Foo description for structured output") + @JsonProperty("foo") String foo; @Description("Foo description for structured output") + @JsonProperty("bar") Integer bar; @Description("Foo description for structured output") - Double baz; + @JsonProperty("baz") + Optional baz; + + public TestData() { + } TestData(String foo, Integer bar, Double baz) { this.foo = foo; this.bar = bar; - this.baz = baz; + this.baz = Optional.of(baz); } } @@ -50,18 +59,11 @@ public String get(@RestQuery String message) { return assistant.chat(message); } - @GET - @Path("/many") - public List getMany(@RestQuery String message) { - return assistant.chats(message); - } - @RegisterAiService(tools = Calculator.class, chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class) public interface Assistant { String chat(String userMessage); - List chats(String userMessage); } @Singleton diff --git a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/EntityMappedResource.java b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/EntityMappedResource.java index 84e2aff77..d798fbbfd 100644 --- a/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/EntityMappedResource.java +++ b/integration-tests/openai/src/main/java/org/acme/example/openai/aiservices/EntityMappedResource.java @@ -2,12 +2,19 @@ import java.util.ArrayList; import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import jakarta.ws.rs.POST; import jakarta.ws.rs.Path; import org.jboss.resteasy.reactive.RestQuery; +import com.fasterxml.jackson.annotation.JsonProperty; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.structured.Description; import dev.langchain4j.service.UserMessage; import io.quarkiverse.langchain4j.RegisterAiService; @@ -23,26 +30,40 @@ public EntityMappedResource(EntityMappedDescriber describer) { public static class TestData { @Description("Foo description for structured output") + @JsonProperty("foo") String foo; @Description("Foo description for structured output") + @JsonProperty("bar") Integer bar; @Description("Foo description for structured output") - Double baz; + @JsonProperty("baz") + Optional baz; + + public TestData() { + } TestData(String foo, Integer bar, Double baz) { this.foo = foo; this.bar = bar; - this.baz = baz; + this.baz = Optional.of(baz); } } - @POST - public List generate(@RestQuery String message) { - var result = describer.describe(message); - - return result; + public static class MirrorModelSupplier implements Supplier { + @Override + public ChatLanguageModel get() { + return (messages) -> new Response<>(new AiMessage(""" + [ + { + "foo": "asd", + "bar": 1, + "baz": 2.0 + } + ] + """)); + } } @POST @@ -51,15 +72,13 @@ public List generateMapped(@RestQuery String message) { List inputs = new ArrayList<>(); inputs.add(new TestData(message, 100, 100.0)); - return describer.describeMapped(inputs); + var test = describer.describeMapped(inputs); + return test; } - @RegisterAiService + @RegisterAiService(chatLanguageModelSupplier = MirrorModelSupplier.class) public interface EntityMappedDescriber { - @UserMessage("This is a describer returning a collection of strings") - List describe(String url); - @UserMessage("This is a describer returning a collection of mapped entities") List describeMapped(List inputs); } diff --git a/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithEntityMappingTest.java b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithEntityMappingTest.java index 46612942b..5058b935d 100644 --- a/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithEntityMappingTest.java +++ b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithEntityMappingTest.java @@ -1,7 +1,7 @@ package org.acme.example.openai.aiservices; import static io.restassured.RestAssured.given; -import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.*; import java.net.URL; @@ -18,17 +18,6 @@ public class AssistantResourceWithEntityMappingTest { @TestHTTPResource URL url; - @Test - public void get() { - given() - .baseUri(url.toString()) - .queryParam("message", "This is a test") - .post() - .then() - .statusCode(200) - .body(containsString("MockGPT")); - } - @Test public void getMany() { given() @@ -37,6 +26,9 @@ public void getMany() { .post() .then() .statusCode(200) - .body(containsString("MockGPT")); + .body("$", hasSize(1)) // Ensure that the response is an array with exactly one item + .body("[0].foo", equalTo("asd")) // Check that foo is set correctly + .body("[0].bar", equalTo(1)) // Check that bar is 100 + .body("[0].baz", equalTo(2.0F)); } } diff --git a/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithToolsTest.java b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithToolsTest.java index ce5331bb4..524df4e25 100644 --- a/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithToolsTest.java +++ b/integration-tests/openai/src/test/java/org/acme/example/openai/aiservices/AssistantResourceWithToolsTest.java @@ -29,14 +29,4 @@ public void get() { .body(containsString("MockGPT")); } - @Test - public void getMany() { - given() - .baseUri(url.toString() + "/many") - .queryParam("message", "This is a test") - .get() - .then() - .statusCode(200) - .body(containsString("MockGPT")); - } }