Skip to content

Commit

Permalink
Merge pull request #1106 from cescoffier/response-augmenter
Browse files Browse the repository at this point in the history
Initial implementation of the response augmenter idea.
  • Loading branch information
geoand authored Nov 26, 2024
2 parents ba5350b + 0807de4 commit b71aeaa
Show file tree
Hide file tree
Showing 26 changed files with 1,584 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import io.quarkiverse.langchain4j.ToolBox;
import io.quarkiverse.langchain4j.deployment.config.LangChain4jBuildConfig;
import io.quarkiverse.langchain4j.deployment.devui.ToolProviderInfo;
import io.quarkiverse.langchain4j.deployment.items.AiServicesMethodBuildItem;
import io.quarkiverse.langchain4j.deployment.items.MethodParameterAllowedAnnotationsBuildItem;
import io.quarkiverse.langchain4j.deployment.items.MethodParameterIgnoredAnnotationsBuildItem;
import io.quarkiverse.langchain4j.deployment.items.SelectedChatModelProviderBuildItem;
Expand Down Expand Up @@ -102,7 +103,6 @@
import io.quarkus.arc.deployment.UnremovableBeanBuildItem;
import io.quarkus.arc.deployment.ValidationPhaseBuildItem;
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.builder.item.MultiBuildItem;
import io.quarkus.deployment.Capabilities;
import io.quarkus.deployment.Capability;
import io.quarkus.deployment.GeneratedClassGizmoAdaptor;
Expand All @@ -125,7 +125,6 @@
import io.quarkus.gizmo.ResultHandle;
import io.quarkus.runtime.metrics.MetricsFactory;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class AiServicesProcessor {
Expand Down Expand Up @@ -186,7 +185,7 @@ public void nativeSupport(CombinedIndexBuildItem indexBuildItem,
}
Set<DotName> returnTypesToRegister = new HashSet<>();
for (AiServicesMethodBuildItem aiServicesMethodBuildItem : aiServicesMethodBuildItems) {
Type type = aiServicesMethodBuildItem.methodInfo.returnType();
Type type = aiServicesMethodBuildItem.getMethodInfo().returnType();
if (type.kind() == Type.Kind.PRIMITIVE) {
continue;
}
Expand Down Expand Up @@ -747,18 +746,29 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
}

@BuildStep
public void markUsedOutputGuardRailsUnremovable(List<AiServicesMethodBuildItem> methods,
public void markUsedGuardRailsUnremovable(List<AiServicesMethodBuildItem> methods,
BuildProducer<UnremovableBeanBuildItem> unremovableProducer) {
for (AiServicesMethodBuildItem method : methods) {
List<String> list = new ArrayList<>(method.getOutputGuardrails());
list.addAll(method.getInputGuardrails());
for (String cn : list) {
unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(DotName.createSimple(cn)));
}
DotName dotName = DotName.createSimple(OutputGuardrailAccumulator.class);
if (method.methodInfo.hasAnnotation(dotName)) {
unremovableProducer.produce(
UnremovableBeanBuildItem.beanTypes(method.methodInfo.annotation(dotName).value().asClass().name()));
if (method.getMethodInfo().hasAnnotation(DotNames.OUTPUT_GUARDRAIL_ACCUMULATOR)) {
DotName name = method.getMethodInfo().annotation(DotNames.OUTPUT_GUARDRAIL_ACCUMULATOR)
.value().asClass().name();
unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(name));
}
}
}

@BuildStep
public void markUsedResponseAugmenterUnremovable(List<AiServicesMethodBuildItem> methods,
BuildProducer<UnremovableBeanBuildItem> unremovableProducer) {
for (AiServicesMethodBuildItem method : methods) {
var cn = method.getResponseAugmenter();
if (cn != null) {
unremovableProducer.produce(UnremovableBeanBuildItem.beanTypes(DotName.createSimple(cn)));
}
}
}
Expand Down Expand Up @@ -830,29 +840,31 @@ public void validateGuardrails(SynthesisFinishedBuildItem synthesisFinished,
}

DotName dotName = DotName.createSimple(OutputGuardrailAccumulator.class);
if (method.methodInfo.hasAnnotation(dotName)) {
if (method.getMethodInfo().hasAnnotation(dotName)) {
// We have an accumulator
// Check that the accumulator exists
var bean = method.methodInfo.annotation(dotName).value().asClass().name();
var bean = method.getMethodInfo().annotation(dotName).value().asClass().name();
if (synthesisFinished.beanStream().withBeanType(bean).isEmpty()) {
errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
new DeploymentException("Missing accumulator bean: " + bean.toString())));
}

// Check that the accumulator is used on a method retuning a Multi
DotName returnedType = method.methodInfo.returnType().name();
DotName returnedType = method.getMethodInfo().returnType().name();
if (!DotName.createSimple(Multi.class).equals(returnedType)) {
errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
new DeploymentException("OutputGuardrailAccumulator can only be used on method returning a " +
"`Multi<X>`: found `%s` for method `%s.%s`".formatted(returnedType,
method.methodInfo.declaringClass().toString(), method.methodInfo.name()))));
method.getMethodInfo().declaringClass().toString(),
method.getMethodInfo().name()))));
}

