Skip to content

Commit

Permalink
Llama develop (speedup 2.x) (#504)
Browse files Browse the repository at this point in the history
* llama develop

* format

* llama rms layer norm

* finish llama kernel develop

* llama op develop

* llama develop

* format

* llama develop

* llama export

* llama export develop

* llama develop

* llama develop

* llama rotary fuse kernel

* llama fuse transpose and rotary position emb

* llama develop

* llama develop

* llama develop

* llama develop

* llama develop

* llama develop

* adapt export

* llama develop

* llama develop

* format
  • Loading branch information
hexisyztem authored May 10, 2023
1 parent 7e5bed6 commit a7ab0da
Show file tree
Hide file tree
Showing 48 changed files with 3,312 additions and 69 deletions.
2 changes: 1 addition & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ if [ ! -d 'build' ]; then
mkdir build
fi
# DEVICE_ARCH could be cuda/x86/arm
cd build && cmake -DUSE_NEW_ARCH=OFF -DDEVICE_ARCH=cuda -DUSE_TRITONBACKEND=OFF -DDEBUG_MODE=OFF -DFP16_MODE=OFF -DMEM_DEBUG=OFF .. && make -j${nproc}
cd build && cmake -DUSE_NEW_ARCH=ON -DDEVICE_ARCH=cuda -DUSE_TRITONBACKEND=OFF -DDEBUG_MODE=ON -DFP16_MODE=ON -DMEM_DEBUG=OFF .. && make -j${nproc}
# you can use comand like below to compile lightseq with pybind interface:
# sudo PATH=$PATH:/usr/local/hdf5 CUDACXX=/usr/local/cuda/bin/nvcc DEVICE_ARCH=cuda ENABLE_FP32=0 ENABLE_DEBUG=0 ENABLE_NEW_ARCH=1 python3 setup.py install
3 changes: 3 additions & 0 deletions lightseq/csrc/example/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ target_link_libraries(transformer_example PUBLIC liblightseq)

add_executable(gpt_example gpt_example.cc)
target_link_libraries(gpt_example PUBLIC liblightseq)

add_executable(llama_example llama_example.cc)
target_link_libraries(llama_example PUBLIC liblightseq)
94 changes: 94 additions & 0 deletions lightseq/csrc/example/llama_example.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#include "model_base.h"
#include "llama.h"

/**
@file
Example of how to run gpt inference using our implementation.
*/

int main(int argc, char* argv[]) {
std::string model_weights_path = argv[1];
std::vector<int> example_input = {1, 21784, 26539, 338,
263, 4933, 6509, 6890};
int eg_seq_len = example_input.size();

int batch_size = 1;
int batch_seq_len = eg_seq_len;

if (argc == 4) {
batch_size = atoi(argv[2]);
batch_seq_len = atoi(argv[3]);
}

int max_batch_size = std::max(8, batch_size);

std::vector<int> host_input;
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < batch_seq_len; ++j) {
host_input.push_back(example_input[j % eg_seq_len]);
}
}

auto model = lightseq::cuda::LSModelFactory::GetInstance().CreateModel(
"Llama", model_weights_path, 1);

void* d_input;
CHECK_GPU_ERROR(
cudaMalloc(&d_input, sizeof(int) * batch_size * batch_seq_len));
CHECK_GPU_ERROR(cudaMemcpy(d_input, host_input.data(),
sizeof(int) * batch_size * batch_seq_len,
cudaMemcpyHostToDevice));

printf("example step.1\n");

model->set_input_ptr(0, d_input);
model->set_input_shape(0, {batch_size, batch_seq_len});

for (int i = 0; i < model->get_output_size(); i++) {
void* d_output;
std::vector<int> shape = model->get_output_max_shape(i);
int total_size = 1;
for (int j = 0; j < shape.size(); j++) {
total_size *= shape[j];
}
CHECK_GPU_ERROR(cudaMalloc(&d_output, total_size * sizeof(int)));
model->set_output_ptr(i, d_output);
}
printf("example step.2\n");
CHECK_GPU_ERROR(cudaStreamSynchronize(0));
std::cout << "infer preprocessing finished" << std::endl;
printf("example step.2-1\n");
std::cout << "infer preprocessing finished 2" << std::endl;

std::chrono::duration<double> elapsed;
int iter = 0;
/* ---step5. infer and log--- */
for (int i = 0; i < 5; i++) {
auto start = std::chrono::high_resolution_clock::now();
model->Infer();
auto finish = std::chrono::high_resolution_clock::now();
if (i) {
iter++;
elapsed += finish - start;
}
}

std::cout << "lightseq inference latency: " << elapsed.count() * 1000 / iter
<< " ms" << std::endl;

for (int i = 0; i < model->get_output_size(); i++) {
const int* d_output;
d_output = static_cast<const int*>(model->get_output_ptr(i));
std::vector<int> shape = model->get_output_shape(i);
std::cout << "output shape: ";
for (int j = 0; j < shape.size(); j++) {
std::cout << shape[j] << " ";
}
std::cout << std::endl;
if (i == 0) {
lightseq::print_vec(d_output, "d_output", shape[2]);
}
}

