Skip to content

Commit

Permalink
Merge pull request #1110 from aldettinger/QUARKUS-LANGCHAIN4J-1087
Browse files Browse the repository at this point in the history
Enable resolution of AI services by bean name
  • Loading branch information
geoand authored Nov 26, 2024
2 parents b71aeaa + 6fe51bc commit aad4aff
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.FORCE_ALLOW;
import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.IGNORE;
import static io.quarkiverse.langchain4j.deployment.MethodParameterAsTemplateVariableAllowance.OPTIONAL_DENY;
import static io.quarkus.arc.processor.DotNames.NAMED;

import java.io.IOException;
import java.io.InputStream;
Expand Down Expand Up @@ -380,6 +381,12 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,

String imageModelName = chatModelName; // TODO: should we have a separate setting for this?

AnnotationInstance namedAnno = declarativeAiServiceClassInfo.annotation(NAMED);
Optional<String> beanName = Optional.empty();
if (namedAnno != null) {
beanName = Optional.ofNullable(namedAnno.value().asString());
}

declarativeAiServiceProducer.produce(
new DeclarativeAiServiceBuildItem(
declarativeAiServiceClassInfo,
Expand All @@ -398,7 +405,8 @@ public void findDeclarativeServices(CombinedIndexBuildItem indexBuildItem,
chatModelName,
moderationModelName,
imageModelName,
toolProviderClassName));
toolProviderClassName,
beanName));
}
toolProviderProducer.produce(new ToolProviderMetaBuildItem(toolProviderInfos));

Expand Down Expand Up @@ -705,6 +713,8 @@ public void handleDeclarativeServices(AiServicesRecorder recorder,
allToolProviders.add(toolProvider);
}

bi.getBeanName().ifPresent(beanName -> configurator.named(beanName));

configurator
.addInjectionPoint(ParameterizedType.create(DotNames.CDI_INSTANCE,
new Type[] { ClassType.create(OutputGuardrail.class) }, null))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.quarkiverse.langchain4j.deployment;

import java.util.List;
import java.util.Optional;

import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.DotName;
Expand Down Expand Up @@ -30,6 +31,7 @@ public final class DeclarativeAiServiceBuildItem extends MultiBuildItem {
private final String chatModelName;
private final String moderationModelName;
private final String imageModelName;
private final Optional<String> beanName;

public DeclarativeAiServiceBuildItem(
ClassInfo serviceClassInfo,
Expand All @@ -48,7 +50,8 @@ public DeclarativeAiServiceBuildItem(
String chatModelName,
String moderationModelName,
String imageModelName,
DotName toolProviderClassDotName) {
DotName toolProviderClassDotName,
Optional<String> beanName) {
this.serviceClassInfo = serviceClassInfo;
this.chatLanguageModelSupplierClassDotName = chatLanguageModelSupplierClassDotName;
this.streamingChatLanguageModelSupplierClassDotName = streamingChatLanguageModelSupplierClassDotName;
Expand All @@ -66,6 +69,7 @@ public DeclarativeAiServiceBuildItem(
this.moderationModelName = moderationModelName;
this.imageModelName = imageModelName;
this.toolProviderClassDotName = toolProviderClassDotName;
this.beanName = beanName;
}

public ClassInfo getServiceClassInfo() {
Expand Down Expand Up @@ -135,4 +139,8 @@ public String getImageModelName() {
public DotName getToolProviderClassDotName() {
return toolProviderClassDotName;
}

public Optional<String> getBeanName() {
return beanName;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io.quarkiverse.langchain4j.test;

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.util.List;

import jakarta.enterprise.inject.spi.BeanManager;
import jakarta.inject.Inject;
import jakarta.inject.Named;
import jakarta.inject.Singleton;

import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;
import io.quarkus.test.QuarkusUnitTest;

public class NamedAiServicesAreResolvableByNameTest {

private static final String MY_NAMED_SERVICE_BEAN = "myNamedServiceBean";

@Inject
BeanManager beanManager;

@RegisterExtension
static final QuarkusUnitTest unitTest = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(MyNamedService.class));

@Named(MY_NAMED_SERVICE_BEAN)
@RegisterAiService(chatMemoryProviderSupplier = RegisterAiService.NoChatMemoryProviderSupplier.class)
interface MyNamedService {
@UserMessage("Dummy prompt for " + MY_NAMED_SERVICE_BEAN)
String chat();
}

@Singleton
public static class MyLanguageModel implements ChatLanguageModel {
@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return null;
}
}

@Test
void namedAiServiceCouldBeResolvedByNameTest() {
assertEquals(1, beanManager.getBeans(MY_NAMED_SERVICE_BEAN).size());
}
}

0 comments on commit aad4aff

Please sign in to comment.