// Check that the method have output guardrails
if (method.outputGuardrails.isEmpty()) {
if (method.getOutputGuardrails().isEmpty()) {
errors.produce(new ValidationPhaseBuildItem.ValidationErrorBuildItem(
new DeploymentException("OutputGuardrailAccumulator used without OutputGuardrails in method `%s.%s`"
.formatted(method.methodInfo.declaringClass().toString(), method.methodInfo.name()))));
.formatted(method.getMethodInfo().declaringClass().toString(),
method.getMethodInfo().name()))));
}
}
}
Expand Down Expand Up @@ -1125,7 +1137,7 @@ public void handleAiServices(
aiServicesMethodProducer.produce(new AiServicesMethodBuildItem(methodInfo,
methodCreateInfo.getInputGuardrailsClassNames(),
methodCreateInfo.getOutputGuardrailsClassNames(),
gatherMethodToolClassNames(methodInfo),
methodCreateInfo.getResponseAugmenterClassName(),
methodCreateInfo));
}
}
Expand Down Expand Up @@ -1259,14 +1271,16 @@ private AiServiceMethodCreateInfo gatherMethodMetadata(

String accumulatorClassName = AiServicesMethodBuildItem.gatherAccumulator(method);

String responseAugmenterClassName = AiServicesMethodBuildItem.gatherResponseAugmenter(method);

// Detect if tools execution may block the caller thread.
boolean switchToWorkerThread = detectIfToolExecutionRequiresAWorkerThread(method, tools, methodToolClassNames);

return new AiServiceMethodCreateInfo(method.declaringClass().name().toString(), method.name(), systemMessageInfo,
userMessageInfo, memoryIdParamPosition, requiresModeration,
returnTypeSignature(method.returnType(), new TypeArgMapper(method.declaringClass(), index)),
metricsTimedInfo, metricsCountedInfo, spanInfo, responseSchemaInfo, methodToolClassNames, switchToWorkerThread,
inputGuardrails, outputGuardrails, accumulatorClassName);
inputGuardrails, outputGuardrails, accumulatorClassName, responseAugmenterClassName);
}

private boolean detectIfToolExecutionRequiresAWorkerThread(MethodInfo method, List<ToolMethodBuildItem> tools,
Expand Down Expand Up @@ -1759,70 +1773,4 @@ static Map<String, Integer> toNameToArgsPositionMap(List<TemplateParameterInfo>
}
}

