Skip to content

Commit

Permalink
Add handling for Optional types
Browse files Browse the repository at this point in the history
  • Loading branch information
Tarjei400 committed Nov 5, 2024
1 parent bd88ad9 commit 0f5e158
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ public class ToolProcessor {
Object.class);
private static final Logger log = Logger.getLogger(ToolProcessor.class);

public static final DotName OPTIONAL = DotName.createSimple("java.util.Optional");
public static final DotName OPTIONAL_INT = DotName.createSimple("java.util.OptionalInt");
public static final DotName OPTIONAL_LONG = DotName.createSimple("java.util.OptionalLong");
public static final DotName OPTIONAL_DOUBLE = DotName.createSimple("java.util.OptionalDouble");

@BuildStep
public void telemetry(Capabilities capabilities, BuildProducer<AdditionalBeanBuildItem> additionalBeanProducer) {
var addOpenTelemetrySpan = capabilities.isPresent(Capability.OPENTELEMETRY_TRACER);
Expand Down Expand Up @@ -488,11 +493,18 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
ClassInfo classInfo = index.getClassByName(type.name());

List<String> required = new ArrayList<>();

if (classInfo != null) {
for (FieldInfo field : classInfo.fields()) {
String fieldName = field.name();
Type fieldType = field.type();

Iterable<JsonSchemaProperty> fieldSchema = toJsonSchemaProperties(field.type(), index, null);
boolean isOptional = isJavaOptionalType(fieldType);
if (isOptional) {
fieldType = unwrapOptionalType(fieldType);
}

Iterable<JsonSchemaProperty> fieldSchema = toJsonSchemaProperties(fieldType, index, null);
Map<String, Object> fieldDescription = new HashMap<>();

for (JsonSchemaProperty fieldProperty : fieldSchema) {
Expand All @@ -506,6 +518,10 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
fieldDescription.put("description", String.join(",", descriptionValue));
}
}
if (!isOptional) {
required.add(fieldName);
}

properties.put(fieldName, fieldDescription);
}
}
Expand All @@ -517,10 +533,39 @@ private Iterable<JsonSchemaProperty> toJsonSchemaProperties(Type type, IndexView
throw new IllegalArgumentException("Unsupported type: " + type);
}

private boolean isJavaOptionalType(Type type) {
DotName typeName = type.name();
return typeName.equals(DotName.createSimple("java.util.Optional"))
|| typeName.equals(DotName.createSimple("java.util.OptionalInt"))
|| typeName.equals(DotName.createSimple("java.util.OptionalLong"))
|| typeName.equals(DotName.createSimple("java.util.OptionalDouble"));
}

private Type unwrapOptionalType(Type optionalType) {
if (optionalType.kind() == Type.Kind.PARAMETERIZED_TYPE) {
ParameterizedType parameterizedType = optionalType.asParameterizedType();
return parameterizedType.arguments().get(0);
}
return optionalType;
}

private boolean isComplexType(Type type) {
return type.kind() == Type.Kind.CLASS || type.kind() == Type.Kind.PARAMETERIZED_TYPE;
}

private boolean isOptionalField(FieldInfo field, IndexView index) {
Type fieldType = field.type();
DotName fieldTypeName = fieldType.name();

if (OPTIONAL.equals(fieldTypeName) || OPTIONAL_INT.equals(fieldTypeName) || OPTIONAL_LONG.equals(fieldTypeName)
|| OPTIONAL_DOUBLE.equals(fieldTypeName)) {
return true;
}

return false;

}

