Skip to content

Commit

Permalink
add lora
Browse files Browse the repository at this point in the history
  • Loading branch information
zkh2016 committed Sep 14, 2024
1 parent 4c05653 commit a41e722
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
5 changes: 3 additions & 2 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]:
toktypes: list[int] = []

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size

Expand Down Expand Up @@ -512,6 +512,7 @@ def get_vocab_base_pre(self, tokenizer) -> str:
# NOTE: if you get an error here, you need to update the convert_hf_to_gguf_update.py script
# or pull the latest version of the model from Huggingface
# don't edit the hashes manually!
res = "llama-bpe"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
Expand Down Expand Up @@ -596,7 +597,7 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "bc01ce58980e1db43859146dc51b1758b3b88729b217a74792e9f8d43e479d21":
# ref: https://huggingface.co/TurkuNLP/gpt3-finnish-small
res = "gpt3-finnish"

print("=============== res = ", res)
if res is None:
logger.warning("\n")
logger.warning("**************************************************************************************")
Expand Down
18 changes: 12 additions & 6 deletions convert_lora_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@ def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):


def get_base_tensor_name(lora_tensor_name: str) -> str:
base_name = lora_tensor_name.replace("base_model.model.", "")
base_name = base_name.replace(".lora_A.weight", ".weight")
base_name = base_name.replace(".lora_B.weight", ".weight")
base_name = lora_tensor_name.replace("base_model.model.llm.", "")
base_name = base_name.replace(".lora_A.default.weight", ".weight")
base_name = base_name.replace(".lora_B.default.weight", ".weight")
return base_name


Expand Down Expand Up @@ -338,8 +338,10 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
if self.lazy:
tensor = LazyTorchTensor.from_eager(tensor)
base_name = get_base_tensor_name(name)
is_lora_a = ".lora_A.weight" in name
is_lora_b = ".lora_B.weight" in name
is_lora_a = ".lora_A.default.weight" in name
is_lora_b = ".lora_B.default.weight" in name
print(base_name, tensor, is_lora_a, is_lora_b)
assert tensor is not None
if not is_lora_a and not is_lora_b:
if ".base_layer.weight" in name:
continue
Expand All @@ -351,13 +353,17 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
tensor_map[base_name].A = tensor
else:
tensor_map[base_name].B = tensor
assert tensor is not None

else:
if is_lora_a:
tensor_map[base_name] = PartialLoraTensor(A=tensor)
else:
tensor_map[base_name] = PartialLoraTensor(B=tensor)

assert tensor is not None
print()
for name, tensor in tensor_map.items():
print(name, tensor)
assert tensor.A is not None
assert tensor.B is not None
yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B)))
Expand Down
21 changes: 21 additions & 0 deletions examples/llava/minicpmv-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,27 @@ static struct llava_context * llava_init_context(gpt_params * params) {
return NULL;
}

llama_init_result iparams;

// load and optionally apply lora adapters
for (auto & la : params->lora_adapters) {
llama_lora_adapter_container loaded_la;
loaded_la.path = la.path;
loaded_la.scale = la.scale;
loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
if (loaded_la.adapter == nullptr) {
fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
// llama_free(lctx);
// llama_free_model(model);
// return iparams;
return NULL;
}
iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters
}
if (!params->lora_init_without_apply) {
llama_lora_adapters_apply(ctx_llama, iparams.lora_adapters);
}

auto ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));

ctx_llava->ctx_llama = ctx_llama;
Expand Down

0 comments on commit a41e722

Please sign in to comment.