public static final class AiServicesMethodBuildItem extends MultiBuildItem {

private final MethodInfo methodInfo;
private final List<String> outputGuardrails;
private final List<String> inputGuardrails;
private final List<String> tools;
private final AiServiceMethodCreateInfo methodCreateInfo;

public AiServicesMethodBuildItem(MethodInfo methodInfo, List<String> inputGuardrails, List<String> outputGuardrails,
List<String> tools,
AiServiceMethodCreateInfo methodCreateInfo) {
this.methodInfo = methodInfo;
this.inputGuardrails = inputGuardrails;
this.outputGuardrails = outputGuardrails;
this.tools = tools;
this.methodCreateInfo = methodCreateInfo;
}

public List<String> getOutputGuardrails() {
return outputGuardrails;
}

public List<String> getInputGuardrails() {
return inputGuardrails;
}

public MethodInfo getMethodInfo() {
return methodInfo;
}

public AiServiceMethodCreateInfo getMethodCreateInfo() {
return methodCreateInfo;
}

public static List<String> gatherGuardrails(MethodInfo methodInfo, DotName annotation) {
List<String> guardrails = new ArrayList<>();
AnnotationInstance instance = methodInfo.annotation(annotation);
if (instance == null) {
// Check on class
instance = methodInfo.declaringClass().declaredAnnotation(annotation);
}
if (instance != null) {
Type[] array = instance.value().asClassArray();
for (Type type : array) {
// Make sure each guardrail is used only once
if (!guardrails.contains(type.name().toString())) {
guardrails.add(type.name().toString());
}
}
}
return guardrails;
}

public static String gatherAccumulator(MethodInfo methodInfo) {
DotName annotation = DotName.createSimple(OutputGuardrailAccumulator.class);
AnnotationInstance instance = methodInfo.annotation(annotation);
if (instance == null) {
// Check on class
instance = methodInfo.declaringClass().declaredAnnotation(annotation);
}
if (instance != null) {
return instance.value().asClass().name().toString();
}
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import io.quarkiverse.langchain4j.auth.ModelAuthProvider;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
import io.quarkiverse.langchain4j.response.AiResponseAugmenter;
import io.quarkiverse.langchain4j.response.ResponseAugmenter;
import io.smallrye.common.annotation.Blocking;
import io.smallrye.common.annotation.NonBlocking;
import io.smallrye.common.annotation.RunOnVirtualThread;
Expand Down Expand Up @@ -58,4 +61,16 @@ public class DotNames {
public static final DotName CHAT_MODEL_LISTENER = DotName.createSimple(ChatModelListener.class);
public static final DotName MODEL_AUTH_PROVIDER = DotName.createSimple(ModelAuthProvider.class);
public static final DotName TOOL = DotName.createSimple(Tool.class);

public static final DotName OUTPUT_GUARDRAIL_ACCUMULATOR = DotName.createSimple(OutputGuardrailAccumulator.class);

/**
* The {@link AiResponseAugmenter} interface.
*/
public static final DotName AI_RESPONSE_AUGMENTER = DotName.createSimple(AiResponseAugmenter.class);

/**
* The {@link ResponseAugmenter} annotation.
*/
public static final DotName RESPONSE_AUGMENTER_ANNOTATION = DotName.createSimple(ResponseAugmenter.class);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package io.quarkiverse.langchain4j.deployment.items;

import java.util.ArrayList;
import java.util.List;

import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.DotName;
import org.jboss.jandex.MethodInfo;
import org.jboss.jandex.Type;

import io.quarkiverse.langchain4j.guardrails.OutputGuardrailAccumulator;
import io.quarkiverse.langchain4j.response.ResponseAugmenter;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkus.builder.item.MultiBuildItem;

/**
* A build item representing a method from an AI service.
*/
public final class AiServicesMethodBuildItem extends MultiBuildItem {

private final MethodInfo methodInfo;
private final List<String> outputGuardrails;
private final List<String> inputGuardrails;
private final AiServiceMethodCreateInfo methodCreateInfo;
private final String responseAugmenter;

public AiServicesMethodBuildItem(MethodInfo methodInfo, List<String> inputGuardrails, List<String> outputGuardrails,
String responseAugmenter,
AiServiceMethodCreateInfo methodCreateInfo) {
this.methodInfo = methodInfo;
this.inputGuardrails = inputGuardrails;
this.outputGuardrails = outputGuardrails;
this.responseAugmenter = responseAugmenter;
this.methodCreateInfo = methodCreateInfo;
}

public List<String> getOutputGuardrails() {
return outputGuardrails;
}

public List<String> getInputGuardrails() {
return inputGuardrails;
}

public MethodInfo getMethodInfo() {
return methodInfo;
}

public AiServiceMethodCreateInfo getMethodCreateInfo() {
return methodCreateInfo;
}

public String getResponseAugmenter() {
return responseAugmenter;
}

public static List<String> gatherGuardrails(MethodInfo methodInfo, DotName annotation) {
List<String> guardrails = new ArrayList<>();
AnnotationInstance instance = methodInfo.annotation(annotation);
if (instance == null) {
// Check on class
instance = methodInfo.declaringClass().declaredAnnotation(annotation);
}
if (instance != null) {
Type[] array = instance.value().asClassArray();
for (Type type : array) {
// Make sure each guardrail is used only once
if (!guardrails.contains(type.name().toString())) {
guardrails.add(type.name().toString());
}
}
}
return guardrails;
}

public static String gatherAccumulator(MethodInfo methodInfo) {
DotName annotation = DotName.createSimple(OutputGuardrailAccumulator.class);
AnnotationInstance instance = methodInfo.annotation(annotation);
if (instance == null) {
// Check on class
instance = methodInfo.declaringClass().declaredAnnotation(annotation);
}
if (instance != null) {
return instance.value().asClass().name().toString();
}
return null;
}

public static String gatherResponseAugmenter(MethodInfo methodInfo) {
DotName annotation = DotName.createSimple(ResponseAugmenter.class);
AnnotationInstance instance = methodInfo.annotation(annotation);
if (instance == null) {
// Check on class
instance = methodInfo.declaringClass().declaredAnnotation(annotation);
}
if (instance != null) {
return instance.value().asClass().name().toString();
}
return null;
}
}
Loading

0 comments on commit b71aeaa

Please sign in to comment.