return 0;
}
Empty file.
153 changes: 153 additions & 0 deletions lightseq/csrc/export/hf_llama_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
Export Hugging Face GPT2 models to hdf5 format.
"""
import __init__
import os
import h5py
import numpy as np
from collections import OrderedDict
from util import parse_args, check_arguements, ModelArguements, fill_hdf5_layer
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


"""
For the mapping dictionary: key is the value of the proto parameter,
value is a powerful expression, each && split tensor name of the matching path or expression.
The sub-pattern of the path is separated by spaces, and the expression starts with a expression_.
You can operate separately on each tensor and support multiple expressions. Multiple matching paths
and the expression will finally be concatenated on axis = -1.
"""


"""
'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.rotary_emb.inv_freq', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.post_attention_layernorm.weight'
"""

dec_layer_mapping_dict = OrderedDict(
{
"attention_norm_scale": "input_layernorm weight",
"attention_project_qkv": "self_attn q_proj weight&&self_attn k_proj weight&&self_attn v_proj weight&&expression_.transpose(0, 1)",
"attention_output": "self_attn o_proj weight&&expression_.transpose(0, 1)",
"ffn_norm_scale": "post_attention_layernorm weight",
"gate_up_project_weight": "mlp gate_proj weight&&mlp up_proj weight&&expression_.transpose(0, 1)",
"down_project_weight": "mlp down_proj weight&&expression_.transpose(0, 1)",
}
)

src_emb_mapping_dict = OrderedDict(
{
"post_norm_scale": "norm weight",
"token_embedding": "embed_tokens weight",
"logits_linear_weight": "lm_head weight&&expression_.transpose(0, 1)",
}
)


def extract_llama_weights(
output_file: str,
arguments: ModelArguements,
):
# load var names
state_dict = torch.load(arguments.model_file)

head_num = arguments.head_num
enc_var_name_list = list(state_dict.keys())

# initialize output file
output_file += ".hdf5"
print("Saving model to hdf5...")
print("Writing to {0}".format(output_file))

# exit(0)
hdf5_file = h5py.File(output_file, "w")

# fill each encoder layer's params
enc_tensor_names = {}
for name in enc_var_name_list:
name_split = name.split(".")
if len(name_split) <= 2 or not name_split[2].isdigit():
continue
layer_id = int(name_split[2])
enc_tensor_names.setdefault(layer_id, []).append(name)

# fill encoder_stack
for layer_id in sorted(enc_tensor_names.keys()):
fill_hdf5_layer(
enc_tensor_names[layer_id],
state_dict,
hdf5_file,
f"decoder_layers/{layer_id}/",
dec_layer_mapping_dict,
)

# fill src_embedding - except for position embedding
fill_hdf5_layer(
enc_var_name_list,
state_dict,
hdf5_file,
"src_embedding/",
src_emb_mapping_dict,
)

# save number of layers metadata
hdf5_file.create_dataset(
"model_conf/hidden_size", data=arguments.hidden_size, dtype="i4"
)
hdf5_file.create_dataset(
"model_conf/inner_size", data=arguments.inner_size, dtype="i4"
)
hdf5_file.create_dataset("model_conf/max_step", data=arguments.max_step, dtype="i4")
hdf5_file.create_dataset("model_conf/head_num", data=arguments.head_num, dtype="i4")
hdf5_file.create_dataset(
"model_conf/layer_num", data=arguments.layer_num, dtype="i4"
)
hdf5_file.create_dataset(
"model_conf/src_padding_id", data=arguments.padding_id, dtype="i4"
)
hdf5_file.create_dataset(
"model_conf/generate_method",
data=np.array([ord(c) for c in arguments.generation_method]).astype(np.int8),
dtype="i1",
)
hdf5_file.create_dataset("model_conf/topp", data=arguments.topp, dtype="f4")
hdf5_file.create_dataset("model_conf/topk", data=arguments.topk, dtype="i4")
hdf5_file.create_dataset("model_conf/eos_id", data=arguments.eos_id, dtype="i4")
hdf5_file.create_dataset(
"model_conf/extra_decode_length", data=arguments.extra_decode_length, dtype="i4"
)
hdf5_file.create_dataset(
"model_conf/src_vocab_size", data=arguments.vocab_size, dtype="i4"
)

hdf5_file.close()
# read-in again to double check
hdf5_file = h5py.File(output_file, "r")

def _print_pair(key, value):
if key == "generate_method":
value = "".join(map(chr, value[()]))
else:
value = value[()]
print(f"{key}: {value}")

list(map(lambda x: _print_pair(*x), hdf5_file["model_conf"].items()))


if __name__ == "__main__":
args = parse_args()

arguments = ModelArguements(args)
basename = os.path.basename(arguments.model_repo)
output_lightseq_model_name = "_".join(["lightseq_llama", basename, "7b"])
# default eos_id from https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel

arguments.eos_id = 2 # need to set
arguments.padding_id = 0 # need to set

if not check_arguements(arguments):
exit(0)

extract_llama_weights(output_lightseq_model_name, arguments)
Loading

0 comments on commit a7ab0da

Please sign in to comment.