Skip to content

Commit

Permalink
Merge pull request #1076 from quarkiverse/llama3-updated
Browse files Browse the repository at this point in the history
Use latest version of the Llama3.java code
  • Loading branch information
geoand authored Nov 14, 2024
2 parents c10b478 + 1c1a2d1 commit f4b60ad
Show file tree
Hide file tree
Showing 7 changed files with 280 additions and 155 deletions.
11 changes: 11 additions & 0 deletions docs/modules/ROOT/pages/llama3.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ WARNING: Models are huge, so make sure you have enough disk space.
NOTE: Due to model's large size, pulling them can take time
=== Native mode
Currently, Llama3.java only works in native mode with Early Access version's of Oracle GraalVM 24 (which can be easily downloaded with https://sdkman.io[SDKMan]).
To achieve the best performance in native mode, it is suggested to configure the application with the following:
[source,properties,subs=attributes+]
----
quarkus.native.additional-build-args=-O3,-march=native
----
== Using Llama3.java
To let Llama3.java running inference on your models, add the following dependency into your project:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public Response<AiMessage> generate(List<ChatMessage> messages) {

private InferenceResponse runInference(Llama model, Sampler sampler, Llama3.Options options,
List<ChatFormat.Message> messages) {
Llama.State state = model.createNewState();
Llama.State state = model.createNewState(Llama3.BATCH_SIZE);
ChatFormat chatFormat = new ChatFormat(model.tokenizer());

List<Integer> promptTokens = new ArrayList<>(chatFormat.encodeDialogPrompt(true, messages));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static io.quarkiverse.langchain4j.llama3.MessageMapper.toLlama3Message;
import static io.quarkiverse.langchain4j.llama3.copy.Llama3.BATCH_SIZE;
import static io.quarkiverse.langchain4j.llama3.copy.Llama3.selectSampler;

import java.io.IOException;
Expand Down Expand Up @@ -84,7 +85,7 @@ public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMess
private void runInference(Llama model, Sampler sampler, Llama3.Options options,
List<ChatFormat.Message> messages,
StreamingResponseHandler<AiMessage> handler) {
Llama.State state = model.createNewState();
Llama.State state = model.createNewState(BATCH_SIZE);
ChatFormat chatFormat = new ChatFormat(model.tokenizer());

List<Integer> promptTokens = new ArrayList<>(chatFormat.encodeDialogPrompt(true, messages));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public static PartialModel preLoadGGUF(String modelPath) {
* No checksum/hash is checked for performance reasons.
*/
public static Llama tryUsePreLoaded(Path modelPath, int contextLength) throws IOException {
PartialModel preLoaded = AOT.PRELOADED_GGUF;
AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF;
if (preLoaded == null) {
return null; // no pre-loaded model stored
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ private void loadModelImpl(FileChannel fileChannel) throws IOException {
// gguf_tensor_info_t tensor_infos[header.tensor_count];
this.tensorInfos = HashMap.newHashMap(tensorCount);
for (int i = 0; i < tensorCount; ++i) {
GGUFTensorInfo ti = readTensorInfo(fileChannel);
GGUF.GGUFTensorInfo ti = readTensorInfo(fileChannel);
assert !tensorInfos.containsKey(ti.name);
tensorInfos.put(ti.name, ti);
}
Expand Down Expand Up @@ -156,7 +156,7 @@ private GGMLType readGGMLType(FileChannel fileChannel) throws IOException {
return GGMLType.fromId(ggmlTypeId);
}

private GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOException {
private GGUF.GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOException {
// The name of the tensor. It is a standard GGUF string, with the caveat that
// it must be at most 64 bytes long.
String name = readString(fileChannel); // gguf_string_t name;
Expand All @@ -180,7 +180,7 @@ private GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOExceptio
// Must be a multiple of `ALIGNMENT`.
long offset = readLong(fileChannel); // uint64_t offset;
assert offset % getAlignment() == 0;
return new GGUFTensorInfo(name, dimensions, ggmlType, offset);
return new GGUF.GGUFTensorInfo(name, dimensions, ggmlType, offset);
}

private String readString(FileChannel fileChannel) throws IOException {
Expand Down
Loading

0 comments on commit f4b60ad

Please sign in to comment.