private Iterable<JsonSchemaProperty> removeNulls(JsonSchemaProperty... properties) {
return stream(properties)
.filter(Objects::nonNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import dev.langchain4j.service.Result;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.TypeUtils;
//import dev.langchain4j.service.output.OutputParser;
import dev.langchain4j.service.output.ServiceOutputParser;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
import io.smallrye.mutiny.Multi;
Expand All @@ -23,13 +22,17 @@ public class QuarkusServiceOutputParser extends ServiceOutputParser {

@Override
public String outputFormatInstructions(Type returnType) {
Class<?> rawClass = getRawClass(returnType);
boolean isOptional = isJavaOptional(returnType);
Type actualType = isOptional ? unwrapOptionalType(returnType) : returnType;

Class<?> rawClass = getRawClass(actualType);

if (rawClass != String.class && rawClass != AiMessage.class && rawClass != TokenStream.class
&& rawClass != Response.class && !Multi.class.equals(rawClass)) {
try {
var schema = this.toJsonSchema(returnType);
return "You must answer strictly with json according to the following json schema format: " + schema;
return "You must answer strictly with json according to the following json schema format. Use description metadata to fill data properly: "
+ schema;
} catch (Exception e) {
return "";
}
Expand Down Expand Up @@ -77,7 +80,10 @@ private String extractJsonBlock(String text) {

public String toJsonSchema(Type type) throws Exception {
Map<String, Object> schema = new HashMap<>();
Class<?> rawClass = getRawClass(type);
boolean isOptional = isJavaOptional(type);
Type actualType = isOptional ? unwrapOptionalType(type) : type;

Class<?> rawClass = getRawClass(actualType);

if (type instanceof WildcardType wildcardType) {
Type boundType = wildcardType.getUpperBounds().length > 0 ? wildcardType.getUpperBounds()[0]
Expand All @@ -104,22 +110,64 @@ public String toJsonSchema(Type type) throws Exception {
schema.put("type", "object");
Map<String, Object> properties = new HashMap<>();

List<String> required = new ArrayList<>();
for (Field field : rawClass.getDeclaredFields()) {
field.setAccessible(true);
Map<String, Object> fieldSchema = toJsonSchemaMap(field.getGenericType());
properties.put(field.getName(), fieldSchema);
if (field.isAnnotationPresent(Description.class)) {
Description description = field.getAnnotation(Description.class);
fieldSchema.put("description", description.value());
try {
field.setAccessible(true);
Type fieldType = field.getGenericType();

// Check if the field is Optional and unwrap it if necessary
boolean fieldIsOptional = isJavaOptional(fieldType);
Type fieldActualType = fieldIsOptional ? unwrapOptionalType(fieldType) : fieldType;

Map<String, Object> fieldSchema = toJsonSchemaMap(fieldActualType);
properties.put(field.getName(), fieldSchema);

if (field.isAnnotationPresent(Description.class)) {
Description description = field.getAnnotation(Description.class);
fieldSchema.put("description", String.join(",", description.value()));
}

// Only add to required if it is not Optional
if (!fieldIsOptional) {
required.add(field.getName());
} else {
fieldSchema.put("nullable", true); // Mark as nullable in the JSON schema
}

} catch (Exception e) {

}

}
schema.put("properties", properties);
if (!required.isEmpty()) {
schema.put("required", required);
}
}
if (isOptional) {
schema.put("nullable", true);
}

ObjectMapper mapper = new ObjectMapper();
return mapper.writeValueAsString(schema); // Convert the schema map to a JSON string
}

private boolean isJavaOptional(Type type) {
if (type instanceof ParameterizedType) {
Type rawType = ((ParameterizedType) type).getRawType();
return rawType == Optional.class || rawType == OptionalInt.class || rawType == OptionalLong.class
|| rawType == OptionalDouble.class;
}
return false;
}

private Type unwrapOptionalType(Type optionalType) {
if (optionalType instanceof ParameterizedType) {
return ((ParameterizedType) optionalType).getActualTypeArguments()[0];
}
return optionalType;
}

private Class<?> getRawClass(Type type) {
if (type instanceof Class<?>) {
return (Class<?>) type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

import jakarta.annotation.PreDestroy;
Expand All @@ -12,6 +13,8 @@

import org.jboss.resteasy.reactive.RestQuery;

import com.fasterxml.jackson.annotation.JsonProperty;

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
Expand All @@ -30,18 +33,24 @@ public AssistantWithToolsResource(Assistant assistant) {

public static class TestData {
@Description("Foo description for structured output")
@JsonProperty("foo")
String foo;

@Description("Foo description for structured output")
@JsonProperty("bar")
Integer bar;

@Description("Foo description for structured output")
Double baz;
@JsonProperty("baz")
Optional<Double> baz;

public TestData() {
}

TestData(String foo, Integer bar, Double baz) {
this.foo = foo;
this.bar = bar;
this.baz = baz;
this.baz = Optional.of(baz);
}
}

Expand All @@ -50,18 +59,11 @@ public String get(@RestQuery String message) {
return assistant.chat(message);
}

@GET
@Path("/many")
public List<TestData> getMany(@RestQuery String message) {
return assistant.chats(message);
}

@RegisterAiService(tools = Calculator.class, chatMemoryProviderSupplier = RegisterAiService.BeanChatMemoryProviderSupplier.class)
public interface Assistant {

String chat(String userMessage);

List<TestData> chats(String userMessage);
}

@Singleton
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Supplier;

import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;

import org.jboss.resteasy.reactive.RestQuery;

import com.fasterxml.jackson.annotation.JsonProperty;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
Expand All @@ -23,26 +30,40 @@ public EntityMappedResource(EntityMappedDescriber describer) {

public static class TestData {
@Description("Foo description for structured output")
@JsonProperty("foo")
String foo;

@Description("Foo description for structured output")
@JsonProperty("bar")
Integer bar;

@Description("Foo description for structured output")
Double baz;
@JsonProperty("baz")
Optional<Double> baz;

public TestData() {
}

TestData(String foo, Integer bar, Double baz) {
this.foo = foo;
this.bar = bar;
this.baz = baz;
this.baz = Optional.of(baz);
}
}

@POST
public List<String> generate(@RestQuery String message) {
var result = describer.describe(message);

return result;
public static class MirrorModelSupplier implements Supplier<ChatLanguageModel> {
@Override
public ChatLanguageModel get() {
return (messages) -> new Response<>(new AiMessage("""
[
{
"foo": "asd",
"bar": 1,
"baz": 2.0
}
]
"""));
}
}

@POST
Expand All @@ -51,15 +72,13 @@ public List<TestData> generateMapped(@RestQuery String message) {
List<TestData> inputs = new ArrayList<>();
inputs.add(new TestData(message, 100, 100.0));

return describer.describeMapped(inputs);
var test = describer.describeMapped(inputs);
return test;
}

@RegisterAiService
@RegisterAiService(chatLanguageModelSupplier = MirrorModelSupplier.class)
public interface EntityMappedDescriber {

@UserMessage("This is a describer returning a collection of strings")
List<String> describe(String url);

@UserMessage("This is a describer returning a collection of mapped entities")
List<TestData> describeMapped(List<TestData> inputs);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package org.acme.example.openai.aiservices;

import static io.restassured.RestAssured.given;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.*;

import java.net.URL;

Expand All @@ -18,17 +18,6 @@ public class AssistantResourceWithEntityMappingTest {
@TestHTTPResource
URL url;

@Test
public void get() {
given()
.baseUri(url.toString())
.queryParam("message", "This is a test")
.post()
.then()
.statusCode(200)
.body(containsString("MockGPT"));
}

@Test
public void getMany() {
given()
Expand All @@ -37,6 +26,9 @@ public void getMany() {
.post()
.then()
.statusCode(200)
.body(containsString("MockGPT"));
.body("$", hasSize(1)) // Ensure that the response is an array with exactly one item
.body("[0].foo", equalTo("asd")) // Check that foo is set correctly
.body("[0].bar", equalTo(1)) // Check that bar is 100
.body("[0].baz", equalTo(2.0F));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,4 @@ public void get() {
.body(containsString("MockGPT"));
}

@Test
public void getMany() {
given()
.baseUri(url.toString() + "/many")
.queryParam("message", "This is a test")
.get()
.then()
.statusCode(200)
.body(containsString("MockGPT"));
}
}

0 comments on commit 0f5e158

Please sign in to comment.