Skip to content

Commit

Permalink
Merge pull request #1092 from mariofusco/no-bloking-jlama
Browse files Browse the repository at this point in the history
Prevent Jlama inference to block vertx event loop
  • Loading branch information
geoand authored Nov 21, 2024
2 parents c63b96e + b1fe7df commit 867216e
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package io.quarkiverse.langchain4j.runtime;

import java.util.concurrent.Callable;

import io.smallrye.common.vertx.VertxContext;
import io.smallrye.mutiny.infrastructure.Infrastructure;
import io.vertx.core.Context;

public class VertxUtil {

public static void runOutEventLoop(Runnable runnable) {
if (Context.isOnEventLoopThread()) {
Context executionContext = VertxContext.getOrCreateDuplicatedContext();
if (executionContext != null) {
executionContext.executeBlocking(new Callable<Object>() {
@Override
public Object call() {
runnable.run();
return null;
}
});
} else {
Infrastructure.getDefaultWorkerPool().execute(runnable);
}
} else {
runnable.run();
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.quarkiverse.langchain4j.jlama;

import static io.quarkiverse.langchain4j.jlama.JlamaModel.toFinishReason;
import static io.quarkiverse.langchain4j.runtime.VertxUtil.runOutEventLoop;

import java.nio.file.Path;
import java.util.List;
Expand All @@ -10,6 +11,7 @@
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;

import dev.langchain4j.data.message.AiMessage;
Expand All @@ -32,17 +34,21 @@ public JlamaStreamingChatModel(JlamaStreamingChatModelBuilder builder) {
.withRetry(() -> registry.downloadModel(builder.modelName, Optional.ofNullable(builder.authToken)), 3);

JlamaModel.Loader loader = jlamaModel.loader();
if (builder.quantizeModelAtRuntime != null && builder.quantizeModelAtRuntime)
if (builder.quantizeModelAtRuntime != null && builder.quantizeModelAtRuntime) {
loader = loader.quantized();
}

if (builder.workingQuantizedType != null)
if (builder.workingQuantizedType != null) {
loader = loader.workingQuantizationType(builder.workingQuantizedType);
}

if (builder.threadCount != null)
if (builder.threadCount != null) {
loader = loader.threadCount(builder.threadCount);
}

if (builder.workingDirectory != null)
if (builder.workingDirectory != null) {
loader = loader.workingDirectory(builder.workingDirectory);
}

this.model = loader.load();
this.temperature = builder.temperature == null ? 0.7f : builder.temperature;
Expand All @@ -55,21 +61,18 @@ public static JlamaStreamingChatModelBuilder builder() {

@Override
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
if (model.promptSupport().isEmpty())
throw new UnsupportedOperationException("This model does not support chat generation");

PromptSupport.Builder promptBuilder = model.promptSupport().get().builder();
for (ChatMessage message : messages) {
switch (message.type()) {
case SYSTEM -> promptBuilder.addSystemMessage(message.text());
case USER -> promptBuilder.addUserMessage(message.text());
case AI -> promptBuilder.addAssistantMessage(message.text());
default -> throw new IllegalArgumentException("Unsupported message type: " + message.type());
PromptContext promptContext = createPromptContext(messages);
runOutEventLoop(new Runnable() {
@Override
public void run() {
internalGenerate(handler, promptContext);
}
}
});
}

private void internalGenerate(StreamingResponseHandler<AiMessage> handler, PromptContext promptContext) {
try {
Generator.Response r = model.generate(id, promptBuilder.build(), temperature, maxTokens, (token, time) -> {
Generator.Response r = model.generate(id, promptContext, temperature, maxTokens, (token, time) -> {
handler.onNext(token);
});

Expand All @@ -80,6 +83,23 @@ public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMess
}
}

private PromptContext createPromptContext(List<ChatMessage> messages) {
if (model.promptSupport().isEmpty()) {
throw new UnsupportedOperationException("This model does not support chat generation");
}

PromptSupport.Builder promptBuilder = model.promptSupport().get().builder();
for (ChatMessage message : messages) {
switch (message.type()) {
case SYSTEM -> promptBuilder.addSystemMessage(message.text());
case USER -> promptBuilder.addUserMessage(message.text());
case AI -> promptBuilder.addAssistantMessage(message.text());
default -> throw new IllegalArgumentException("Unsupported message type: " + message.type());
}
}
return promptBuilder.build();
}

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public static class JlamaStreamingChatModelBuilder {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import static io.quarkiverse.langchain4j.llama3.MessageMapper.toLlama3Message;
import static io.quarkiverse.langchain4j.llama3.copy.Llama3.BATCH_SIZE;
import static io.quarkiverse.langchain4j.llama3.copy.Llama3.selectSampler;
import static io.quarkiverse.langchain4j.runtime.VertxUtil.runOutEventLoop;

import java.io.IOException;
import java.io.UncheckedIOException;
Expand Down Expand Up @@ -79,7 +80,13 @@ public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMess
);
Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(),
options.seed());
runInference(model, sampler, options, llama3Messages, handler);

runOutEventLoop(new Runnable() {
@Override
public void run() {
runInference(model, sampler, options, llama3Messages, handler);
}
});
}

private void runInference(Llama model, Sampler sampler, Llama3.Options options,
Expand Down

0 comments on commit 867216e

Please sign in to comment.