From a7ab0dabe63ede35016ab151c5adfdf1b86f2eb4 Mon Sep 17 00:00:00 2001 From: zhoubofan Date: Wed, 10 May 2023 12:35:39 +0800 Subject: [PATCH] Llama develop (speedup 2.x) (#504) * llama develop * format * llama rms layer norm * finish llama kernel develop * llama op develop * llama develop * format * llama develop * llama export * llama export develop * llama develop * llama develop * llama rotary fuse kernel * llama fuse transpose and rotary position emb * llama develop * llama develop * llama develop * llama develop * llama develop * llama develop * adapt export * llama develop * llama develop * format --- build.sh | 2 +- lightseq/csrc/example/CMakeLists.txt | 3 + lightseq/csrc/example/llama_example.cc | 94 ++++ lightseq/csrc/export/__init__.py | 0 lightseq/csrc/export/hf_llama_export.py | 153 ++++++ lightseq/csrc/export/util.py | 189 +++++++ lightseq/csrc/kernels/cuda/CMakeLists.txt | 1 + .../kernels/cuda/includes/kernel_headers.h | 1 + lightseq/csrc/kernels/cuda/includes/kernels.h | 1 + .../kernels/cuda/includes/llama_kernels.h | 39 ++ lightseq/csrc/kernels/cuda/includes/util.h | 2 + lightseq/csrc/kernels/cuda/llama_kernels.cu | 479 ++++++++++++++++++ lightseq/csrc/kernels/cuda/util.cc.cu | 17 + lightseq/csrc/layers_new/CMakeLists.txt | 3 + .../includes/launch_llama_emb_layer.h | 60 +++ .../includes/llama_attention_layer.h | 65 +++ .../csrc/layers_new/includes/llama_layer.h | 42 ++ .../layers_new/includes/llama_mlp_layer.h | 51 ++ .../csrc/layers_new/includes/rms_norm_layer.h | 64 +++ .../csrc/layers_new/llama_attention_layer.cpp | 109 ++++ lightseq/csrc/layers_new/llama_layer.cpp | 44 ++ lightseq/csrc/layers_new/llama_mlp_layer.cpp | 69 +++ lightseq/csrc/models/CMakeLists.txt | 2 +- lightseq/csrc/models/includes/llama.h | 61 +++ lightseq/csrc/models/llama.cc | 285 +++++++++++ lightseq/csrc/ops_new/CMakeLists.txt | 7 +- lightseq/csrc/ops_new/act_elewise_product.cpp | 35 ++ lightseq/csrc/ops_new/fuse_add2_op.cpp | 35 ++ .../csrc/ops_new/fuse_rotary_position_qkv.cpp | 41 ++ .../ops_new/includes/act_elewise_product.h | 42 ++ lightseq/csrc/ops_new/includes/fuse_add2_op.h | 39 ++ .../includes/fuse_rotary_position_qkv.h | 97 ++++ .../csrc/ops_new/includes/launch_llama_emb.h | 57 +++ lightseq/csrc/ops_new/includes/linear.h | 14 +- .../csrc/ops_new/includes/rms_layer_norm.h | 44 ++ lightseq/csrc/ops_new/includes/sampling.h | 2 + lightseq/csrc/ops_new/launch_llama_emb.cpp | 56 ++ lightseq/csrc/ops_new/linear.cpp | 21 +- lightseq/csrc/ops_new/rms_layer_norm.cpp | 59 +++ lightseq/csrc/proto/CMakeLists.txt | 2 +- lightseq/csrc/proto/includes/hdf5_util.h | 26 + lightseq/csrc/proto/includes/llama_weight.h | 93 ++++ lightseq/csrc/proto/llama.proto | 60 +++ lightseq/csrc/proto/llama_weight.cc | 329 ++++++++++++ lightseq/csrc/pybind/pybind_kernel_cuda.cpp | 199 +++++++- lightseq/csrc/pybind/pybind_model.cpp | 75 +++ .../pytorch/builder/cuda_kernel_builder.py | 1 + lightseq/csrc/tests/cuda/test_kernel.py | 211 ++++++-- 48 files changed, 3312 insertions(+), 69 deletions(-) create mode 100644 lightseq/csrc/example/llama_example.cc create mode 100644 lightseq/csrc/export/__init__.py create mode 100644 lightseq/csrc/export/hf_llama_export.py create mode 100644 lightseq/csrc/export/util.py create mode 100644 lightseq/csrc/kernels/cuda/includes/llama_kernels.h create mode 100644 lightseq/csrc/kernels/cuda/llama_kernels.cu create mode 100644 lightseq/csrc/layers_new/includes/launch_llama_emb_layer.h create mode 100644 lightseq/csrc/layers_new/includes/llama_attention_layer.h create mode 100644 lightseq/csrc/layers_new/includes/llama_layer.h create mode 100644 lightseq/csrc/layers_new/includes/llama_mlp_layer.h create mode 100644 lightseq/csrc/layers_new/includes/rms_norm_layer.h create mode 100644 lightseq/csrc/layers_new/llama_attention_layer.cpp create mode 100644 lightseq/csrc/layers_new/llama_layer.cpp create mode 100644 lightseq/csrc/layers_new/llama_mlp_layer.cpp create mode 100644 lightseq/csrc/models/includes/llama.h create mode 100644 lightseq/csrc/models/llama.cc create mode 100644 lightseq/csrc/ops_new/act_elewise_product.cpp create mode 100644 lightseq/csrc/ops_new/fuse_add2_op.cpp create mode 100644 lightseq/csrc/ops_new/fuse_rotary_position_qkv.cpp create mode 100644 lightseq/csrc/ops_new/includes/act_elewise_product.h create mode 100644 lightseq/csrc/ops_new/includes/fuse_add2_op.h create mode 100644 lightseq/csrc/ops_new/includes/fuse_rotary_position_qkv.h create mode 100644 lightseq/csrc/ops_new/includes/launch_llama_emb.h create mode 100644 lightseq/csrc/ops_new/includes/rms_layer_norm.h create mode 100644 lightseq/csrc/ops_new/launch_llama_emb.cpp create mode 100644 lightseq/csrc/ops_new/rms_layer_norm.cpp create mode 100644 lightseq/csrc/proto/includes/hdf5_util.h create mode 100644 lightseq/csrc/proto/includes/llama_weight.h create mode 100644 lightseq/csrc/proto/llama.proto create mode 100644 lightseq/csrc/proto/llama_weight.cc diff --git a/build.sh b/build.sh index 188c4edc..532ca38c 100755 --- a/build.sh +++ b/build.sh @@ -2,6 +2,6 @@ if [ ! -d 'build' ]; then mkdir build fi # DEVICE_ARCH could be cuda/x86/arm -cd build && cmake -DUSE_NEW_ARCH=OFF -DDEVICE_ARCH=cuda -DUSE_TRITONBACKEND=OFF -DDEBUG_MODE=OFF -DFP16_MODE=OFF -DMEM_DEBUG=OFF .. && make -j${nproc} +cd build && cmake -DUSE_NEW_ARCH=ON -DDEVICE_ARCH=cuda -DUSE_TRITONBACKEND=OFF -DDEBUG_MODE=ON -DFP16_MODE=ON -DMEM_DEBUG=OFF .. && make -j${nproc} # you can use comand like below to compile lightseq with pybind interface: # sudo PATH=$PATH:/usr/local/hdf5 CUDACXX=/usr/local/cuda/bin/nvcc DEVICE_ARCH=cuda ENABLE_FP32=0 ENABLE_DEBUG=0 ENABLE_NEW_ARCH=1 python3 setup.py install diff --git a/lightseq/csrc/example/CMakeLists.txt b/lightseq/csrc/example/CMakeLists.txt index 84940a43..832ab114 100644 --- a/lightseq/csrc/example/CMakeLists.txt +++ b/lightseq/csrc/example/CMakeLists.txt @@ -8,3 +8,6 @@ target_link_libraries(transformer_example PUBLIC liblightseq) add_executable(gpt_example gpt_example.cc) target_link_libraries(gpt_example PUBLIC liblightseq) + +add_executable(llama_example llama_example.cc) +target_link_libraries(llama_example PUBLIC liblightseq) diff --git a/lightseq/csrc/example/llama_example.cc b/lightseq/csrc/example/llama_example.cc new file mode 100644 index 00000000..d1e29815 --- /dev/null +++ b/lightseq/csrc/example/llama_example.cc @@ -0,0 +1,94 @@ +#include "model_base.h" +#include "llama.h" + +/** +@file +Example of how to run gpt inference using our implementation. +*/ + +int main(int argc, char* argv[]) { + std::string model_weights_path = argv[1]; + std::vector example_input = {1, 21784, 26539, 338, + 263, 4933, 6509, 6890}; + int eg_seq_len = example_input.size(); + + int batch_size = 1; + int batch_seq_len = eg_seq_len; + + if (argc == 4) { + batch_size = atoi(argv[2]); + batch_seq_len = atoi(argv[3]); + } + + int max_batch_size = std::max(8, batch_size); + + std::vector host_input; + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < batch_seq_len; ++j) { + host_input.push_back(example_input[j % eg_seq_len]); + } + } + + auto model = lightseq::cuda::LSModelFactory::GetInstance().CreateModel( + "Llama", model_weights_path, 1); + + void* d_input; + CHECK_GPU_ERROR( + cudaMalloc(&d_input, sizeof(int) * batch_size * batch_seq_len)); + CHECK_GPU_ERROR(cudaMemcpy(d_input, host_input.data(), + sizeof(int) * batch_size * batch_seq_len, + cudaMemcpyHostToDevice)); + + printf("example step.1\n"); + + model->set_input_ptr(0, d_input); + model->set_input_shape(0, {batch_size, batch_seq_len}); + + for (int i = 0; i < model->get_output_size(); i++) { + void* d_output; + std::vector shape = model->get_output_max_shape(i); + int total_size = 1; + for (int j = 0; j < shape.size(); j++) { + total_size *= shape[j]; + } + CHECK_GPU_ERROR(cudaMalloc(&d_output, total_size * sizeof(int))); + model->set_output_ptr(i, d_output); + } + printf("example step.2\n"); + CHECK_GPU_ERROR(cudaStreamSynchronize(0)); + std::cout << "infer preprocessing finished" << std::endl; + printf("example step.2-1\n"); + std::cout << "infer preprocessing finished 2" << std::endl; + + std::chrono::duration elapsed; + int iter = 0; + /* ---step5. infer and log--- */ + for (int i = 0; i < 5; i++) { + auto start = std::chrono::high_resolution_clock::now(); + model->Infer(); + auto finish = std::chrono::high_resolution_clock::now(); + if (i) { + iter++; + elapsed += finish - start; + } + } + + std::cout << "lightseq inference latency: " << elapsed.count() * 1000 / iter + << " ms" << std::endl; + + for (int i = 0; i < model->get_output_size(); i++) { + const int* d_output; + d_output = static_cast(model->get_output_ptr(i)); + std::vector shape = model->get_output_shape(i); + std::cout << "output shape: "; + for (int j = 0; j < shape.size(); j++) { + std::cout << shape[j] << " "; + } + std::cout << std::endl; + if (i == 0) { + lightseq::print_vec(d_output, "d_output", shape[2]); + } + } + + return 0; +} diff --git a/lightseq/csrc/export/__init__.py b/lightseq/csrc/export/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightseq/csrc/export/hf_llama_export.py b/lightseq/csrc/export/hf_llama_export.py new file mode 100644 index 00000000..5beccec8 --- /dev/null +++ b/lightseq/csrc/export/hf_llama_export.py @@ -0,0 +1,153 @@ +""" +Export Hugging Face GPT2 models to hdf5 format. +""" +import __init__ +import os +import h5py +import numpy as np +from collections import OrderedDict +from util import parse_args, check_arguements, ModelArguements, fill_hdf5_layer +import torch + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + + +""" +For the mapping dictionary: key is the value of the proto parameter, +value is a powerful expression, each && split tensor name of the matching path or expression. + +The sub-pattern of the path is separated by spaces, and the expression starts with a expression_. +You can operate separately on each tensor and support multiple expressions. Multiple matching paths +and the expression will finally be concatenated on axis = -1. +""" + + +""" +'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.rotary_emb.inv_freq', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.post_attention_layernorm.weight' +""" + +dec_layer_mapping_dict = OrderedDict( + { + "attention_norm_scale": "input_layernorm weight", + "attention_project_qkv": "self_attn q_proj weight&&self_attn k_proj weight&&self_attn v_proj weight&&expression_.transpose(0, 1)", + "attention_output": "self_attn o_proj weight&&expression_.transpose(0, 1)", + "ffn_norm_scale": "post_attention_layernorm weight", + "gate_up_project_weight": "mlp gate_proj weight&&mlp up_proj weight&&expression_.transpose(0, 1)", + "down_project_weight": "mlp down_proj weight&&expression_.transpose(0, 1)", + } +) + +src_emb_mapping_dict = OrderedDict( + { + "post_norm_scale": "norm weight", + "token_embedding": "embed_tokens weight", + "logits_linear_weight": "lm_head weight&&expression_.transpose(0, 1)", + } +) + + +def extract_llama_weights( + output_file: str, + arguments: ModelArguements, +): + # load var names + state_dict = torch.load(arguments.model_file) + + head_num = arguments.head_num + enc_var_name_list = list(state_dict.keys()) + + # initialize output file + output_file += ".hdf5" + print("Saving model to hdf5...") + print("Writing to {0}".format(output_file)) + + # exit(0) + hdf5_file = h5py.File(output_file, "w") + + # fill each encoder layer's params + enc_tensor_names = {} + for name in enc_var_name_list: + name_split = name.split(".") + if len(name_split) <= 2 or not name_split[2].isdigit(): + continue + layer_id = int(name_split[2]) + enc_tensor_names.setdefault(layer_id, []).append(name) + + # fill encoder_stack + for layer_id in sorted(enc_tensor_names.keys()): + fill_hdf5_layer( + enc_tensor_names[layer_id], + state_dict, + hdf5_file, + f"decoder_layers/{layer_id}/", + dec_layer_mapping_dict, + ) + + # fill src_embedding - except for position embedding + fill_hdf5_layer( + enc_var_name_list, + state_dict, + hdf5_file, + "src_embedding/", + src_emb_mapping_dict, + ) + + # save number of layers metadata + hdf5_file.create_dataset( + "model_conf/hidden_size", data=arguments.hidden_size, dtype="i4" + ) + hdf5_file.create_dataset( + "model_conf/inner_size", data=arguments.inner_size, dtype="i4" + ) + hdf5_file.create_dataset("model_conf/max_step", data=arguments.max_step, dtype="i4") + hdf5_file.create_dataset("model_conf/head_num", data=arguments.head_num, dtype="i4") + hdf5_file.create_dataset( + "model_conf/layer_num", data=arguments.layer_num, dtype="i4" + ) + hdf5_file.create_dataset( + "model_conf/src_padding_id", data=arguments.padding_id, dtype="i4" + ) + hdf5_file.create_dataset( + "model_conf/generate_method", + data=np.array([ord(c) for c in arguments.generation_method]).astype(np.int8), + dtype="i1", + ) + hdf5_file.create_dataset("model_conf/topp", data=arguments.topp, dtype="f4") + hdf5_file.create_dataset("model_conf/topk", data=arguments.topk, dtype="i4") + hdf5_file.create_dataset("model_conf/eos_id", data=arguments.eos_id, dtype="i4") + hdf5_file.create_dataset( + "model_conf/extra_decode_length", data=arguments.extra_decode_length, dtype="i4" + ) + hdf5_file.create_dataset( + "model_conf/src_vocab_size", data=arguments.vocab_size, dtype="i4" + ) + + hdf5_file.close() + # read-in again to double check + hdf5_file = h5py.File(output_file, "r") + + def _print_pair(key, value): + if key == "generate_method": + value = "".join(map(chr, value[()])) + else: + value = value[()] + print(f"{key}: {value}") + + list(map(lambda x: _print_pair(*x), hdf5_file["model_conf"].items())) + + +if __name__ == "__main__": + args = parse_args() + + arguments = ModelArguements(args) + basename = os.path.basename(arguments.model_repo) + output_lightseq_model_name = "_".join(["lightseq_llama", basename, "7b"]) + # default eos_id from https://huggingface.co/transformers/model_doc/gpt2.html#gpt2lmheadmodel + + arguments.eos_id = 2 # need to set + arguments.padding_id = 0 # need to set + + if not check_arguements(arguments): + exit(0) + + extract_llama_weights(output_lightseq_model_name, arguments) diff --git a/lightseq/csrc/export/util.py b/lightseq/csrc/export/util.py new file mode 100644 index 00000000..c57a65d1 --- /dev/null +++ b/lightseq/csrc/export/util.py @@ -0,0 +1,189 @@ +import argparse +from re import L +import os +import h5py +import json +import numpy as np + + +def parse_args(): + parser = argparse.ArgumentParser(description="export fairseq checkpoint", usage="") + parser.add_argument( + "--model_file", + "-m", + type=str, + required=True, + help="path of pytorch model repo", + ) + parser.add_argument( + "--generation_method", + "-g", + type=str, + required=True, + choices=["beam_search", "topk_greedy", "topk", "topp", "ppl"], + help="generation method", + ) + parser.add_argument( + "--beam_size", + type=int, + required=False, + default=None, + ) + parser.add_argument( + "--topk", + type=int, + required=False, + default=None, + ) + parser.add_argument( + "--topp", + type=float, + required=False, + default=None, + ) + parser.add_argument( + "--extra_decode_length", + type=int, + required=False, + default=None, + ) + args = parser.parse_args() + return args + + +def check_arguements(args): + if args.generation_method == "beam_search": + if args.beam_size == None: + raise Exception("set 'beam_size' value while using beam search method.") + elif args.generation_method == "topk": + if args.topk == None: + raise Exception("set 'topk' value while using topk sample method.") + ... + elif args.generation_method == "topp": + if args.topp == None: + raise Exception("set 'topp' value while using topp sample method.") + + if args.eos_id == None: + raise Exception("eos id should not be set as None") + + if args.padding_id == None: + raise Exception("padding id shoud not be set as None") + + if args.beam_size == None: + args.beam_size = 1 + + if args.topp == None: + args.topp = 1.0 + + if args.topk == None: + args.topk = 1 + + return True + + +class ModelArguements(object): + def __init__(self, args): + self.model_file = os.path.abspath(args.model_file) + if not os.path.isfile(self.model_file): + raise Exception(f"there is no such model file {self.model_file}") + + self.model_repo = os.path.dirname(self.model_file) + self.generation_method = args.generation_method + self.beam_size = args.beam_size + self.topk = args.topk + self.topp = args.topp + self.eos_id = None + self.bos_id = None + self.config_path = os.path.join(self.model_repo, "config.json") + + if not os.path.isfile(self.config_path): + raise Exception(f"there is no such config file {self.config_path}") + + config_file = open(self.config_path) + config = json.load(config_file) + config_file.close() + + self.padding_id = config.get("pad_token_id") + self.max_step = config.get("max_sequence_length") + self.hidden_size = config.get("hidden_size") + self.inner_size = config.get("intermediate_size") + self.head_num = config.get("num_attention_heads") + self.vocab_size = config.get("vocab_size") + self.layer_num = config.get("num_hidden_layers") + self.extra_decode_length = ( + self.max_step + if args.extra_decode_length is None + else args.extra_decode_length + ) + + +def apply_rule(proto_name, ckpt_rule, tensor_names, state_dict): + def check_rule(tensor_name, rule): + if "Adam" in tensor_name or "adam" in tensor_name: + return False + assert isinstance(rule, str) and rule + rule = rule.split("-") + assert len(rule) < 3 + if len(rule) == 2: + white, black = rule[0].split(" "), rule[1].split(" ") + else: + white, black = rule[0].split(" "), [] + for b in black: + if b in tensor_name.split("."): + return False + for w in white: + if w not in tensor_name.split("."): + return False + return True + + expression = [ele for ele in ckpt_rule.split("&&") if ele.startswith("expression_")] + + ckpt_rule = [ + ele for ele in ckpt_rule.split("&&") if not ele.startswith("expression_") + ] + + assert (len(ckpt_rule) > 0 and len(expression) < 2) or ( + len(ckpt_rule) == 0 and len(expression) > 0 + ) + + if len(expression) < 2: + expression = "" if not expression else expression[0].split("_")[1] + else: + expression = [exp.split("_")[1] for exp in expression] + + target_tn = [] + for cr in ckpt_rule: + tmp = [] + for tn in tensor_names: + if check_rule(tn, cr): + tmp.append(tn) + assert len(tmp) == 1 + target_tn.extend(tmp) + target_tensor = [state_dict[name].float() for name in target_tn] + tt = {} + if target_tensor: + exec("tt['save'] = [ele%s for ele in target_tensor]" % expression) + else: + if not isinstance(expression, list): + expression = [expression] + exec("tt['save'] = [%s]" % ",".join(expression)) + + try: + target_tensor = np.concatenate(tt["save"], axis=-1) + except: + target_tensor = tt["save"] + print( + "%s -> %s, convert finished!" + % (target_tn if target_tn else "created", proto_name) + ) + return target_tensor[0] if type(target_tensor) is list else target_tensor + + +def fill_hdf5_layer( + tensor_names, state_dict, hdf5_file, hdf5_dataset_prefix, mapping_dict +): + for proto_name, ckpt_rule in mapping_dict.items(): + target_tensor = apply_rule(proto_name, ckpt_rule, tensor_names, state_dict) + hdf5_file.create_dataset( + hdf5_dataset_prefix + proto_name, data=target_tensor.flatten().tolist() + ) diff --git a/lightseq/csrc/kernels/cuda/CMakeLists.txt b/lightseq/csrc/kernels/cuda/CMakeLists.txt index 5da1ab8e..c7374c3f 100644 --- a/lightseq/csrc/kernels/cuda/CMakeLists.txt +++ b/lightseq/csrc/kernels/cuda/CMakeLists.txt @@ -11,6 +11,7 @@ set(cuda_kernel_files # fused_adam_kernel.cu general_kernels.cu gptKernels.cc.cu + llama_kernels.cu normalize_kernels.cu softmax_kernels.cu softmax_kernels_new.cu diff --git a/lightseq/csrc/kernels/cuda/includes/kernel_headers.h b/lightseq/csrc/kernels/cuda/includes/kernel_headers.h index 8660d00a..17eea766 100644 --- a/lightseq/csrc/kernels/cuda/includes/kernel_headers.h +++ b/lightseq/csrc/kernels/cuda/includes/kernel_headers.h @@ -20,3 +20,4 @@ #include "transformerKernels.h" #include "cuda_util.h" #include "cublas_wrappers.h" +#include "llama_kernels.h" diff --git a/lightseq/csrc/kernels/cuda/includes/kernels.h b/lightseq/csrc/kernels/cuda/includes/kernels.h index 2bb440fe..f224fa3c 100644 --- a/lightseq/csrc/kernels/cuda/includes/kernels.h +++ b/lightseq/csrc/kernels/cuda/includes/kernels.h @@ -7,6 +7,7 @@ #include #include #include +#include "cmath" #define MAX_THREADS 1024 #define WARP_SIZE 32 diff --git a/lightseq/csrc/kernels/cuda/includes/llama_kernels.h b/lightseq/csrc/kernels/cuda/includes/llama_kernels.h new file mode 100644 index 00000000..1b0a8cc1 --- /dev/null +++ b/lightseq/csrc/kernels/cuda/includes/llama_kernels.h @@ -0,0 +1,39 @@ +#pragma once +#include +#include +#include +#include "kernels.h" +#include + +namespace lightseq { +namespace cuda { + +template +void launch_llama_embedding(const T *token_emb, const int *tokens, T *output, + T *pad_mask_ptr, int *left_pad_len_ptr, + int batch_size, int beam_size, int hidden_dim, + int step_offset, int seq_len, int max_step, + int padding_id, cudaStream_t stream); + +template +void launch_split_rotary_position_qkv(const T *input_ptr, const T *sin_ptr, + const T *cos_ptr, T *q_out, + T *cache_k_out, T *cache_v_out, + size_t max_step, size_t batch_size, + size_t nhead, size_t offset_seq_len, + size_t query_len, size_t head_dim, + cudaStream_t stream); + +template +void launch_silu_elewise_product(const T *inp_ptr, T *out_ptr, + size_t batch_size, size_t seq_len, + size_t inner_size, cudaStream_t stream); + +template +void launch_rms_layer_norm(const T *inp_ptr, const T *scale_ptr, T *out_ptr, + T *res_ptr, T *rms_ptr, size_t batch_tokens, + size_t hidden_dim, cudaStream_t stream, + const float ln_epsilon = 1e-6f); + +} // namespace cuda +} // namespace lightseq diff --git a/lightseq/csrc/kernels/cuda/includes/util.h b/lightseq/csrc/kernels/cuda/includes/util.h index 206b7abb..0fa8e760 100644 --- a/lightseq/csrc/kernels/cuda/includes/util.h +++ b/lightseq/csrc/kernels/cuda/includes/util.h @@ -98,5 +98,7 @@ float dequantize(unsigned char i, float scale, float clip_max); void dequantize_array(std::vector& i8, std::vector& f, float clip_max, float quant_range, int start, int num); +void launch_convert_dtype(float* source_buffer, __half* target_buffer, + size_t size, int max_thread, cudaStream_t stream); } // namespace cuda } // namespace lightseq diff --git a/lightseq/csrc/kernels/cuda/llama_kernels.cu b/lightseq/csrc/kernels/cuda/llama_kernels.cu new file mode 100644 index 00000000..d6ae1a2e --- /dev/null +++ b/lightseq/csrc/kernels/cuda/llama_kernels.cu @@ -0,0 +1,479 @@ +/* + Copyright 2023 Bytedance Lab-nlp +*/ + +#include "kernels.h" +#include "llama_kernels.h" +#include "common.h" +#include +#include +#include "block_reduce.h" +#include +#include + +namespace cg = cooperative_groups; + +namespace lightseq { +namespace cuda { + +template +__global__ void kernel_llama_padding(const T* token_emb, const int* token_ids, + T* output, T* pad_mask_ptr, + int* left_pad_len_ptr, int batch_size, + int beam_size, int seq_len, int hidden_dim, + int padding_id, int max_step, + int step_offset) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= batch_size * beam_size * seq_len * hidden_dim) { + return; + } + int batch_idx, beam_idx, seq_idx, state_idx; + decompose_4dim(idx, beam_size, seq_len, hidden_dim, &batch_idx, &beam_idx, + &seq_idx, &state_idx); + int token_idx = flat_3dim(batch_idx, beam_idx, seq_idx + step_offset, + beam_size, max_step); + int token_id = token_ids[token_idx]; + int batch_beam_idx = batch_idx * beam_size + beam_idx; + + float4& output_val = ((float4*)output)[idx]; + if (token_id == padding_id) { + if (state_idx == 0) { + pad_mask_ptr[token_idx] = CUDA_FLOAT_INF_NEG; + atomicAdd(left_pad_len_ptr + batch_beam_idx, 1); + } + output_val.x = 0.; + output_val.y = 0.; + output_val.z = 0.; + output_val.w = 0.; + } +} + +template +__global__ void kernel_llama_embedding(const T* token_emb, const int* token_ids, + T* output, T* pad_mask_ptr, + int* left_pad_len_ptr, int batch_size, + int beam_size, int seq_len, + int hidden_dim, int padding_id, + int max_step, int step_offset) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= batch_size * beam_size * seq_len * hidden_dim) { + return; + } + int batch_idx, beam_idx, seq_idx, state_idx; + decompose_4dim(idx, beam_size, seq_len, hidden_dim, &batch_idx, &beam_idx, + &seq_idx, &state_idx); + int token_idx = flat_3dim(batch_idx, beam_idx, seq_idx + step_offset, + beam_size, max_step); + int token_id = token_ids[token_idx]; + + float4& output_val = ((float4*)output)[idx]; + if (token_id != padding_id) { + if (state_idx == 0) { + pad_mask_ptr[token_idx] = 0; + } + output_val = ((float4*)token_emb)[token_id * hidden_dim + state_idx]; + } +} + +template <> +__global__ void kernel_llama_padding<__half>( + const __half* token_emb, const int* token_ids, __half* output, + __half* pad_mask_ptr, int* left_pad_len_ptr, int batch_size, int beam_size, + int seq_len, int hidden_dim, int padding_id, int max_step, + int step_offset) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= batch_size * beam_size * seq_len * hidden_dim) { + return; + } + int batch_idx, beam_idx, seq_idx, state_idx; + decompose_4dim(idx, beam_size, seq_len, hidden_dim, &batch_idx, &beam_idx, + &seq_idx, &state_idx); + int token_idx = flat_3dim(batch_idx, beam_idx, seq_idx + step_offset, + beam_size, max_step); + int token_id = token_ids[token_idx]; + int batch_beam_idx = batch_idx * beam_size + beam_idx; + + float4& output_val = ((float4*)output)[idx]; + if (token_id == padding_id) { + if (state_idx == 0) { + pad_mask_ptr[token_idx] = __float2half(CUDA_FLOAT_INF_NEG); + atomicAdd(left_pad_len_ptr + batch_beam_idx, 1); + } + output_val.x = 0.f; + output_val.y = 0.f; + output_val.z = 0.f; + output_val.w = 0.f; + } +} + +template <> +__global__ void kernel_llama_embedding<__half>( + const __half* token_emb, const int* token_ids, __half* output, + __half* pad_mask_ptr, int* left_pad_len_ptr, int batch_size, int beam_size, + int seq_len, int hidden_dim, int padding_id, int max_step, + int step_offset) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= batch_size * beam_size * seq_len * hidden_dim) { + return; + } + int batch_idx, beam_idx, seq_idx, state_idx; + decompose_4dim(idx, beam_size, seq_len, hidden_dim, &batch_idx, &beam_idx, + &seq_idx, &state_idx); + int token_idx = flat_3dim(batch_idx, beam_idx, seq_idx + step_offset, + beam_size, max_step); + int token_id = token_ids[token_idx]; + + float4& output_val = ((float4*)output)[idx]; + + if (token_id != padding_id) { + if (state_idx == 0) { + pad_mask_ptr[token_idx] = __float2half(0.f); + } + output_val = ((float4*)token_emb)[token_id * hidden_dim + state_idx]; + } +} + +template <> +void launch_llama_embedding(const float* token_emb, const int* tokens, + float* output, float* pad_mask_ptr, + int* left_pad_len_ptr, int batch_size, + int beam_size, int hidden_dim, + int step_offset, int seq_len, int max_step, + int padding_id, cudaStream_t stream) { + if (seq_len + step_offset >= max_step) { + throw std::runtime_error("violate seq_len + step_offset < max_step"); + } + if (hidden_dim % 4) { + throw std::runtime_error("violate hidden_dim % 4 = 0"); + } + hidden_dim >>= 2; + int nele = (batch_size * beam_size * seq_len * hidden_dim); + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_llama_padding<<>>( + token_emb, tokens, output, pad_mask_ptr, left_pad_len_ptr, batch_size, + beam_size, seq_len, hidden_dim, padding_id, max_step, step_offset); + + kernel_llama_embedding<<>>( + token_emb, tokens, output, pad_mask_ptr, left_pad_len_ptr, batch_size, + beam_size, seq_len, hidden_dim, padding_id, max_step, step_offset); +} + +template <> +void launch_llama_embedding<__half>(const __half* token_emb, const int* tokens, + __half* output, __half* pad_mask_ptr, + int* left_pad_len_ptr, int batch_size, + int beam_size, int hidden_dim, + int step_offset, int seq_len, int max_step, + int padding_id, cudaStream_t stream) { + if (seq_len + step_offset >= max_step) { + throw std::runtime_error("violate seq_len + step_offset < max_step"); + } + if (hidden_dim % 8) { + throw std::runtime_error("violate hidden_dim % 8 = 0"); + } + hidden_dim >>= 3; + int nele = (batch_size * beam_size * seq_len * hidden_dim); + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_llama_padding<__half><<>>( + token_emb, tokens, output, pad_mask_ptr, left_pad_len_ptr, batch_size, + beam_size, seq_len, hidden_dim, padding_id, max_step, step_offset); + kernel_llama_embedding<__half><<>>( + token_emb, tokens, output, pad_mask_ptr, left_pad_len_ptr, batch_size, + beam_size, seq_len, hidden_dim, padding_id, max_step, step_offset); +} + +template void launch_llama_embedding( + const float* token_emb, const int* tokens, float* output, + float* pad_mask_ptr, int* left_pad_len_ptr, int batch_size, int beam_size, + int hidden_dim, int step_offset, int seq_len, int max_step, int padding_id, + cudaStream_t stream); + +template void launch_llama_embedding<__half>( + const __half* token_emb, const int* tokens, __half* output, + __half* pad_mask_ptr, int* left_pad_len_ptr, int batch_size, int beam_size, + int hidden_dim, int step_offset, int seq_len, int max_step, int padding_id, + cudaStream_t stream); + +template +__global__ void kernel_split_rotary_position_qkv( + const T* input_ptr, const T* sin_ptr, const T* cos_ptr, T* q_out, + T* cache_k_out, T* cache_v_out, size_t batch_size, size_t max_step, + size_t nhead, size_t offset_seq_len, size_t query_len, size_t head_dim, + size_t max_thread_num) { + size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= max_thread_num) { + return; + } + int batch_idx, qkv_idx, head_idx, seq_idx, head_dim_idx; + decompose_5dim(idx, query_len, 3, nhead, head_dim, &batch_idx, &seq_idx, + &qkv_idx, &head_idx, &head_dim_idx); + + size_t output_idx = 0; + if (qkv_idx) { + output_idx = flat_4dim(batch_idx, head_idx, offset_seq_len + seq_idx, + head_dim_idx, nhead, max_step, head_dim); + } else { + output_idx = flat_4dim(batch_idx, head_idx, seq_idx, head_dim_idx, nhead, + query_len, head_dim); + } + + // cos part + T state_val1 = *(input_ptr + idx); + + if (qkv_idx == 2) { + *(cache_v_out + output_idx) = state_val1; + } else { + T cos_val = *(cos_ptr + (offset_seq_len + seq_idx) * head_dim / 2 + + (head_dim_idx % (head_dim / 2))); + T sin_val = *(sin_ptr + (offset_seq_len + seq_idx) * head_dim / 2 + + (head_dim_idx % (head_dim / 2))); + T out_val = 0.; + if (head_dim_idx < head_dim / 2) { + T state_val2 = *(input_ptr + idx + head_dim / 2); + out_val = state_val1 * cos_val - state_val2 * sin_val; + } else { + T state_val2 = *(input_ptr + idx - head_dim / 2); + out_val = state_val1 * cos_val + state_val2 * sin_val; + } + + if (qkv_idx == 0) { + *(q_out + output_idx) = out_val; + } else { + *(cache_k_out + output_idx) = out_val; + } + } +} + +template +void launch_split_rotary_position_qkv(const T* input_ptr, const T* sin_ptr, + const T* cos_ptr, T* q_out, + T* cache_k_out, T* cache_v_out, + size_t max_step, size_t batch_size, + size_t nhead, size_t offset_seq_len, + size_t query_len, size_t head_dim, + cudaStream_t stream) { + size_t nele = 3 * batch_size * nhead * query_len * head_dim; + size_t nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_split_rotary_position_qkv<<>>( + input_ptr, sin_ptr, cos_ptr, q_out, cache_k_out, cache_v_out, batch_size, + max_step, nhead, offset_seq_len, query_len, head_dim, nele); +} + +template void launch_split_rotary_position_qkv( + const float* input_ptr, const float* sin_ptr, const float* cos_ptr, + float* q_out, float* cache_k_out, float* cache_v_out, size_t max_step, + size_t batch_size, size_t nhead, size_t offset_seq_len, size_t query_len, + size_t head_dim, cudaStream_t stream); + +template void launch_split_rotary_position_qkv<__half>( + const __half* input_ptr, const __half* sin_ptr, const __half* cos_ptr, + __half* q_out, __half* cache_k_out, __half* cache_v_out, size_t max_step, + size_t batch_size, size_t nhead, size_t offset_seq_len, size_t query_len, + size_t head_dim, cudaStream_t stream); + +template +__global__ void kernel_silu_elewise_product(const T* inp_ptr, T* out_ptr, + size_t seq_len, size_t inner_size, + size_t max_thread_num) { + size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= max_thread_num) { + return; + } + int inpA_idx = idx / inner_size * inner_size * 2 + idx % inner_size; + int inpB_idx = inpA_idx + inner_size; + const T& inpA = *(inp_ptr + inpA_idx); + const T& inpB = *(inp_ptr + inpB_idx); + *(out_ptr + idx) = inpA / (1.f + __expf(-inpA)) * inpB; +} + +template <> +__global__ void kernel_silu_elewise_product<__half>(const __half* inp_ptr, + __half* out_ptr, + size_t seq_len, + size_t inner_size, + size_t max_thread_num) { + size_t idx = (size_t)blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= max_thread_num) { + return; + } + // const __half& ele_product = *(inpA_ptr + idx); + int inpA_idx = idx / inner_size * inner_size * 2 + idx % inner_size; + int inpB_idx = inpA_idx + inner_size; + const __half& inpA = *(inp_ptr + inpA_idx); + const __half& inpB = *(inp_ptr + inpB_idx); + *(out_ptr + idx) = inpA / __float2half(1.f + __expf(-inpA)) * inpB; +} + +template +void launch_silu_elewise_product(const T* inp_ptr, T* out_ptr, + size_t batch_size, size_t seq_len, + size_t inner_size, cudaStream_t stream) { + size_t nele = batch_size * seq_len * inner_size; + size_t nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_silu_elewise_product<<>>( + inp_ptr, out_ptr, seq_len, inner_size, nele); +} + +template void launch_silu_elewise_product( + const float* inp_ptr, float* out_ptr, size_t batch_size, size_t seq_len, + size_t inner_size, cudaStream_t stream); +template void launch_silu_elewise_product<__half>( + const __half* inp_ptr, __half* out_ptr, size_t batch_size, size_t seq_len, + size_t inner_size, cudaStream_t stream); + +template +__global__ void ker_rms_layer_norm(const T* inp_ptr, const T* scale_ptr, + T* out_ptr, T* rms_ptr, size_t hidden_dim, + const float ln_epsilon) { + // step 0. compute local sum + float l_square_sum = 0; + const T* thread_inp = inp_ptr + blockIdx.x * hidden_dim; + for (uint idx = threadIdx.x; idx < hidden_dim; idx += blockDim.x) { + l_square_sum += thread_inp[idx] * thread_inp[idx]; + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_dim); + float kReduce[1] = {l_square_sum}; + blockReduce(kReduce); + __shared__ float s_var; + if (threadIdx.x == 0) { + s_var = rsqrtf(kReduce[0] / mean_dim + ln_epsilon); + rms_ptr[blockIdx.x] = s_var; + } + __syncthreads(); + + // step 2. layer norm result + T* thread_out = out_ptr + blockIdx.x * hidden_dim; + for (uint idx = threadIdx.x; idx < hidden_dim; idx += blockDim.x) { + thread_out[idx] = thread_inp[idx] * scale_ptr[idx] * s_var; + } +} + +template <> +__global__ void ker_rms_layer_norm<__half>(const __half* inp_ptr, + const __half* scale_ptr, + __half* out_ptr, __half* rms_ptr, + size_t hidden_dim, + const float ln_epsilon) { + // step 0. compute local sum + float l_square_sum = 0; + const __half* thread_inp = inp_ptr + blockIdx.x * hidden_dim; + for (uint idx = threadIdx.x; idx < hidden_dim; idx += blockDim.x) { + float float_inp = __half2float(thread_inp[idx]); + l_square_sum += float_inp * float_inp; + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_dim); + float kReduce[1] = {l_square_sum}; + blockReduce(kReduce); + __shared__ __half s_var; + if (threadIdx.x == 0) { + s_var = __float2half(rsqrtf(kReduce[0] / mean_dim + ln_epsilon)); + if (rms_ptr != nullptr) rms_ptr[blockIdx.x] = s_var; + } + __syncthreads(); + + // step 2. layer norm result + __half* thread_out = out_ptr + blockIdx.x * hidden_dim; + for (uint idx = threadIdx.x; idx < hidden_dim; idx += blockDim.x) { + thread_out[idx] = thread_inp[idx] * scale_ptr[idx] * s_var; + } +} + +template +__global__ void ker_rms_layer_norm_with_res(const T* inp_ptr, + const T* scale_ptr, T* out_ptr, + T* res_ptr, T* rms_ptr, + size_t hidden_dim, + const float ln_epsilon) { + // step 0. compute local sum + float l_square_sum = 0; + const T* thread_inp = inp_ptr + blockIdx.x * hidden_dim; + T* res_thread_out = res_ptr + blockIdx.x * hidden_dim; + for (uint idx = threadIdx.x; idx < hidden_dim; idx += blockDim.x) { + l_square_sum += thread_inp[idx] * thread_inp[idx]; + res_thread_out[idx] = thread_inp[idx]; + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_dim); + float kReduce[1] = {l_square_sum}; + blockReduce(kReduce); + __shared__ float s_var; + if (threadIdx.x == 0) { + s_var = rsqrtf(kReduce[0] / mean_dim + ln_epsilon); + rms_ptr[blockIdx.x] = s_var; + } + __syncthreads(); + + // step 2. layer norm result + T* thread_out = out_ptr + blockIdx.x * hidden_dim; + for (uint idx = threadIdx.x; idx < hidden_dim; idx += blockDim.x) { + thread_out[idx] = thread_inp[idx] * scale_ptr[idx] * s_var; + } +} + +template <> +__global__ void ker_rms_layer_norm_with_res<__half>( + const __half* inp_ptr, const __half* scale_ptr, __half* out_ptr, + __half* res_ptr, __half* rms_ptr, size_t hidden_dim, + const float ln_epsilon) { + // step 0. compute local sum + float l_square_sum = 0; + const __half* thread_inp = inp_ptr + blockIdx.x * hidden_dim; + __half* res_thread_out = res_ptr + blockIdx.x * hidden_dim; + for (uint idx = threadIdx.x; idx < hidden_dim; idx += blockDim.x) { + float float_inp = __half2float(thread_inp[idx]); + l_square_sum += float_inp * float_inp; + res_thread_out[idx] = thread_inp[idx]; + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_dim); + float kReduce[1] = {l_square_sum}; + blockReduce(kReduce); + __shared__ __half s_var; + if (threadIdx.x == 0) { + s_var = __float2half(rsqrtf(kReduce[0] / mean_dim + ln_epsilon)); + if (rms_ptr != nullptr) rms_ptr[blockIdx.x] = s_var; + } + __syncthreads(); + + // step 2. layer norm result + __half* thread_out = out_ptr + blockIdx.x * hidden_dim; + for (uint idx = threadIdx.x; idx < hidden_dim; idx += blockDim.x) { + thread_out[idx] = thread_inp[idx] * scale_ptr[idx] * s_var; + } +} + +template +void launch_rms_layer_norm(const T* inp_ptr, const T* scale_ptr, T* out_ptr, + T* res_ptr, T* rms_ptr, size_t batch_tokens, + size_t hidden_dim, cudaStream_t stream, + const float ln_epsilon) { + int nthread = std::min(((hidden_dim + 31) / 32) * 32, size_t(MAX_THREADS)); + dim3 grid_dim(batch_tokens); + dim3 block_dim(nthread); + + if (res_ptr == nullptr) { + ker_rms_layer_norm<<>>( + inp_ptr, scale_ptr, out_ptr, rms_ptr, hidden_dim, ln_epsilon); + } else { + ker_rms_layer_norm_with_res<<>>( + inp_ptr, scale_ptr, out_ptr, res_ptr, rms_ptr, hidden_dim, ln_epsilon); + } +} + +template void launch_rms_layer_norm( + const float* inp_ptr, const float* scale_ptr, float* out_ptr, + float* res_ptr, float* rms_ptr, size_t batch_tokens, size_t hidden_dim, + cudaStream_t stream, const float ln_epsilon); +template void launch_rms_layer_norm<__half>( + const __half* inp_ptr, const __half* scale_ptr, __half* out_ptr, + __half* res_ptr, __half* rms_ptr, size_t batch_tokens, size_t hidden_dim, + cudaStream_t stream, const float ln_epsilon); + +} // namespace cuda +} // namespace lightseq diff --git a/lightseq/csrc/kernels/cuda/util.cc.cu b/lightseq/csrc/kernels/cuda/util.cc.cu index 93e7b1f1..b0699828 100644 --- a/lightseq/csrc/kernels/cuda/util.cc.cu +++ b/lightseq/csrc/kernels/cuda/util.cc.cu @@ -66,5 +66,22 @@ void dequantize_array(std::vector& i8, std::vector& f, f[i] = dequantize(i8[i], quant_range, clip_max); } } + +__global__ void kernel_convert_dtype(float* source_buffer, + __half* target_buffer, size_t nele) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= nele) { + return; + } + *(target_buffer + idx) = __float2half(*(source_buffer + idx)); +} + +void launch_convert_dtype(float* source_buffer, __half* target_buffer, + size_t size, int max_thread, cudaStream_t stream) { + int nblock = (size + max_thread - 1) / max_thread; + kernel_convert_dtype<<>>(source_buffer, + target_buffer, size); +} + } // namespace cuda } // namespace lightseq diff --git a/lightseq/csrc/layers_new/CMakeLists.txt b/lightseq/csrc/layers_new/CMakeLists.txt index 029e55a1..a680b9db 100644 --- a/lightseq/csrc/layers_new/CMakeLists.txt +++ b/lightseq/csrc/layers_new/CMakeLists.txt @@ -1,6 +1,9 @@ set(layers_files feed_forward_layer.cpp linear_layer.cpp + llama_attention_layer.cpp + llama_mlp_layer.cpp + llama_layer.cpp generator_layer.cpp gpt_attention_layer.cpp gpt_layer.cpp diff --git a/lightseq/csrc/layers_new/includes/launch_llama_emb_layer.h b/lightseq/csrc/layers_new/includes/launch_llama_emb_layer.h new file mode 100644 index 00000000..a59f43b3 --- /dev/null +++ b/lightseq/csrc/layers_new/includes/launch_llama_emb_layer.h @@ -0,0 +1,60 @@ +#pragma once +#include "launch_llama_emb.h" +#include "layer.h" + +namespace lightseq { + +template +class LaunchLlamaEmbLayer : public Layer { + private: + // operators + LaunchLlamaEmbOp* _launch_llama_op = nullptr; + + // parameters + Variable* _token_emb; + + public: + LaunchLlamaEmbLayer(int max_batch_tokens, int max_step, int max_batch_size, + int beam_size, int pad_id, int hidden_dim) + : Layer("LaunchLlamaEmbLayer"), + _launch_llama_op(new LaunchLlamaEmbOp(max_batch_tokens, max_step, + max_batch_size, beam_size, + pad_id, hidden_dim)) { + _token_emb = new Variable("token_emb", g_dtype()); + + this->_context_ptr->exit_layer(); // necessary + } + + virtual ~LaunchLlamaEmbLayer() {} + + std::tuple operator()(Variable* inp) { + set_inputs({inp}); + + std::tuple out = + (*_launch_llama_op)(inp, _token_emb); + + set_outputs({std::get<0>(out), std::get<1>(out), std::get<2>(out)}); + return out; + } + + void before_forward(int batch_size, int seq_len, int offset) { + _launch_llama_op->before_forward(batch_size, seq_len, offset); + } + + void before_backward() {} + + int load_params(const std::vector& para_vec, int offset) { + _token_emb->set_value((char*)para_vec[offset]); + return 0; + } +}; + +template class LaunchLlamaEmbLayer; +#ifdef LIGHTSEQ_cuda +template class LaunchLlamaEmbLayer<__half>; +#endif + +template +using LaunchLlamaEmbLayerPtr = std::shared_ptr>; + +} // namespace lightseq diff --git a/lightseq/csrc/layers_new/includes/llama_attention_layer.h b/lightseq/csrc/layers_new/includes/llama_attention_layer.h new file mode 100644 index 00000000..fbb71844 --- /dev/null +++ b/lightseq/csrc/layers_new/includes/llama_attention_layer.h @@ -0,0 +1,65 @@ +#pragma once +#include "layer.h" +#include "linear.h" +#include "rms_layer_norm.h" +#include "fuse_rotary_position_qkv.h" +#include "sdpa_layer.h" +#include "transform_0213.h" +#include "fuse_add2_op.h" + +namespace lightseq { + +template +class LlamaAttentionLayer : public Layer { + private: + // operators + RMSLayerNormalizeOp* _attn_ln = nullptr; + LinearOp* _qkv_linear = nullptr; + RotaryPositionQk* _fuse_rotary = nullptr; + SDPALayer* _sdpa = nullptr; + Transform0213OP* _transform_0213 = nullptr; + LinearOp* _attn_out_linear = nullptr; + FuseAdd2Op* _add_residual = nullptr; + + // parameters + Variable* _norm_scale; + Variable* _attn_qkvw; + Variable* _attn_ow; + + // shape related + size_t _max_batch_size; + int _max_batch_tokens; + int _max_seq_len; + size_t _hidden_size; + int _nhead; + int _head_dim; + + // tensor slice + Variable* _cache_k; + Variable* _cache_v; + + public: + LlamaAttentionLayer(int max_batch_tokens, int max_seq_len, int hidden_size, + int num_heads, int beam_size); + + virtual ~LlamaAttentionLayer() {} + + Variable* operator()(Variable* inp, Variable* cache_k, Variable* cache_v, + Variable* pad_mask); + + void before_forward(int batch_size, int trg_seq_len, int prompt_len); + + void before_backward(); + + int load_params(const std::vector& para_vec, int offset); +}; + +template class LlamaAttentionLayer; +#ifdef LIGHTSEQ_cuda +template class LlamaAttentionLayer<__half, __half>; +#endif + +template +using LlamaAttentionLayerPtr = std::shared_ptr>; + +} // namespace lightseq diff --git a/lightseq/csrc/layers_new/includes/llama_layer.h b/lightseq/csrc/layers_new/includes/llama_layer.h new file mode 100644 index 00000000..20366141 --- /dev/null +++ b/lightseq/csrc/layers_new/includes/llama_layer.h @@ -0,0 +1,42 @@ +#pragma once +#include "layer.h" +#include "llama_attention_layer.h" +#include "llama_mlp_layer.h" + +namespace lightseq { + +template +class LlamaLayer : public Layer { + private: + LlamaAttentionLayerPtr _attn_layer; + LlamaMLPLayerPtr _mlp_layer; + + int _layer_id; + + public: + LlamaLayer(int max_batch_size, int max_seq_len, int hidden_size, + int inner_dim, int num_heads, int beam_size); + virtual ~LlamaLayer() {} + + Variable* operator()(Variable* inp, Variable* cache_k, Variable* cache_v, + Variable* pad_mask); + + void before_forward(int batch_size, int seq_len, int prompt_len) { + _attn_layer->before_forward(batch_size, seq_len, prompt_len); + _mlp_layer->before_forward(batch_size, seq_len); + } + + size_t load_para_and_grad(const T1* para_ptr, T2* grad_ptr); + + int load_params(const std::vector& para_vec, int offset); +}; + +template class LlamaLayer; +#ifdef LIGHTSEQ_cuda +template class LlamaLayer<__half, __half>; +#endif + +template +using LlamaLayerPtr = std::shared_ptr>; + +} // namespace lightseq diff --git a/lightseq/csrc/layers_new/includes/llama_mlp_layer.h b/lightseq/csrc/layers_new/includes/llama_mlp_layer.h new file mode 100644 index 00000000..79798bd0 --- /dev/null +++ b/lightseq/csrc/layers_new/includes/llama_mlp_layer.h @@ -0,0 +1,51 @@ +#pragma once + +#include "rms_layer_norm.h" +#include "linear.h" +#include "act_elewise_product.h" +#include "fuse_add2_op.h" +#include "layer.h" + +namespace lightseq { + +template +class LlamaMLPLayer : public Layer { + private: + // operators + RMSLayerNormalizeOp* _mlp_ln = nullptr; + LinearOp* _gate_up_linear = nullptr; + LinearOp* _down_linear = nullptr; + ActElewiseProductOp* _act_product = nullptr; + FuseAdd2Op* _add_residual = nullptr; + + // parameters + Variable* _norm_scale; + Variable* _gate_up_linear_weight; + Variable* _down_linear_weight; + + // shape related + int _max_batch_tokens; + size_t _hidden_dim; + size_t _inner_dim; + + public: + LlamaMLPLayer(int max_batch_tokens, int hidden_dim, int inner_dim); + + virtual ~LlamaMLPLayer() {} + + Variable* operator()(Variable* inp); + + void before_forward(int batch_size, int seq_len); + + int load_params(const std::vector& para_vec, int offset); +}; + +template class LlamaMLPLayer; +#ifdef LIGHTSEQ_cuda +template class LlamaMLPLayer<__half, __half>; +#endif + +template +using LlamaMLPLayerPtr = std::shared_ptr>; + +} // namespace lightseq diff --git a/lightseq/csrc/layers_new/includes/rms_norm_layer.h b/lightseq/csrc/layers_new/includes/rms_norm_layer.h new file mode 100644 index 00000000..2d0a5f5e --- /dev/null +++ b/lightseq/csrc/layers_new/includes/rms_norm_layer.h @@ -0,0 +1,64 @@ +#pragma once +#include "rms_layer_norm.h" +#include "layer.h" + +namespace lightseq { + +template +class RMSNormLayer : public Layer { + private: + int _hidden_size; + int _max_batch_tokens; + + // operators + RMSLayerNormalizeOp* _rms_norm = nullptr; + + // parameters + Variable* _norm_scale; + + public: + RMSNormLayer(int max_batch_tokens, int hidden_size) + : Layer("RMSNormLayer"), + _hidden_size(hidden_size), + _max_batch_tokens(max_batch_tokens), + _rms_norm(new RMSLayerNormalizeOp(max_batch_tokens, hidden_size, + false)) { + _norm_scale = new Variable("_norm_scale", g_dtype(), g_dtype()); + + this->_context_ptr->exit_layer(); // necessary + } + + virtual ~RMSNormLayer() {} + + Variable* operator()(Variable* inp) { + set_inputs({inp}); + + Variable* out = std::get<0>((*_rms_norm)(inp, _norm_scale)); + + set_outputs({out}); + return out; + } + + void before_forward(int batch_size, int seq_len) { + _rms_norm->before_forward(batch_size, seq_len); + } + + void before_backward() {} + + int load_params(const std::vector& para_vec, int offset) { + int size = 0; + _norm_scale->set_value((char*)para_vec[offset + size]), size++; + _norm_scale->set_shape({size_t(_hidden_size)}); + return size; + } +}; + +template class RMSNormLayer; +#ifdef LIGHTSEQ_cuda +template class RMSNormLayer<__half, __half>; +#endif + +template +using RMSNormLayerPtr = std::shared_ptr>; + +} // namespace lightseq diff --git a/lightseq/csrc/layers_new/llama_attention_layer.cpp b/lightseq/csrc/layers_new/llama_attention_layer.cpp new file mode 100644 index 00000000..0d8da72a --- /dev/null +++ b/lightseq/csrc/layers_new/llama_attention_layer.cpp @@ -0,0 +1,109 @@ +#include "llama_attention_layer.h" + +namespace lightseq { + +template +LlamaAttentionLayer::LlamaAttentionLayer(int max_batch_size, + int max_seq_len, + int hidden_size, int num_heads, + int beam_size) + : Layer("LlamaAttentionLayer"), + _max_batch_size(max_batch_size), + _max_batch_tokens(max_batch_size * max_seq_len), + _max_seq_len(max_seq_len), + _hidden_size(hidden_size), + _nhead(num_heads), + _head_dim(hidden_size / num_heads) { + // operators + _attn_ln = new RMSLayerNormalizeOp(_max_batch_tokens, hidden_size); + _qkv_linear = + new LinearOp(_max_batch_tokens, 3 * hidden_size, hidden_size); + _fuse_rotary = new RotaryPositionQk(max_batch_size, max_seq_len, + num_heads, _head_dim); + + _sdpa = new SDPALayer(_max_batch_tokens, max_seq_len, _head_dim, + num_heads, 0.f); + _transform_0213 = + new Transform0213OP(_max_batch_tokens * hidden_size); + _attn_out_linear = + new LinearOp(_max_batch_tokens, hidden_size, hidden_size); + // _add_residual = new FuseAdd2Op(_max_batch_tokens, hidden_size); + // parameters init + _norm_scale = new Variable("_norm_scale", g_dtype(), g_dtype()); + _attn_qkvw = new Variable("_attn_qkvw", g_dtype(), g_dtype()); + _attn_ow = new Variable("_attn_ow", g_dtype(), g_dtype()); + + this->_context_ptr->exit_layer(); // necessary +} + +template +Variable* LlamaAttentionLayer::operator()(Variable* inp, + Variable* cache_k, + Variable* cache_v, + Variable* pad_mask) { + set_inputs({inp, cache_k, cache_v, pad_mask}); + + std::tuple ln_out = (*_attn_ln)(inp, _norm_scale); + Variable* qkv_out = (*_qkv_linear)(std::get<0>(ln_out), _attn_qkvw); + + Variable* q_out = (*_fuse_rotary)(qkv_out, cache_k, cache_v); + + // result of Scaled Dot Product Attention + Variable* sdpa_res = (*_sdpa)(q_out, cache_k, cache_v, pad_mask); + + // [sz0, sz1, sz2, sz3] -> [sz0, sz2, sz1, sz3] + Variable* transform_0213_out = (*_transform_0213)(sdpa_res); + + Variable* attn_linear = + (*_attn_out_linear)(transform_0213_out, _attn_ow, std::get<1>(ln_out)); + + // Variable* attn_out = (*_add_residual)(inp, attn_linear); + + set_outputs({attn_linear}); + return attn_linear; +} + +template +void LlamaAttentionLayer::before_forward(int batch_size, int query_len, + int prompt_len) { + // all token number in this batch + int batch_tokens = batch_size * query_len; + int attn_to_len = (prompt_len <= 0) ? query_len : prompt_len + 1; + + _attn_ln->before_forward(batch_size, query_len); + + _qkv_linear->before_forward(batch_tokens); + + _fuse_rotary->before_forward(batch_size, prompt_len, query_len); + + // mask future when training or (inference and prompt_len=0) + _sdpa->before_forward(batch_size, query_len, attn_to_len, _max_seq_len, + prompt_len <= 0); + + _transform_0213->before_forward(batch_size, _nhead, query_len, _head_dim); + + _attn_out_linear->before_forward(batch_tokens); + + // _add_residual->before_forward(batch_size, query_len); +} + +template +void LlamaAttentionLayer::before_backward() {} + +template +int LlamaAttentionLayer::load_params( + const std::vector& para_vec, int offset) { // for inference + int size = 0; + _norm_scale->set_value((char*)para_vec[offset + size]), size++; + _norm_scale->set_shape({_hidden_size}); + + _attn_qkvw->set_value((char*)para_vec[offset + size]), size++; + _attn_qkvw->set_shape({_hidden_size, 3 * _hidden_size}); + + _attn_ow->set_value((char*)para_vec[offset + size]), size++; + _attn_ow->set_shape({_hidden_size, _hidden_size}); + + return size; +} + +} // namespace lightseq diff --git a/lightseq/csrc/layers_new/llama_layer.cpp b/lightseq/csrc/layers_new/llama_layer.cpp new file mode 100644 index 00000000..a41dee74 --- /dev/null +++ b/lightseq/csrc/layers_new/llama_layer.cpp @@ -0,0 +1,44 @@ +#include "llama_layer.h" + +namespace lightseq { + +template +LlamaLayer::LlamaLayer(int max_batch_size, int max_seq_len, + int hidden_size, int inner_dim, int num_heads, + int beam_size) + : Layer("LlamaLayer") { + _attn_layer.reset(new LlamaAttentionLayer( + max_batch_size, max_seq_len, hidden_size, num_heads, beam_size)); + _mlp_layer.reset(new LlamaMLPLayer(max_batch_size * max_seq_len, + hidden_size, inner_dim)); + + this->_context_ptr->exit_layer(); // necessary +} + +template +Variable* LlamaLayer::operator()(Variable* inp, Variable* cache_k, + Variable* cache_v, + Variable* pad_mask) { + set_inputs({inp, cache_k, cache_v, pad_mask}); + + Variable* attn_out = (*_attn_layer)(inp, cache_k, cache_v, pad_mask); + + Variable* ffn_out = (*_mlp_layer)(attn_out); + + set_outputs({ffn_out}); + return ffn_out; +} + +template +int LlamaLayer::load_params(const std::vector& para_vec, + int offset) { // for inference + int size = 0; + + size += _attn_layer->load_params(para_vec, offset + size); + + size += _mlp_layer->load_params(para_vec, offset + size); + + return size; +} + +} // namespace lightseq diff --git a/lightseq/csrc/layers_new/llama_mlp_layer.cpp b/lightseq/csrc/layers_new/llama_mlp_layer.cpp new file mode 100644 index 00000000..8c19e65a --- /dev/null +++ b/lightseq/csrc/layers_new/llama_mlp_layer.cpp @@ -0,0 +1,69 @@ +#include "llama_mlp_layer.h" + +namespace lightseq { + +template +LlamaMLPLayer::LlamaMLPLayer(int max_batch_tokens, int hidden_dim, + int inner_dim) + : Layer("LlamaMLPLayer"), + _max_batch_tokens(max_batch_tokens), + _hidden_dim(hidden_dim), + _inner_dim(inner_dim), + _mlp_ln(new RMSLayerNormalizeOp(max_batch_tokens, hidden_dim)), + _gate_up_linear( + new LinearOp(max_batch_tokens, 2 * inner_dim, hidden_dim)), + _act_product( + new ActElewiseProductOp(max_batch_tokens, inner_dim)), + _down_linear( + new LinearOp(max_batch_tokens, hidden_dim, inner_dim)) +// _add_residual(new FuseAdd2Op(max_batch_tokens, hidden_dim)) +{ + _norm_scale = new Variable("_norm_scale", g_dtype(), g_dtype()); + _gate_up_linear_weight = + new Variable("_gate_up_linear_weight", g_dtype(), g_dtype()); + _down_linear_weight = + new Variable("_down_linear_weight", g_dtype(), g_dtype()); + this->_context_ptr->exit_layer(); // necessary +} + +template +Variable* LlamaMLPLayer::operator()(Variable* inp) { + set_inputs({inp}); + std::tuple ln_out = (*_mlp_ln)(inp, _norm_scale); + Variable* gate_up_out = + (*_gate_up_linear)(std::get<0>(ln_out), _gate_up_linear_weight); + Variable* act_out = (*_act_product)(gate_up_out); + Variable* down_out = + (*_down_linear)(act_out, _down_linear_weight, std::get<1>(ln_out)); + // Variable* mlp_out = (*_add_residual)(down_out, inp); + set_outputs({down_out}); + return down_out; +} + +template +void LlamaMLPLayer::before_forward(int batch_size, int seq_len) { + _mlp_ln->before_forward(batch_size, seq_len); + _gate_up_linear->before_forward(batch_size * seq_len); + _act_product->before_forward(batch_size, seq_len); + _down_linear->before_forward(batch_size * seq_len); + // _add_residual->before_forward(batch_size, seq_len); +} + +template +int LlamaMLPLayer::load_params(const std::vector& para_vec, + int offset) { + int size = 0; + + _norm_scale->set_value((char*)para_vec[offset + size]), size++; + _norm_scale->set_shape({_hidden_dim}); + + _gate_up_linear_weight->set_value((char*)para_vec[offset + size]), size++; + _gate_up_linear_weight->set_shape({_hidden_dim, 2 * _inner_dim}); + + _down_linear_weight->set_value((char*)para_vec[offset + size]), size++; + _down_linear_weight->set_shape({_inner_dim, _hidden_dim}); + + return size; +} + +} // namespace lightseq diff --git a/lightseq/csrc/models/CMakeLists.txt b/lightseq/csrc/models/CMakeLists.txt index 24ea42ab..89decaef 100644 --- a/lightseq/csrc/models/CMakeLists.txt +++ b/lightseq/csrc/models/CMakeLists.txt @@ -1,5 +1,5 @@ add_library(liblightseq SHARED bert.cc bert_crf.cc transformer.cu gpt.cc - model_util.cc) + llama.cc model_util.cc) target_link_libraries(liblightseq PUBLIC lightseq_layers) diff --git a/lightseq/csrc/models/includes/llama.h b/lightseq/csrc/models/includes/llama.h new file mode 100644 index 00000000..aecc9014 --- /dev/null +++ b/lightseq/csrc/models/includes/llama.h @@ -0,0 +1,61 @@ +#pragma once +#include "model_base.h" + +#include "model_util.h" +#include "llama_weight.h" + +#include "launch_llama_emb_layer.h" +#include "llama_layer.h" +#include "linear_layer.h" +#include "rms_norm_layer.h" +#include "generator_layer.h" + +namespace lightseq { +namespace cuda { +class Llama : public LSModel { + private: + LlamaWeight tw_; + std::shared_ptr _context_ptr; + + LaunchLlamaEmbLayerPtr _launch_llama_emb_layer; + std::vector> _llama_layer_vec; + RMSNormLayerPtr _rms_norm_layer; + LinearLayerPtr _linear_layer; + GeneratorLayerPtr _generator_layer; + + ContextPtr context_ptr; + + Variable* _inp_tokens; // need to allocate + Variable* _out_tokens; + Variable* _pad_mask; + + Variable* _total_caches_k; + Variable* _total_caches_v; + + int* _llama_out_ptr = nullptr; + int* _input_ptr = nullptr; + float* _llama_scores_ptr = nullptr; + + int _max_batch_size; + GenerateMethod _generate_method; + + public: + Llama(const std::string weight_path, const int max_batch_size); + ~Llama(); + + void before_forward(int batch_size, int prompt_len, int steps); + + void Infer() override; + void set_input_ptr(int index, void* input_ptr) override; + void set_output_ptr(int index, void* output_ptr) override; + const void* get_output_ptr(int index) override; + std::vector get_input_max_shape(int index) override; + std::vector get_output_max_shape(int index) override; + DataType get_input_dtype(int index) override; + DataType get_output_dtype(int index) override; + void benchmark_mode(bool is_benchmark) override {} +}; + +LSMODEL_REGISTER(Llama); +} // namespace cuda +} // namespace lightseq diff --git a/lightseq/csrc/models/llama.cc b/lightseq/csrc/models/llama.cc new file mode 100644 index 00000000..ef88cdae --- /dev/null +++ b/lightseq/csrc/models/llama.cc @@ -0,0 +1,285 @@ +#include "llama.h" + +namespace lightseq { +namespace cuda { +Llama::Llama(const std::string weight_path, const int max_batch_size) + : LSModel({"token_ids"}, {"llama_out"}), _max_batch_size(max_batch_size) { + /* --- step.1 initial context --- */ + Context::create_global_context(StatusType::Inference); + _context_ptr = Context::global_instance(); + + /* --- step.2 load model weights into GPU memory --- */ + // saved in custom proto file + std::string model_weights_path = weight_path; + std::string res = tw_.initializing(model_weights_path); + if (!res.empty()) { + throw std::runtime_error(res); + } + printf("*** model max_batch_size: %d ***\n", max_batch_size); + _generate_method = get_generate_method(tw_._generate_method); + if (_generate_method != GenerateMethod::BeamSearch) { + tw_._beam_size = 1; + } + tw_.print_model_config(); + + /* --- step.3 initial input Variable node --- */ + _inp_tokens = new Variable("inp_tokens", g_dtype()); + + /* --- step.4 inital operator & layer --- */ + int max_batch_tokens = tw_._max_step * _max_batch_size; + _launch_llama_emb_layer.reset(new LaunchLlamaEmbLayer( + max_batch_tokens, tw_._max_step, _max_batch_size, tw_._beam_size, + tw_._padding_id, tw_._hidden_size)); + _launch_llama_emb_layer->load_params(tw_.get_src_emb_wei(), 0); + + int enc_wei_offset = 0; + for (int idx = 0; idx < tw_._layer_num; idx++) { + LlamaLayerPtr llama_layer( + new LlamaLayer(max_batch_size, tw_._max_step, + tw_._hidden_size, tw_._inner_size, + tw_._head_num, tw_._beam_size)); + enc_wei_offset += + llama_layer->load_params(tw_.get_enc_wei(), enc_wei_offset); + _llama_layer_vec.push_back(llama_layer); + } + + _rms_norm_layer.reset( + new RMSNormLayer(max_batch_tokens, tw_._hidden_size)); + _rms_norm_layer->load_params(tw_.get_src_emb_wei(), 1); + + // intial Project hidden states to vocab logits + _linear_layer.reset(new LinearLayer( + max_batch_size * tw_._beam_size, tw_._hidden_size, tw_._src_vocab_size, + MATRIX_OP::NonTranspose, MATRIX_OP::NonTranspose, 1.f)); + _linear_layer->load_params(tw_.get_src_emb_wei(), 2); + + _generator_layer.reset(new GeneratorLayer( + _generate_method, tw_._layer_num, max_batch_size, tw_._max_step, + tw_._src_vocab_size, tw_._hidden_size, 1024, tw_._beam_size, + tw_._diverse_lambda, tw_._dim_per_head, tw_._eos_id, tw_._head_num, + tw_._length_penalty, tw_._topk, tw_._topp, false)); + + /* --- step.5 construct network --- */ + size_t cache_size = max_batch_tokens * tw_._beam_size * tw_._hidden_size; + _total_caches_k = new Variable("total_caches_k", cache_size * tw_._layer_num, + g_dtype(), DataType::kNotSupported, + VariableType::RegressiveVariable); + _total_caches_v = new Variable("total_caches_v", cache_size * tw_._layer_num, + g_dtype(), DataType::kNotSupported, + VariableType::RegressiveVariable); + + // note regress begin + _context_ptr->regress_begin(); + + std::tuple llama_emb_outs = + (*_launch_llama_emb_layer)(_inp_tokens); + Variable *llama_emb = std::get<0>(llama_emb_outs); + Variable *pad_mask = std::get<1>(llama_emb_outs); + pad_mask->set_regress_var(); + size_t cache_offset = 0; + for (auto iter : _llama_layer_vec) { + Variable *cache_k = new Variable("cache_k", _total_caches_k); + cache_k->set_offset(cache_offset, {cache_size}); + Variable *cache_v = new Variable("cache_v", _total_caches_v); + cache_v->set_offset(cache_offset, {cache_size}); + llama_emb = (*iter)(llama_emb, cache_k, cache_v, pad_mask); + cache_offset += cache_size; + } + llama_emb = (*_rms_norm_layer)(llama_emb); + Variable *logits_prob = (*_linear_layer)(llama_emb); + + std::tuple gen_outs = + (*_generator_layer)(logits_prob, _inp_tokens); + + // note regress_end + _context_ptr->regress_end(); + + _out_tokens = std::get<0>(gen_outs); + _inp_tokens->malloc_memory(max_batch_size * tw_._beam_size * tw_._max_step); + _out_tokens->malloc_memory(max_batch_size * tw_._beam_size * tw_._max_step); + + _context_ptr->build(); + printf("Finish construct network!\n"); +} + +Llama::~Llama() {} + +void Llama::before_forward(int batch_size, int prompt_len, int steps) { + if (steps == 0) { + _launch_llama_emb_layer->before_forward(batch_size, prompt_len, 0); + for (auto iter : _llama_layer_vec) { + iter->before_forward(batch_size * tw_._beam_size, prompt_len, 0); + } + _rms_norm_layer->before_forward(batch_size * tw_._beam_size, 1); + _linear_layer->before_forward(batch_size * tw_._beam_size, 1); + _generator_layer->before_forward(batch_size, prompt_len, 0); + } else { + _launch_llama_emb_layer->before_forward(batch_size, 1, + prompt_len + steps - 1); + for (auto iter : _llama_layer_vec) { + iter->before_forward(batch_size * tw_._beam_size, 1, + prompt_len + steps - 1); + } + _rms_norm_layer->before_forward(batch_size * tw_._beam_size, 1); + _linear_layer->before_forward(batch_size * tw_._beam_size, 1); + _generator_layer->before_forward(batch_size, prompt_len, steps); + } +} + +void Llama::Infer() { + int batch_size = input_shapes_[0][0], prompt_len = input_shapes_[0][1]; + + /* --- notice that the order of forward should be the same with network --- */ + +#ifdef LIGHTSEQ_cuda + for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { + for (int beam_idx = 0; beam_idx < tw_._beam_size; beam_idx++) { + CHECK_GPU_ERROR(cudaMemcpyAsync( + _inp_tokens->value() + + (batch_idx * tw_._beam_size + beam_idx) * tw_._max_step, + _input_ptr + batch_idx * prompt_len, prompt_len * sizeof(int), + cudaMemcpyDefault, _context_ptr->get_stream())); + } + } +#endif + + int steps = 0; + while (steps + prompt_len < tw_._max_step) { + before_forward(batch_size, prompt_len, steps); + + _launch_llama_emb_layer->forward(); + for (auto iter : _llama_layer_vec) { + iter->forward(); + } + + if (steps == 0) { + OpType_ *linear_inp_ptr = _rms_norm_layer->input(0)->value(); + for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { + for (int i = 0; i < tw_._beam_size; i++) { + cudaMemcpyAsync( + linear_inp_ptr + + (batch_idx * tw_._beam_size + i) * tw_._hidden_size, + linear_inp_ptr + (batch_idx * tw_._beam_size * prompt_len + + i * prompt_len + prompt_len - 1) * + tw_._hidden_size, + tw_._hidden_size * sizeof(OpType_), cudaMemcpyDefault, + _context_ptr->get_stream()); + } + } + } + _rms_norm_layer->forward(); + _linear_layer->forward(); + + _generator_layer->forward(); + + if (_generator_layer->is_stop()) { + break; + } + if (_generate_method == GenerateMethod::BeamSearch) { + _generator_layer->refresh_cache(_total_caches_k, _total_caches_v); + if (steps + prompt_len + 1 < tw_._max_step) { + Variable::swap_tensor(_inp_tokens, _out_tokens); + } + } + steps++; + } + + for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { + for (int beam_idx = 0; beam_idx < tw_._beam_size; beam_idx++) { + int *tmp_out_ptr = (_generate_method == GenerateMethod::BeamSearch) + ? _out_tokens->value() + : _inp_tokens->value(); + cudaMemcpyAsync( + _llama_out_ptr + + (batch_idx * tw_._beam_size + beam_idx) * (steps + prompt_len), + tmp_out_ptr + (batch_idx * tw_._beam_size + beam_idx) * tw_._max_step, + (steps + prompt_len) * sizeof(int), cudaMemcpyDefault, + _context_ptr->get_stream()); + } + } + + _context_ptr->synchronize(); + set_output_shape(0, {batch_size, tw_._beam_size, prompt_len + steps}); +} + +void Llama::set_input_ptr(int index, void *input_ptr) { + switch (index) { + case 0: + _input_ptr = (int *)input_ptr; + break; + + default: + throw std::runtime_error("invalid input index"); + break; + } +} + +void Llama::set_output_ptr(int index, void *output_ptr) { + switch (index) { + case 0: + _llama_out_ptr = (int *)output_ptr; + break; + + default: + throw std::runtime_error("invalid output index"); + break; + } +} + +const void *Llama::get_output_ptr(int index) { + switch (index) { + case 0: + return static_cast(_llama_out_ptr); + + default: + throw std::runtime_error("invalid output index"); + break; + } +} + +std::vector Llama::get_input_max_shape(int index) { + switch (index) { + case 0: + return {_max_batch_size, tw_._max_step}; + + default: + throw std::runtime_error("invalid input index"); + break; + } +} +std::vector Llama::get_output_max_shape(int index) { + switch (index) { + case 0: + return {_max_batch_size, tw_._beam_size, tw_._max_step}; + + default: + throw std::runtime_error("invalid output index"); + break; + } +} + +DataType Llama::get_input_dtype(int index) { + switch (index) { + case 0: + return DataType::kInt32; + break; + + default: + throw std::runtime_error("invalid input index"); + break; + } +} + +DataType Llama::get_output_dtype(int index) { + switch (index) { + case 0: + return DataType::kInt32; + break; + + default: + throw std::runtime_error("invalid output index"); + break; + } +} +} // namespace cuda +} // namespace lightseq diff --git a/lightseq/csrc/ops_new/CMakeLists.txt b/lightseq/csrc/ops_new/CMakeLists.txt index c1ec773a..15c3ecae 100644 --- a/lightseq/csrc/ops_new/CMakeLists.txt +++ b/lightseq/csrc/ops_new/CMakeLists.txt @@ -1,18 +1,23 @@ set(operator_files + act_elewise_product.cpp beam_search_topk.cu - sampling.cc.cu bias_act_dropout.cpp bias_add_transform_20314.cpp bias_dropout_residual.cpp concat3_dim1.cpp crf.cpp dropout.cpp + fuse_add2_op.cpp launch_dec_emb_op.cpp launch_enc_emb.cpp launch_gpt_emb.cpp + launch_llama_emb.cpp layer_normalize.cpp split_head_op.cpp linear.cpp + rms_layer_norm.cpp + fuse_rotary_position_qkv.cpp + sampling.cc.cu softmax.cpp strided_batch_gemm.cpp transform_0213.cpp) diff --git a/lightseq/csrc/ops_new/act_elewise_product.cpp b/lightseq/csrc/ops_new/act_elewise_product.cpp new file mode 100644 index 00000000..c3e10735 --- /dev/null +++ b/lightseq/csrc/ops_new/act_elewise_product.cpp @@ -0,0 +1,35 @@ +#include "act_elewise_product.h" + +namespace lightseq { + +template +Variable* ActElewiseProductOp::operator()(Variable* inp) { + size_t max_size = _max_batch_tokens * _inner_size; + _result = new Variable("ActElewiseProductOp_out", max_size, g_dtype(), + g_dtype()); + set_parents({inp}); + this->set_children({_result}); + return _result; +} + +template +void ActElewiseProductOp::forward() { + T1* inp_val = (T1*)parent(0)->value(); + T1* out_val = (T1*)child(0)->value(); + + if (!_context_ptr->is_built()) { + return; + } + +#ifdef LIGHTSEQ_cuda + cudaStream_t stream = _context_ptr->get_stream(); + cuda::launch_silu_elewise_product(inp_val, out_val, _batch_size, _seq_len, + _inner_size, stream); +#endif +} + +template class ActElewiseProductOp; +#ifdef LIGHTSEQ_cuda +template class ActElewiseProductOp<__half, __half>; +#endif +} // namespace lightseq diff --git a/lightseq/csrc/ops_new/fuse_add2_op.cpp b/lightseq/csrc/ops_new/fuse_add2_op.cpp new file mode 100644 index 00000000..fede6b62 --- /dev/null +++ b/lightseq/csrc/ops_new/fuse_add2_op.cpp @@ -0,0 +1,35 @@ +#include "fuse_add2_op.h" + +namespace lightseq { + +template +Variable* FuseAdd2Op::operator()(Variable* inpA, Variable* inpB) { + _result = new Variable("FuseAdd2Op_out", _max_batch_tokens * _hidden_dim, + g_dtype(), g_dtype()); + set_parents({inpA, inpB}); + this->set_children({_result}); + return _result; +} + +template +void FuseAdd2Op::forward() { + T1* inpA_ptr = (T1*)parent(0)->value(); + T1* inpB_ptr = (T1*)parent(1)->value(); + T1* out_ptr = (T1*)child(0)->value(); + + if (!_context_ptr->is_built()) { + return; + } + +#ifdef LIGHTSEQ_cuda + cudaStream_t stream = _context_ptr->get_stream(); + cuda::launch_fused_add2(out_ptr, inpA_ptr, inpB_ptr, _batch_size, _seq_len, + _hidden_dim, stream); +#endif +} + +template class FuseAdd2Op; +#ifdef LIGHTSEQ_cuda +template class FuseAdd2Op<__half, __half>; +#endif +} // namespace lightseq diff --git a/lightseq/csrc/ops_new/fuse_rotary_position_qkv.cpp b/lightseq/csrc/ops_new/fuse_rotary_position_qkv.cpp new file mode 100644 index 00000000..1d451a92 --- /dev/null +++ b/lightseq/csrc/ops_new/fuse_rotary_position_qkv.cpp @@ -0,0 +1,41 @@ +#include "fuse_rotary_position_qkv.h" + +namespace lightseq { + +template +Variable* RotaryPositionQk::operator()(Variable* inp, Variable* cache_k, + Variable* cache_v) { + size_t max_size = _max_batch_size * _max_step * _head_num * _head_dim; + _result = new Variable("RotaryPositionQk_out", max_size, g_dtype(), + g_dtype()); + set_parents({inp, cache_k, cache_v}); + this->set_children({_result}); + return _result; +} + +template +void RotaryPositionQk::forward() { + T1* inp_val = (T1*)parent(0)->value(); + T1* cache_k_val = (T1*)parent(1)->value(); + T1* cache_v_val = (T1*)parent(2)->value(); + + T1* out_val = (T1*)child(0)->value(); + + if (!_context_ptr->is_built()) { + return; + } + +#ifdef LIGHTSEQ_cuda + cudaStream_t stream = _context_ptr->get_stream(); + cuda::launch_split_rotary_position_qkv( + inp_val, _device_sin_ptr, _device_cos_ptr, out_val, cache_k_val, + cache_v_val, _max_step, _batch_size, _head_num, _offset_seq_len, + _query_len, _head_dim, stream); +#endif +} + +template class RotaryPositionQk; +#ifdef LIGHTSEQ_cuda +template class RotaryPositionQk<__half, __half>; +#endif +} // namespace lightseq diff --git a/lightseq/csrc/ops_new/includes/act_elewise_product.h b/lightseq/csrc/ops_new/includes/act_elewise_product.h new file mode 100644 index 00000000..f26a4f36 --- /dev/null +++ b/lightseq/csrc/ops_new/includes/act_elewise_product.h @@ -0,0 +1,42 @@ +#pragma once +#include "declaration.h" +#include "node.h" + +namespace lightseq { + +template +class ActElewiseProductOp : public Operator { + private: + size_t _inner_size; + size_t _max_batch_tokens; + size_t _batch_tokens; + size_t _batch_size; + size_t _seq_len; + + Variable* _result; + + public: + ActElewiseProductOp(size_t max_batch_tokens, size_t inner_size) + : Operator("ActElewiseProductOp"), + _max_batch_tokens(max_batch_tokens), + _inner_size(inner_size) {} + + virtual ~ActElewiseProductOp() {} + + Variable* operator()(Variable* inp); + + void forward() override; + + void before_forward(size_t batch_size, size_t seq_len) { + _batch_size = batch_size; + _seq_len = seq_len; + _batch_tokens = batch_size * seq_len; + _result->set_shape({_batch_tokens, _inner_size}); + } + + void backward() override {} + + void before_backward() {} +}; + +} // namespace lightseq diff --git a/lightseq/csrc/ops_new/includes/fuse_add2_op.h b/lightseq/csrc/ops_new/includes/fuse_add2_op.h new file mode 100644 index 00000000..8cb458cc --- /dev/null +++ b/lightseq/csrc/ops_new/includes/fuse_add2_op.h @@ -0,0 +1,39 @@ +#pragma once +#include "declaration.h" +#include "node.h" + +namespace lightseq { + +template +class FuseAdd2Op : public Operator { + private: + size_t _max_batch_tokens; + size_t _batch_tokens; + size_t _batch_size; + size_t _seq_len; + size_t _hidden_dim; + + Variable* _result; + + public: + FuseAdd2Op(size_t max_batch_tokens, size_t hidden_dim) + : Operator("FuseAdd2"), + _max_batch_tokens(max_batch_tokens), + _hidden_dim(hidden_dim) {} + + ~FuseAdd2Op() {} + + Variable* operator()(Variable* inpA, Variable* inpB); + + void forward() override; + + void before_forward(size_t batch_size, size_t seq_len) { + _batch_size = batch_size; + _seq_len = seq_len; + _result->set_shape({batch_size, seq_len, _hidden_dim}); + } + + void backward() override {} +}; + +} // namespace lightseq diff --git a/lightseq/csrc/ops_new/includes/fuse_rotary_position_qkv.h b/lightseq/csrc/ops_new/includes/fuse_rotary_position_qkv.h new file mode 100644 index 00000000..f1442b64 --- /dev/null +++ b/lightseq/csrc/ops_new/includes/fuse_rotary_position_qkv.h @@ -0,0 +1,97 @@ +#pragma once +#include "declaration.h" +#include "node.h" +#include "cmath" + +namespace lightseq { + +template +class RotaryPositionQk : public Operator { + private: + T1* _sin_ptr; + T1* _cos_ptr; + size_t _max_step; + size_t _max_batch_size; + size_t _batch_size; + size_t _head_num; + size_t _head_dim; + size_t _offset_seq_len; + size_t _query_len; + + T1* _device_sin_ptr; + T1* _device_cos_ptr; + + Variable* _result; + + public: + RotaryPositionQk(int max_batch_size, int max_step, int head_num, int head_dim) + : Operator("RotaryPositionQk"), + _max_batch_size(max_batch_size), + _max_step(max_step), + _head_num(head_num), + _head_dim(head_dim) { + if (head_dim & 1) { + printf( + "Error! head dim should be even number while using RotaryPositionQk " + "Operator.\n"); + exit(0); + } + + int total_size = max_step * head_dim / 2; + _sin_ptr = (T1*)malloc(total_size * sizeof(T1)); + _cos_ptr = (T1*)malloc(total_size * sizeof(T1)); + + for (int i = 0; i < head_dim / 2; i++) { + float theta = std::pow(10000, -2. * i / head_dim); + for (int j = 0; j < max_step; j++) { + T1 sin_val, cos_val; + if (std::is_same::value) { + sin_val = sin(j * theta), cos_val = cos(j * theta); + } else { + sin_val = __float2half(sin(j * theta)), + cos_val = __float2half(cos(j * theta)); + } + *(_sin_ptr + j * head_dim / 2 + i) = + sin_val; // shape: [max_step, head_dim / 2] + *(_cos_ptr + j * head_dim / 2 + i) = + cos_val; // shape: [max_step, head_dim / 2] + } + } + +#ifdef LIGHTSEQ_cuda + _device_sin_ptr = + (T1*)_context_ptr->allocator()->malloc_mem(total_size * sizeof(T1)); + _device_cos_ptr = + (T1*)_context_ptr->allocator()->malloc_mem(total_size * sizeof(T1)); + CHECK_GPU_ERROR(cudaMemcpy(_device_sin_ptr, _sin_ptr, + total_size * sizeof(T1), cudaMemcpyDefault)); + CHECK_GPU_ERROR(cudaMemcpy(_device_cos_ptr, _cos_ptr, + total_size * sizeof(T1), cudaMemcpyDefault)); + free(_sin_ptr); + _sin_ptr = nullptr; + free(_cos_ptr); + _cos_ptr = nullptr; +#else + _device_sin_ptr = _sin_ptr; + _device_cos_ptr = _cos_ptr; +#endif + } + + virtual ~RotaryPositionQk() {} + + void before_forward(int batch_size, int offset_seq_len, int query_len) { + _batch_size = batch_size; + _offset_seq_len = offset_seq_len; + _query_len = query_len; + _result->set_shape({_batch_size, _head_num, _query_len, _head_dim}); + } + + Variable* operator()(Variable* inp_tensor, Variable* cache_k, + Variable* cache_v); + + void forward() override; + + void backward() override {} +}; + +} // namespace lightseq diff --git a/lightseq/csrc/ops_new/includes/launch_llama_emb.h b/lightseq/csrc/ops_new/includes/launch_llama_emb.h new file mode 100644 index 00000000..ddbd9fe5 --- /dev/null +++ b/lightseq/csrc/ops_new/includes/launch_llama_emb.h @@ -0,0 +1,57 @@ +#pragma once +#include "declaration.h" +#include "node.h" +#include "tuple" + +namespace lightseq { + +// dropout inside ffn. +template +class LaunchLlamaEmbOp : public Operator { + private: + size_t _max_batch_tokens; + int _pad_id; + size_t _hidden_dim; + + size_t _batch_size; + size_t _seq_len; + int _max_step; + int _beam_size; + int _offset; + int _max_batch_size; + + Variable* _result; + Variable* _pad_mask; + Variable* _left_pad_len; + + public: + LaunchLlamaEmbOp(size_t max_batch_tokens, int max_step, int max_batch_size, + int beam_size, int pad_id, size_t hidden_dim) + : Operator("LaunchLlamaEmbOp"), + _max_batch_tokens(max_batch_tokens), + _max_batch_size(max_batch_size), + _pad_id(pad_id), + _max_step(max_step), + _beam_size(beam_size), + _hidden_dim(hidden_dim) {} + + virtual ~LaunchLlamaEmbOp() {} + + std::tuple operator()(Variable* inp_tokens, + Variable* token_emb); + + void before_forward(size_t batch_size, size_t seq_len, int offset) { + _batch_size = batch_size, _seq_len = seq_len, _offset = offset; + _result->set_shape({batch_size * seq_len, _hidden_dim}); + _pad_mask->set_shape({batch_size, seq_len + offset}); + _left_pad_len->set_shape({_batch_size, size_t(_beam_size)}); + } + + void forward() override; + + void backward() override { + printf("ERROR! LaunchLlamaEmbOp can't cal backward()\n"); + exit(-1); + } +}; +} // namespace lightseq diff --git a/lightseq/csrc/ops_new/includes/linear.h b/lightseq/csrc/ops_new/includes/linear.h index 87799d60..8a9c5986 100644 --- a/lightseq/csrc/ops_new/includes/linear.h +++ b/lightseq/csrc/ops_new/includes/linear.h @@ -14,8 +14,10 @@ class LinearOp : public Operator { std::array _gemm_algos; float _alpha; + float _beta; MATRIX_OP _opA; MATRIX_OP _opB; + bool _use_residual = false; Variable* _result; @@ -28,7 +30,7 @@ class LinearOp : public Operator { public: LinearOp(size_t max_batch_tokens, size_t output_size, size_t input_size, MATRIX_OP opA = weight_op, MATRIX_OP opB = MATRIX_OP::NonTranspose, - float alpha = float(1.)) + float alpha = float(1.), float beta = float(0.)) : Operator("LinearOp"), _max_batch_tokens(max_batch_tokens), _output_size(output_size), @@ -36,17 +38,23 @@ class LinearOp : public Operator { _opA(opA), _opB(opB), _gemm_algos(std::array({99, 99, 99})), - _alpha(alpha) {} + _alpha(alpha), + _beta(beta) {} ~LinearOp() {} Variable* operator()(Variable* inp, Variable* weight); + Variable* operator()(Variable* inp, Variable* weight, Variable* residual); void forward() override; void before_forward(size_t batch_tokens) { _batch_tokens = batch_tokens; - _result->set_shape({batch_tokens, _output_size}); + if (_use_residual) { + _result->set_offset(0, {batch_tokens, _output_size}); + } else { + _result->set_shape({batch_tokens, _output_size}); + } } void backward() override; diff --git a/lightseq/csrc/ops_new/includes/rms_layer_norm.h b/lightseq/csrc/ops_new/includes/rms_layer_norm.h new file mode 100644 index 00000000..f0784779 --- /dev/null +++ b/lightseq/csrc/ops_new/includes/rms_layer_norm.h @@ -0,0 +1,44 @@ +#pragma once +#include "declaration.h" +#include "node.h" + +namespace lightseq { + +template +class RMSLayerNormalizeOp : public Operator { + private: + size_t _max_batch_tokens; + size_t _hidden_dim; + size_t _batch_tokens; + float _epsilon; + + bool _use_mean; + bool _use_residual; + + TensorPtr _rms_vars; + Variable* _result; + Variable* _residual; + + public: + RMSLayerNormalizeOp(size_t max_batch_tokens, size_t hidden_dim, + bool use_residual = true, float epsilon = 1e-6) + : Operator("RMSLayerNormalizeOp"), + _max_batch_tokens(max_batch_tokens), + _hidden_dim(hidden_dim), + _use_residual(use_residual), + _epsilon(epsilon) { + _rms_vars.reset(new Tensor("rms_vars", g_dtype(), max_batch_tokens)); + } + + std::tuple operator()(Variable* inp, Variable* scale); + + virtual ~RMSLayerNormalizeOp(); + + void before_forward(size_t batch_size, size_t seq_len); + + void forward() override; + + void backward() override {} +}; + +} // namespace lightseq diff --git a/lightseq/csrc/ops_new/includes/sampling.h b/lightseq/csrc/ops_new/includes/sampling.h index 30b0da95..72cd49a2 100644 --- a/lightseq/csrc/ops_new/includes/sampling.h +++ b/lightseq/csrc/ops_new/includes/sampling.h @@ -38,6 +38,8 @@ class SamplingOp : public Operator { int max_thread_per_block, int trg_vocab_size, int topk, float topp, int eos_id); + virtual ~SamplingOp() {} + // output: new_token_ids std::tuple operator()(Variable* logits, Variable* logit_bias, diff --git a/lightseq/csrc/ops_new/launch_llama_emb.cpp b/lightseq/csrc/ops_new/launch_llama_emb.cpp new file mode 100644 index 00000000..49c4979f --- /dev/null +++ b/lightseq/csrc/ops_new/launch_llama_emb.cpp @@ -0,0 +1,56 @@ +#include "launch_llama_emb.h" + +namespace lightseq { + +template +std::tuple LaunchLlamaEmbOp::operator()( + Variable* inp_tokens, Variable* token_emb) { + set_parents({inp_tokens, token_emb}); + + size_t max_size = _max_batch_tokens * _hidden_dim; + + _result = + new Variable("LaunchLlamaEmbOp_out", + _max_batch_tokens * _hidden_dim * _beam_size, g_dtype()); + _pad_mask = + new Variable("_pad_mask", _max_batch_tokens * _beam_size, g_dtype()); + + _left_pad_len = new Variable("_left_pad_len", _max_batch_size * _beam_size, + g_dtype(), cuda::DataType::kNotSupported, + VariableType::RegressiveVariable); + + this->set_children({_result, _pad_mask, _left_pad_len}); + return std::make_tuple(_result, _pad_mask, _left_pad_len); +} + +template +void LaunchLlamaEmbOp::forward() { + int* inp_tokens = (int*)parent(0)->value(); + const T* token_emb = (const T*)parent(1)->value(); + + T* output_ptr = (T*)child(0)->value(); + T* pad_mask_ptr = (T*)child(1)->value(); + int* left_pad_len_ptr = (int*)child(2)->value(); + + if (!_context_ptr->is_built()) { + return; + } + +#ifdef LIGHTSEQ_cuda + cudaStream_t _stream = _context_ptr->get_stream(); + if (_offset == 0) { + cudaMemsetAsync(left_pad_len_ptr, 0, _batch_size * _beam_size * sizeof(int), + _stream); + } + cuda::launch_llama_embedding(token_emb, inp_tokens, output_ptr, + pad_mask_ptr, left_pad_len_ptr, _batch_size, + _beam_size, _hidden_dim, _offset, _seq_len, + _max_step, _pad_id, _stream); +#endif +} + +template class LaunchLlamaEmbOp; +#ifdef LIGHTSEQ_cuda +template class LaunchLlamaEmbOp<__half>; +#endif +} // namespace lightseq diff --git a/lightseq/csrc/ops_new/linear.cpp b/lightseq/csrc/ops_new/linear.cpp index 58bc1fdd..867c826a 100644 --- a/lightseq/csrc/ops_new/linear.cpp +++ b/lightseq/csrc/ops_new/linear.cpp @@ -4,7 +4,6 @@ namespace lightseq { template Variable* LinearOp::operator()(Variable* inp, Variable* weight) { - // size_t max_size = _max_batch_tokens * _output_size; _result = new Variable("LinearOp_out", _max_batch_tokens * _output_size, g_dtype(), g_dtype()); set_parents({inp, weight}); @@ -13,9 +12,18 @@ Variable* LinearOp::operator()(Variable* inp, Variable* weight) { } template -void LinearOp::forward() { - float beta = float(0.); +Variable* LinearOp::operator()(Variable* inp, Variable* weight, + Variable* residual) { + _use_residual = true; + _beta = float(1.); + _result = new Variable("LinearOp_out", residual); + set_parents({inp, weight, residual}); + this->set_children({_result}); + return _result; +} +template +void LinearOp::forward() { T1* input_ptr = (T1*)parent(0)->value(); T1* weights = (T1*)parent(1)->value(); T1* out_ptr = (T1*)child(0)->value(); @@ -23,15 +31,14 @@ void LinearOp::forward() { if (!_context_ptr->is_built()) { return; } - + // _beta = float(0.); #ifdef LIGHTSEQ_cuda cublasHandle_t _cublasHandle = _context_ptr->get_cublashandle(); cuda::cublas_gemm_ex(_cublasHandle, op_from_custom(_opA), op_from_custom(_opB), _output_size, _batch_tokens, - _input_size, &_alpha, &beta, weights, input_ptr, out_ptr, - cublasGemmAlgo_t(_gemm_algos[0])); + _input_size, &_alpha, &_beta, weights, input_ptr, + out_ptr, cublasGemmAlgo_t(_gemm_algos[0])); #elif defined LIGHTSEQ_x86 - x86::matrix_gemm(weights, input_ptr, out_ptr, _output_size, _batch_tokens, _input_size); #endif diff --git a/lightseq/csrc/ops_new/rms_layer_norm.cpp b/lightseq/csrc/ops_new/rms_layer_norm.cpp new file mode 100644 index 00000000..5302b66e --- /dev/null +++ b/lightseq/csrc/ops_new/rms_layer_norm.cpp @@ -0,0 +1,59 @@ +#include "rms_layer_norm.h" + +namespace lightseq { + +template +RMSLayerNormalizeOp::~RMSLayerNormalizeOp() {} + +template +std::tuple RMSLayerNormalizeOp::operator()( + Variable* inp, Variable* scale) { + size_t max_size = _max_batch_tokens * _hidden_dim; + _result = + new Variable("RMSLayerNormalizeOp_out", _max_batch_tokens * _hidden_dim, + g_dtype(), g_dtype()); + _residual = + new Variable("RMSLayerNormalizeOp_res", _max_batch_tokens * _hidden_dim, + g_dtype(), g_dtype()); + set_parents({inp, scale}); + this->set_children({_result, _residual}); + return std::make_tuple(_result, _residual); +} + +template +void RMSLayerNormalizeOp::before_forward(size_t batch_size, + size_t seq_len) { + _batch_tokens = batch_size * seq_len; + _result->set_shape({batch_size, seq_len, _hidden_dim}); + if (_use_residual) { + _residual->set_shape({batch_size, seq_len, _hidden_dim}); + } +} + +template +void RMSLayerNormalizeOp::forward() { + T1* inp_val = (T1*)parent(0)->value(); + T1* scale_val = (T1*)parent(1)->value(); + T1* out_val = (T1*)child(0)->value(); + T1* res_val = nullptr; + if (_use_residual) { + res_val = (T1*)child(1)->value(); + } + T1* rms_vars_val = (T1*)_rms_vars->tensor(); + + if (!_context_ptr->is_built()) { + return; + } + +#ifdef LIGHTSEQ_cuda + cudaStream_t stream = _context_ptr->get_stream(); + cuda::launch_rms_layer_norm(inp_val, scale_val, out_val, res_val, + rms_vars_val, _batch_tokens, _hidden_dim, stream); +#endif +} + +template class RMSLayerNormalizeOp; +#ifdef LIGHTSEQ_cuda +template class RMSLayerNormalizeOp<__half, __half>; +#endif +} // namespace lightseq diff --git a/lightseq/csrc/proto/CMakeLists.txt b/lightseq/csrc/proto/CMakeLists.txt index 086e35f6..ba27ed22 100644 --- a/lightseq/csrc/proto/CMakeLists.txt +++ b/lightseq/csrc/proto/CMakeLists.txt @@ -11,7 +11,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}) set(PROTO_FILES bert.proto bert_crf.proto transformer.proto gpt.proto) set(WEIGHT_FILES bert_weight.cc bert_crf_weight.cc transformer_weight.cc - gpt_weight.cc) + gpt_weight.cc llama_weight.cc) protobuf_generate_cpp(PROTO_SRC PROTO_HEADER ${PROTO_FILES}) add_library(weight_lib STATIC ${WEIGHT_FILES} ${PROTO_SRC} ${PROTO_HEADER} diff --git a/lightseq/csrc/proto/includes/hdf5_util.h b/lightseq/csrc/proto/includes/hdf5_util.h new file mode 100644 index 00000000..b5c2cebe --- /dev/null +++ b/lightseq/csrc/proto/includes/hdf5_util.h @@ -0,0 +1,26 @@ +#pragma once +#include "proto_headers.h" +#include "proto_util.h" +#include "util.h" + +template +void convert_dtype_by_gpu(float* source_addr, float* source_buffer, + T* target_buffer, T* target_addr, size_t size, + cudaStream_t stream) { + if (std::is_same::value) { + cudaMemcpyAsync(source_buffer, source_addr, size * sizeof(float), + cudaMemcpyDefault, stream); + lightseq::cuda::launch_convert_dtype(source_buffer, (__half*)target_addr, + size, 1024, stream); + } else if (std::is_same::value) { + cudaMemcpyAsync(target_addr, source_addr, size * sizeof(float), + cudaMemcpyDefault, stream); + } +} + +template +T* malloc_memory(size_t size) { + T* buffer_addr = nullptr; + cudaMalloc(&buffer_addr, size * sizeof(T)); + return buffer_addr; +} diff --git a/lightseq/csrc/proto/includes/llama_weight.h b/lightseq/csrc/proto/includes/llama_weight.h new file mode 100644 index 00000000..59b9d546 --- /dev/null +++ b/lightseq/csrc/proto/includes/llama_weight.h @@ -0,0 +1,93 @@ +#pragma once +#include "proto_headers.h" +#include "proto_util.h" +#include "hdf5_util.h" + +namespace lightseq { + +/* +Load the model weights which stored in custom proto file into GPU memory. +*/ +template +class LlamaWeight { + private: + cudaStream_t stream; + T float2required(float value); + + // parsing function for hdf5 + void hdf5_get_model_config(hid_t hdf5_file); + void hdf5_parse_emb_wei(hid_t hdf5_file); + void hdf5_parse_enc_wei(hid_t hdf5_file); + + // store the weights pointer + std::vector _p_d_src_emb_wei; // size: 4 + std::vector _p_d_enc_wei; // size: 12 * enc_layer_num + + // store the weights on gpu memory + std::vector _d_src_emb_wei; + std::vector _d_enc_wei; + + public: + std::string initializing(std::string weight_path); + + const std::vector &get_src_emb_wei() const { + // {token_emb, pos_emb, norm_scale, norm_bias} + return _p_d_src_emb_wei; + } + + const std::vector &get_enc_wei() const { + // {multihead_norm_scale, multihead_norm_bias, multihead_qkv_kernel, + // multihead_qkv_bias multihead_output_kernel, multihead_output_bias + // ffn_norm_scale, ffn_norm_bias} + // ffn_first_kernel, ffn_first_bias, ffn_second_kernel, ffn_second_bias} * + // encoder_layer_num + return _p_d_enc_wei; + } + + size_t _hidden_size; + int _inner_size; + int _max_step; + int _extra_decode_length; + int _src_vocab_size; + int _layer_num; // number of encoder layer + int _dim_per_head; + int _weight_per_enc_layer; // 12 + + int _head_num; + int _padding_id; // for src + std::string _generate_method = "topk"; + int _topk = 1; + float _topp = 0.75; + int _eos_id; + + int _beam_size = 1; + float _length_penalty = 1.0; + float _diverse_lambda = 0.; + bool _use_gelu = true; + + void print_model_config() { + std::cout << "***model config***" << std::endl; + std::cout << "decoder layers: " << _layer_num << std::endl; + std::cout << "hidden size: " << _hidden_size << std::endl; + std::cout << "inner size: " << _inner_size << std::endl; + std::cout << "head number: " << _head_num << std::endl; + std::cout << "dim per head: " << _dim_per_head << std::endl; + std::cout << "src vocab size: " << _src_vocab_size << std::endl; + std::cout << "use_gelu: " << _use_gelu << std::endl; + std::cout << "end_id: " << _eos_id << std::endl; + std::cout << "padding_id: " << _padding_id << std::endl; + std::cout << std::endl; + std::cout << "***generator config***" << std::endl; + std::cout << "beam size: " << _beam_size << std::endl; + std::cout << "max step: " << _max_step << std::endl; + std::cout << "extra decode length(max decode length - src input length): " + << _extra_decode_length << std::endl; + std::cout << "length penalty: " << _length_penalty << std::endl; + std::cout << "diverse lambda: " << _diverse_lambda << std::endl; + std::cout << "generate method: " << _generate_method << std::endl; + std::cout << "topk: " << _topk << std::endl; + std::cout << "topp: " << _topp << std::endl; + } +}; + +} // namespace lightseq diff --git a/lightseq/csrc/proto/llama.proto b/lightseq/csrc/proto/llama.proto new file mode 100644 index 00000000..fb9aaa37 --- /dev/null +++ b/lightseq/csrc/proto/llama.proto @@ -0,0 +1,60 @@ +syntax = "proto3"; +option optimize_for = LITE_RUNTIME; +// all the matrix are stored in row-major order, +// plz see https://en.wikipedia.org/wiki/Row-_and_column-major_order for details + +// the definition of "Multi-Head Attention", "Scaled Dot-Product Attention" and +// "Feed-Forward Networks" +// plz see https://arxiv.org/abs/1706.03762 for details + +message LlamaDecoderLayer { + // layer norm before "Llama Attention" + repeated float attention_norm_scale = 1; + repeated float attention_project_qkv = 2; + + // "Multi-Head Attention" linearly project weights kernel for output + // after "Scaled Dot-Product Attention", with shape (hidden_size, hidden_size) + repeated float attention_output = 5; + + // layer norm before "Llama Feed-Forward Networks" + repeated float ffn_norm_scale = 6; + + // "Llama MLP layer Networks" + repeated float gate_up_project_weight = 7; + repeated float down_project_weight = 9; +} + +message LlamaEmbeddingLayer { + // token embedding table + // for encoder, it is in [src_vocab_size, hidden_size] + // so, look it up directly will get the input token embedding + repeated float token_embedding = 1; + // the last layer_norm + repeated float post_norm_scale = 3; +} + +message LlamaModelConf { + int32 hidden_size = 1; + int32 inner_size = 2; + int32 max_step = 3; + int32 head_num = 4; + int32 layer_num = 5; + int32 src_padding_id = 6; + string generate_method = 7; + float topp = 8; + int32 topk = 9; + int32 eos_id = 10; + int32 extra_decode_length = 11; + int32 src_vocab_size = 12; + + int32 beam_size = 13; // beam size of beam search + float length_penalty = 14; // length penalty of beam search + float diverse_lambda = 15; // diverse beam search lambda + string act_method = 16; // act method of Llama MLP layer +} + +message Llama { + LlamaEmbeddingLayer src_embedding = 1; + repeated LlamaDecoderLayer decoder_layers = 2; + LlamaModelConf model_conf = 3; +} diff --git a/lightseq/csrc/proto/llama_weight.cc b/lightseq/csrc/proto/llama_weight.cc new file mode 100644 index 00000000..a141952e --- /dev/null +++ b/lightseq/csrc/proto/llama_weight.cc @@ -0,0 +1,329 @@ +#include "llama_weight.h" + +#include + +/** +@file +Load the model weights which stored in custom proto file into GPU memory. +Currently, fp16 and fp32 versions are provided. +Weights in proto file will always be in fp32. For fp16, the weights + will be casted from fp32 into fp16 +*/ + +namespace lightseq { + +/** +Cast weights into required datatype. +The datatype of weights in custom proto file will always be in fp32. +*/ +template <> +float LlamaWeight::float2required(float value) { + return value; +} + +#ifdef LIGHTSEQ_cuda +/** +fp16 version, cast fp32 into fp16 +*/ +template <> +__half LlamaWeight<__half>::float2required(float value) { + return __float2half_rn(value); +} +#endif + +/** +Read model config stored in custom hdf5 file. +*/ +template +void LlamaWeight::hdf5_get_model_config(hid_t hdf5_file) { + read_hdf5_dataset_scalar(hdf5_file, "model_conf/hidden_size", H5T_NATIVE_INT, + &_hidden_size); + + read_hdf5_dataset_scalar(hdf5_file, "model_conf/inner_size", H5T_NATIVE_INT, + &_inner_size); + + read_hdf5_dataset_scalar(hdf5_file, "model_conf/max_step", H5T_NATIVE_INT, + &_max_step); + + read_hdf5_dataset_scalar(hdf5_file, "model_conf/head_num", H5T_NATIVE_INT, + &_head_num); + + read_hdf5_dataset_scalar(hdf5_file, "model_conf/layer_num", H5T_NATIVE_INT, + &_layer_num); + + read_hdf5_dataset_scalar(hdf5_file, "model_conf/src_padding_id", + H5T_NATIVE_INT, &_padding_id); + + // special handling for string reading + // string were converted to numpy array of np.int8 in python + // hence needed to be read as an char array here + char _generate_method_buf[128]; // get 128 character for sampling method + int _generate_method_strlen = read_hdf5_dataset_data( + hdf5_file, "model_conf/generate_method", H5T_NATIVE_CHAR, + _generate_method_buf, [](int size) { return size > 128; }, + "Expect model_conf/generate_method to have less than 128 characters."); + std::string _generate_method_read = + std::string(_generate_method_buf, _generate_method_strlen); + if (_generate_method_read != "") { + _generate_method = _generate_method_read; + } + + int _topk_read; + read_hdf5_dataset_scalar(hdf5_file, "model_conf/topk", H5T_NATIVE_INT, + &_topk_read); + if (_topk_read != 0) { + _topk = _topk_read; + } + // _topk = 1; + + float _topp_read; + read_hdf5_dataset_scalar(hdf5_file, "model_conf/topp", H5T_NATIVE_FLOAT, + &_topp_read); + if (_topp_read != 0.0) { + _topp = _topp_read; + } + + int _eos_id_read; + read_hdf5_dataset_scalar(hdf5_file, "model_conf/eos_id", H5T_NATIVE_INT, + &_eos_id_read); + if (_eos_id_read != 0) { + _eos_id = _eos_id_read; + } + + read_hdf5_dataset_scalar(hdf5_file, "model_conf/extra_decode_length", + H5T_NATIVE_INT, &_extra_decode_length); + + read_hdf5_dataset_scalar(hdf5_file, "model_conf/src_vocab_size", + H5T_NATIVE_INT, &_src_vocab_size); + + try { + read_hdf5_dataset_scalar(hdf5_file, "model_conf/beam_size", H5T_NATIVE_INT, + &_beam_size); + } catch (HDF5DatasetNotFoundError& e) { + _beam_size = 1; + } + + try { + read_hdf5_dataset_scalar(hdf5_file, "model_conf/length_penalty", + H5T_NATIVE_FLOAT, &_length_penalty); + } catch (HDF5DatasetNotFoundError& e) { + _length_penalty = 1.0; + } + + try { + read_hdf5_dataset_scalar(hdf5_file, "model_conf/diverse_lambda", + H5T_NATIVE_FLOAT, &_diverse_lambda); + } catch (HDF5DatasetNotFoundError& e) { + _diverse_lambda = 0.; + } + + _dim_per_head = _hidden_size / _head_num; +} + +/** +Load the weights of embedding layer into GPU memory. +*/ +template +void LlamaWeight::hdf5_parse_emb_wei(hid_t hdf5_file) { + std::string dataset_prefix = "src_embedding"; + size_t value_size = _src_vocab_size * _hidden_size + _hidden_size; + + size_t max_value_size = _src_vocab_size * _hidden_size; + + std::vector offset; + std::vector value(max_value_size); + std::cout << "loading " << value_size / (1024 * 1024) + << " M of decoder weight." << std::endl; + + const size_t max_buffer_size = max_value_size; + float* source_buffer; + T* target_buffer; + cudaMalloc(&source_buffer, max_buffer_size * sizeof(float)); + cudaMalloc(&target_buffer, max_buffer_size * sizeof(T)); + T* addr = nullptr; + + size_t buffer_size; + + buffer_size = _src_vocab_size * _hidden_size; + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/token_embedding", H5T_NATIVE_FLOAT, + value.data(), [=](int size) { return size != buffer_size; }, + "Wrong token_embedding_size !"); + addr = malloc_memory(buffer_size); + _p_d_src_emb_wei.push_back(addr); + convert_dtype_by_gpu(value.data(), source_buffer, target_buffer, addr, + buffer_size, stream); + + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/post_norm_scale", H5T_NATIVE_FLOAT, + value.data(), [=](int size) { return size != _hidden_size; }, + "Wrong norm_scale_size !"); + buffer_size = _hidden_size; + addr = malloc_memory(buffer_size); + _p_d_src_emb_wei.push_back(addr); + convert_dtype_by_gpu(value.data(), source_buffer, target_buffer, addr, + buffer_size, stream); + + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/logits_linear_weight", H5T_NATIVE_FLOAT, + value.data(), + [=](int size) { return size != _src_vocab_size * _hidden_size; }, + "Wrong norm_scale_size !"); + buffer_size = _src_vocab_size * _hidden_size; + addr = malloc_memory(buffer_size); + _p_d_src_emb_wei.push_back(addr); + convert_dtype_by_gpu(value.data(), source_buffer, target_buffer, addr, + buffer_size, stream); + + std::cout << "finish initializing emb_wei from host to device" << std::endl; + + value.clear(); + value.shrink_to_fit(); + cudaFree(source_buffer); + cudaFree(target_buffer); +} + +/** +Load the weights of encoder into GPU memory. +*/ +template +void LlamaWeight::hdf5_parse_enc_wei(hid_t hdf5_file) { + size_t value_size = + (_hidden_size + _hidden_size * _hidden_size * 3 + + _hidden_size * _hidden_size + _hidden_size + + _hidden_size * _inner_size * 2 + _hidden_size * _inner_size) * + _layer_num; + + std::vector value_size_vec = {_hidden_size, + _hidden_size * _hidden_size * 3, + _hidden_size * _hidden_size, + _hidden_size, + _hidden_size * _inner_size * 2, + _hidden_size * _inner_size}; + size_t max_value_size = + *max_element(value_size_vec.begin(), value_size_vec.end()); + + std::vector offset; + std::vector value(max_value_size); + std::cout << "loading " << value_size / (1024 * 1024) + << " M of decoder weight." << std::endl; + + const size_t max_buffer_size = max_value_size; + float* source_buffer; + T* target_buffer; + cudaMalloc(&source_buffer, max_buffer_size * sizeof(float)); + cudaMalloc(&target_buffer, max_buffer_size * sizeof(T)); + + T* addr = nullptr; + size_t buffer_size; + for (int layer_id = 0; layer_id < _layer_num; ++layer_id) { + std::string dataset_prefix = "decoder_layers/" + std::to_string(layer_id); + + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/attention_norm_scale", H5T_NATIVE_FLOAT, + value.data(), [=](int size) { return size != _hidden_size; }, + "Wrong attention_norm_scale_size !"); + buffer_size = _hidden_size; + addr = malloc_memory(buffer_size); + _p_d_enc_wei.push_back(addr); + convert_dtype_by_gpu(value.data(), source_buffer, target_buffer, addr, + buffer_size, stream); + + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/attention_project_qkv", H5T_NATIVE_FLOAT, + value.data(), + [=](int size) { return size != _hidden_size * _hidden_size * 3; }, + "Wrong attention_project_q_size !"); + buffer_size = _hidden_size * _hidden_size * 3; + addr = malloc_memory(buffer_size); + _p_d_enc_wei.push_back(addr); + convert_dtype_by_gpu(value.data(), source_buffer, target_buffer, addr, + buffer_size, stream); + + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/attention_output", H5T_NATIVE_FLOAT, + value.data(), + [=](int size) { return size != _hidden_size * _hidden_size; }, + "Wrong attention_output_size !"); + buffer_size = _hidden_size * _hidden_size; + addr = malloc_memory(buffer_size); + _p_d_enc_wei.push_back(addr); + convert_dtype_by_gpu(value.data(), source_buffer, target_buffer, addr, + buffer_size, stream); + + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/ffn_norm_scale", H5T_NATIVE_FLOAT, + value.data(), [=](int size) { return size != _hidden_size; }, + "Wrong ffn_norm_scale_size !"); + buffer_size = _hidden_size; + addr = malloc_memory(buffer_size); + _p_d_enc_wei.push_back(addr); + convert_dtype_by_gpu(value.data(), source_buffer, target_buffer, addr, + buffer_size, stream); + + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/gate_up_project_weight", H5T_NATIVE_FLOAT, + value.data(), + [=](int size) { return size != _hidden_size * _inner_size * 2; }, + "Wrong gate_up_project_weight_size !"); + buffer_size = _hidden_size * _inner_size * 2; + addr = malloc_memory(buffer_size); + _p_d_enc_wei.push_back(addr); + convert_dtype_by_gpu(value.data(), source_buffer, target_buffer, addr, + buffer_size, stream); + + read_hdf5_dataset_data( + hdf5_file, dataset_prefix + "/down_project_weight", H5T_NATIVE_FLOAT, + value.data(), + [=](int size) { return size != _hidden_size * _inner_size; }, + "Wrong down_project_weight_size !"); + buffer_size = _hidden_size * _inner_size; + addr = malloc_memory(buffer_size); + _p_d_enc_wei.push_back(addr); + convert_dtype_by_gpu(value.data(), source_buffer, target_buffer, addr, + buffer_size, stream); + } + + std::cout << "finish initializing dec_wei from host to device" << std::endl; + + value.clear(); + value.shrink_to_fit(); + cudaFree(source_buffer); + cudaFree(target_buffer); +} + +/** +Load the proto file into CPU memory and parse it. +*/ +template +std::string LlamaWeight::initializing(std::string weight_path) { + cudaStreamCreate(&stream); + // If weight is of type pb, parse using proto parser. + if (endswith(weight_path, ".hdf5")) { + std::cout << "Parsing hdf5: " << weight_path << std::endl; + + hid_t hdf5_file = H5Fopen(weight_path.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT); + if (hdf5_file < 0) { + return "Unable to read HDF5 file from " + weight_path; + } + hdf5_get_model_config(hdf5_file); + + // hdf5_parse_* would throw std::runtime_error on error + hdf5_parse_emb_wei(hdf5_file); + hdf5_parse_enc_wei(hdf5_file); + H5Fclose(hdf5_file); + + cudaStreamSynchronize(stream); + std::cout << "Finish loading all weight from host to device" << std::endl; + return ""; + } else { + return "Unsupported weight extention for [" + weight_path + + "]; Supported extensions: .pb, .hdf5\n"; + } +} +#ifdef LIGHTSEQ_cuda +template class LlamaWeight<__half>; +#endif +template class LlamaWeight; + +} // namespace lightseq diff --git a/lightseq/csrc/pybind/pybind_kernel_cuda.cpp b/lightseq/csrc/pybind/pybind_kernel_cuda.cpp index 149a3345..332c4a3d 100644 --- a/lightseq/csrc/pybind/pybind_kernel_cuda.cpp +++ b/lightseq/csrc/pybind/pybind_kernel_cuda.cpp @@ -3,9 +3,15 @@ #include #include - #include "cuda_util.h" #include "kernels.h" +#include "llama_kernels.h" +#include "cmath" +#include "memory" +#include +#include +#include +#include typedef const torch::Tensor cts; typedef torch::Tensor ts; @@ -113,22 +119,23 @@ void torch_launch_attn_softmax(torch::Tensor &vals, CHECK_GPU_ERROR(cudaGetLastError()); } -template -void torch_launch_attn_softmax_new(torch::Tensor &out, torch::Tensor &inp, - const torch::Tensor &attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool is_dec_self_attn, - bool mask_future) { - const T *attn_mask_ptr = rptr(attn_mask); - if (is_dec_self_attn) { - attn_mask_ptr = nullptr; - } - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - launch_attn_softmax_new(rptr(out), rptr(inp), attn_mask_ptr, batch_size, - nhead, from_len, to_len, mask_future, stream); - // cudaStreamSynchronize(stream); - CHECK_GPU_ERROR(cudaGetLastError()); -} +// template +// void torch_launch_attn_softmax_new(torch::Tensor &out, torch::Tensor &inp, +// const torch::Tensor &attn_mask, +// int batch_size, int nhead, int from_len, +// int to_len, bool is_dec_self_attn, +// bool mask_future) { +// const T *attn_mask_ptr = rptr(attn_mask); +// if (is_dec_self_attn) { +// attn_mask_ptr = nullptr; +// } +// cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +// launch_attn_softmax_new(rptr(out), rptr(inp), attn_mask_ptr, +// batch_size, +// nhead, from_len, to_len, mask_future, stream); +// // cudaStreamSynchronize(stream); +// CHECK_GPU_ERROR(cudaGetLastError()); +// } template void torch_launch_attn_softmax_bw(torch::Tensor &out_grad, @@ -295,7 +302,7 @@ void torch_launch_concat3_dim1(const torch::Tensor &inp1, cudaStream_t stream = at::cuda::getCurrentCUDAStream(); launch_concat3_dim1(rptr(inp1), rptr(inp2), rptr(output), sz0, sz2, sz1_1, sz1_2, stream); - // cudaStreamSynchronize(stream); + cudaStreamSynchronize(stream); CHECK_GPU_ERROR(cudaGetLastError()); } @@ -420,10 +427,134 @@ void torch_launch_viterbi(const torch::Tensor &start_transition, cudaStreamSynchronize(stream); CHECK_GPU_ERROR(cudaGetLastError()); } + +class RotaryPositionWeight { + public: + float *_device_sin_ptr; + float *_device_cos_ptr; + float *_sin_ptr; + float *_cos_ptr; + + __half *_device_sin_half_ptr; + __half *_device_cos_half_ptr; + __half *_sin_half_ptr; + __half *_cos_half_ptr; + + int _max_step; + int _head_dim; + + RotaryPositionWeight(int max_step, int head_dim) + : _max_step(max_step), _head_dim(head_dim) { + if (head_dim & 1) { + printf( + "Error! head dim should be even number while using RotaryPositionQk " + "Operator.\n"); + exit(0); + } + + int total_size = max_step * head_dim / 2; + _sin_ptr = (float *)malloc(total_size * sizeof(float)); + _cos_ptr = (float *)malloc(total_size * sizeof(float)); + + _sin_half_ptr = (__half *)malloc(total_size * sizeof(__half)); + _cos_half_ptr = (__half *)malloc(total_size * sizeof(__half)); + + for (int i = 0; i < head_dim / 2; i++) { + float theta = std::pow(10000, -2. * i / head_dim); + for (int j = 0; j < max_step; j++) { + *(_sin_ptr + j * head_dim / 2 + i) = + sin(j * theta); // shape: [max_step, head_dim / 2] + *(_cos_ptr + j * head_dim / 2 + i) = + cos(j * theta); // shape: [max_step, head_dim / 2] + + *(_sin_half_ptr + j * head_dim / 2 + i) = + __float2half_rn(sin(j * theta)); // shape: [max_step, head_dim / 2] + *(_cos_half_ptr + j * head_dim / 2 + i) = + __float2half_rn(cos(j * theta)); // shape: [max_step, head_dim / 2] + } + } + + cudaMalloc(&_device_sin_ptr, total_size * sizeof(float)); + cudaMalloc(&_device_cos_ptr, total_size * sizeof(float)); + cudaMemcpy(_device_sin_ptr, _sin_ptr, total_size * sizeof(float), + cudaMemcpyDefault); + cudaMemcpy(_device_cos_ptr, _cos_ptr, total_size * sizeof(float), + cudaMemcpyDefault); + + cudaMalloc(&_device_sin_half_ptr, total_size * sizeof(__half)); + cudaMalloc(&_device_cos_half_ptr, total_size * sizeof(__half)); + cudaMemcpy(_device_sin_half_ptr, _sin_half_ptr, total_size * sizeof(__half), + cudaMemcpyDefault); + cudaMemcpy(_device_cos_half_ptr, _cos_half_ptr, total_size * sizeof(__half), + cudaMemcpyDefault); + } +} _rotary_position_instance(2048, 128); + +template +void torch_launch_split_rotary_position( + const torch::Tensor &input, torch::Tensor &q_out, + torch::Tensor &cache_k_out, torch::Tensor &cache_v_out, int batch_size, + int nhead, int offset_seq_len, int query_seq_len, int head_dim) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (query_seq_len + offset_seq_len > _rotary_position_instance._max_step) { + printf( + "Error! query_seq_len + offset_seq_len > " + "_rotary_position_instance._max_step\n"); + return; + } + if (_rotary_position_instance._head_dim != head_dim) { + printf("Error! _rotary_position_instance._head_dim != head_dim\n"); + return; + } + if (std::is_same::value) { + launch_split_rotary_position_qkv( + rptr(input), _rotary_position_instance._device_sin_ptr, + _rotary_position_instance._device_cos_ptr, rptr(q_out), + rptr(cache_k_out), rptr(cache_v_out), + offset_seq_len + query_seq_len, batch_size, nhead, offset_seq_len, + query_seq_len, head_dim, stream); + } else { + launch_split_rotary_position_qkv<__half>( + rptr<__half>(input), _rotary_position_instance._device_sin_half_ptr, + _rotary_position_instance._device_cos_half_ptr, rptr<__half>(q_out), + rptr<__half>(cache_k_out), rptr<__half>(cache_v_out), + offset_seq_len + query_seq_len, batch_size, nhead, offset_seq_len, + query_seq_len, head_dim, stream); + } + cudaStreamSynchronize(stream); + CHECK_GPU_ERROR(cudaGetLastError()); +} + +template +void torch_silu_elewise_product(const torch::Tensor &inp, torch::Tensor out, + int batch_size, int seq_len, int inner_size) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + launch_silu_elewise_product(rptr(inp), rptr(out), batch_size, + seq_len, inner_size, stream); + cudaStreamSynchronize(stream); + CHECK_GPU_ERROR(cudaGetLastError()); +} + +template +void torch_rms_layer_norm(const torch::Tensor &inp, const torch::Tensor &scale, + torch::Tensor &out, torch::Tensor &rms_out, + int batch_tokens, int hidden_dim, + const float epsilon = 1e-6) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + launch_rms_layer_norm(rptr(inp), rptr(scale), rptr(out), nullptr, + rptr(rms_out), batch_tokens, hidden_dim, stream, + epsilon); + cudaStreamSynchronize(stream); + CHECK_GPU_ERROR(cudaGetLastError()); +} + } // namespace cuda } // namespace lightseq PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // lightseq::cuda::_rotary_position_instance = + // lightseq::cuda::RotaryPositionWeight(2048, 128); + m.def("torch_launch_transform_0213_fp32", &lightseq::cuda::torch_launch_transform_0213, "Test kernel wrapper"); @@ -468,12 +599,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("torch_launch_attn_softmax_fp16", &lightseq::cuda::torch_launch_attn_softmax<__half>, "Test kernel wrapper"); - m.def("torch_launch_attn_softmax_new_fp32", - &lightseq::cuda::torch_launch_attn_softmax_new, - "Test kernel wrapper"); - m.def("torch_launch_attn_softmax_new_fp16", - &lightseq::cuda::torch_launch_attn_softmax_new<__half>, - "Test kernel wrapper"); + // m.def("torch_launch_attn_softmax_new_fp32", + // &lightseq::cuda::torch_launch_attn_softmax_new, + // "Test kernel wrapper"); + // m.def("torch_launch_attn_softmax_new_fp16", + // &lightseq::cuda::torch_launch_attn_softmax_new<__half>, + // "Test kernel wrapper"); m.def("torch_launch_attn_softmax_bw_fp32", &lightseq::cuda::torch_launch_attn_softmax_bw, "Test kernel wrapper"); @@ -623,4 +754,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &lightseq::cuda::torch_launch_viterbi<__half>, "Test kernel wrapper"); m.def("torch_launch_viterbi_fp32", &lightseq::cuda::torch_launch_viterbi, "Test kernel wrapper"); + + m.def("torch_launch_split_rotary_position_fp32", + &lightseq::cuda::torch_launch_split_rotary_position, + "Test llama rotary position kernel"); + m.def("torch_launch_split_rotary_position_fp16", + &lightseq::cuda::torch_launch_split_rotary_position, + "Test llama rotary position kernel"); + m.def("torch_silu_elewise_product_fp32", + &lightseq::cuda::torch_silu_elewise_product, + "Test llama rotary position kernel"); + m.def("torch_silu_elewise_product_fp16", + &lightseq::cuda::torch_silu_elewise_product<__half>, + "Test llama rotary position kernel"); + + m.def("torch_rms_layer_norm_fp32", + &lightseq::cuda::torch_rms_layer_norm, + "Test llama rms layer norm kernel"); + m.def("torch_rms_layer_norm_fp16", + &lightseq::cuda::torch_rms_layer_norm<__half>, + "Test llama rms layer norm kernel"); } diff --git a/lightseq/csrc/pybind/pybind_model.cpp b/lightseq/csrc/pybind/pybind_model.cpp index 9a8b3c58..c183738d 100644 --- a/lightseq/csrc/pybind/pybind_model.cpp +++ b/lightseq/csrc/pybind/pybind_model.cpp @@ -9,6 +9,7 @@ #include "bert.h" #include "bert_crf.h" #include "gpt.h" +#include "llama.h" #include "transformer.h" namespace py = pybind11; @@ -306,6 +307,74 @@ class PyGpt { return std::make_tuple(output, scores); } }; + +class PyLlama { + private: + LSModel *model_; + int *d_input_; + std::vector d_outputs_; + + public: + PyLlama(std::string weight_path, int max_batch_size) { + model_ = LSModelFactory::GetInstance().CreateModel("Llama", weight_path, + max_batch_size); + std::vector max_input_shape = model_->get_input_max_shape(0); + int max_size = + std::accumulate(max_input_shape.begin(), max_input_shape.end(), 1, + std::multiplies()); + CHECK_GPU_ERROR(cudaMalloc(&d_input_, sizeof(int) * max_size)); + + for (int i = 0; i < model_->get_output_size(); i++) { + void *d_output; + std::vector shape = model_->get_output_max_shape(i); + int output_size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + CHECK_GPU_ERROR(cudaMalloc(&d_output, output_size * sizeof(int))); + model_->set_output_ptr(i, d_output); + d_outputs_.push_back(d_output); + } + } + ~PyLlama() { + delete model_; + CHECK_GPU_ERROR(cudaFree(d_input_)); + for (auto d_output : d_outputs_) { + CHECK_GPU_ERROR(cudaFree(d_output)); + } + } + + py::array_t infer( + py::array_t input_seq) { + auto input_seq_out = input_seq.mutable_unchecked<2>(); + const int *input_seq_data = input_seq_out.data(0, 0); + int batch_size = input_seq_out.shape(0); + int batch_seq_len = input_seq_out.shape(1); + if (model_->get_output_dtype(0) != DataType::kInt32) { + throw std::runtime_error( + "This model is not for sample, maybe you have set the " + "sampling_method to " + "ppl"); + } + + CHECK_GPU_ERROR(cudaMemcpy(d_input_, input_seq_data, + sizeof(int) * input_seq_out.size(), + cudaMemcpyHostToDevice)); + + model_->set_input_ptr(0, d_input_); + model_->set_input_shape(0, {batch_size, batch_seq_len}); + + model_->Infer(); + + std::vector output_shape = model_->get_output_shape(0); + auto output = py::array_t(output_shape); + int *output_data = output.mutable_data(0, 0); + const int *d_output = static_cast(model_->get_output_ptr(0)); + CHECK_GPU_ERROR(cudaMemcpy(output_data, d_output, + sizeof(int) * output.size(), + cudaMemcpyDeviceToHost)); + + return output; + } +}; } // namespace cuda } // namespace lightseq @@ -335,4 +404,10 @@ PYBIND11_MODULE(inference, m) { py::arg("max_batch_size")) .def("infer", &lightseq::cuda::PyGpt::infer, py::return_value_policy::reference_internal, py::arg("input_seq")); + + py::class_(m, "Llama") + .def(py::init(), py::arg("weight_path"), + py::arg("max_batch_size")) + .def("infer", &lightseq::cuda::PyLlama::infer, + py::return_value_policy::reference_internal, py::arg("input_seq")); } diff --git a/lightseq/csrc/pytorch/builder/cuda_kernel_builder.py b/lightseq/csrc/pytorch/builder/cuda_kernel_builder.py index 89ec4d8a..3135b3de 100644 --- a/lightseq/csrc/pytorch/builder/cuda_kernel_builder.py +++ b/lightseq/csrc/pytorch/builder/cuda_kernel_builder.py @@ -31,6 +31,7 @@ def sources(self): "csrc/kernels/cuda/dropout_kernels.cu", "csrc/kernels/cuda/embedding_kernels.cu", "csrc/kernels/cuda/quantize_kernels.cu", + "csrc/kernels/cuda/llama_kernels.cu", "csrc/kernels/cuda/crf.cu", "csrc/pybind/pybind_kernel_cuda.cpp", ] diff --git a/lightseq/csrc/tests/cuda/test_kernel.py b/lightseq/csrc/tests/cuda/test_kernel.py index 3bd310b3..57f2c386 100644 --- a/lightseq/csrc/tests/cuda/test_kernel.py +++ b/lightseq/csrc/tests/cuda/test_kernel.py @@ -1575,40 +1575,189 @@ def baseline(): return custom, baseline +@kt.case(atol=1e-2, rtol=1e-3, dtypes=[torch.float, torch.half]) +def test_split_rotary_position_qkv(): + batch_size, offset_seq_len = kt.bs_sl() + nhead = kt.nhead + head_dim = 128 + seq_len = 1 + seq_len = random.randint(1, 2048) + offset_seq_len = random.randint(0, 2048 - seq_len) + outshape = kt.rand((batch_size, nhead, seq_len, head_dim)) + + cachek = kt.rand((batch_size, nhead, offset_seq_len, head_dim)) + cachev = kt.rand((batch_size, nhead, offset_seq_len, head_dim)) + q_tensor = kt.rand((batch_size, seq_len, nhead, head_dim)) + k_tensor = kt.rand((batch_size, seq_len, nhead, head_dim)) + v_tensor = kt.rand((batch_size, seq_len, nhead, head_dim)) + qkv_tensor = torch.cat((q_tensor, k_tensor, v_tensor), dim=2) + + out_cachek = torch.cat((cachek, outshape), dim=2) + out_cachev = torch.cat((cachev, outshape), dim=2) + + func = None + if kt.dtype == torch.float: + func = cuda_module.torch_launch_split_rotary_position_fp32 + elif kt.dtype == torch.half: + func = cuda_module.torch_launch_split_rotary_position_fp16 + + custom_q = torch.empty_like(q_tensor) + + def custom(): + func( + qkv_tensor, + custom_q, + out_cachek, + out_cachev, + batch_size, + nhead, + offset_seq_len, + seq_len, + head_dim, + ) + return [custom_q.contiguous(), out_cachek.contiguous(), out_cachev.contiguous()] + + inv_freq = 1.0 / ( + 10000 ** (torch.arange(0, head_dim, 2).float().to(device="cuda:0") / head_dim) + ) + t = torch.arange(2048, device="cuda:0", dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + cos_cached = emb.cos()[None, None, :, :].to(device="cuda:0", dtype=kt.dtype) + sin_cached = emb.sin()[None, None, :, :].to(device="cuda:0", dtype=kt.dtype) + + def baseline(): + trans_q = q_tensor.transpose(1, 2) + trans_k = k_tensor.transpose(1, 2) + trans_v = v_tensor.transpose(1, 2) + kv_seq_len = offset_seq_len + seq_len + cos = cos_cached[:, :, :kv_seq_len, ...] + sin = sin_cached[:, :, :kv_seq_len, ...] + gather_indices = ( + (torch.arange(seq_len) + offset_seq_len)[None, None, :, None] + .repeat(batch_size, cos.shape[1], 1, cos.shape[3]) + .to("cuda:0") + ) + cos = torch.gather( + cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices + ) + sin = torch.gather( + sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices + ) + + def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + q_out = (trans_q * cos) + (rotate_half(trans_q) * sin) + k_out = (trans_k * cos) + (rotate_half(trans_k) * sin) + k_out = torch.cat((cachek, k_out), dim=2) + v_out = torch.cat((cachev, trans_v), dim=2) + return [q_out.contiguous(), k_out.contiguous(), v_out.contiguous()] + + return custom, baseline + + +from transformers import LlamaModel +from transformers.activations import SiLUActivation + + +@kt.case(atol=1e-3, rtol=1e-4, dtypes=[torch.float, torch.half]) +def test_silu_elewise_product(): + batch_size, seq_len = 1, 256 + hidden_size = 13824 + inpA = kt.rand((batch_size, seq_len, hidden_size)) + inpB = kt.rand((batch_size, seq_len, hidden_size)) + custom_outC = torch.empty_like(inpA) + + act_func = SiLUActivation() + func = ( + cuda_module.torch_silu_elewise_product_fp32 + if kt.dtype == torch.float + else cuda_module.torch_silu_elewise_product_fp16 + ) + + def custom(): + func(inpA, inpB, custom_outC, batch_size, seq_len, hidden_size) + return [custom_outC.contiguous()] + + def baseline(): + output = act_func(inpA) * inpB + return [output.contiguous()] + + return custom, baseline + + +@kt.case(atol=1e-3, rtol=1e-4, dtypes=[torch.float, torch.half]) +def test_rms_layer_norm(): # torch_rms_layer_norm + batch_size, seq_len = 1, 1 # kt.bs_sl() + hidden_size = 5120 + inp = kt.rand((batch_size, seq_len, hidden_size)) + scale = kt.rand((hidden_size)) + custom_out = torch.empty_like(inp) + rms_out = kt.rand((batch_size, seq_len)) + + func = ( + cuda_module.torch_rms_layer_norm_fp32 + if kt.dtype == torch.float + else cuda_module.torch_rms_layer_norm_fp16 + ) + + def custom(): + func(inp, scale, custom_out, rms_out, batch_size * seq_len, hidden_size, 1e-6) + return [rms_out.contiguous(), custom_out.contiguous()] + + def baseline(): + # output = act_func(inpA) * inpB + variance = inp.to(torch.float32).pow(2).mean(-1, keepdim=True) + rms_var = torch.rsqrt(variance + 1e-6).to(dtype=kt.dtype) + hidden_states = inp * rms_var + output = (scale * hidden_states).to(dtype=kt.dtype) + return [rms_var.contiguous(), output.contiguous()] + + return custom, baseline + + if __name__ == "__main__": kt.init(device="cuda:0", nhead=16) kt.run( [ - "test_launch_transform_0213", - "test_launch_bias_add_transform_20314", - "test_launch_transform4d_0213", - "test_launch_bias_add_transform_20314_new", - "test_launch_fused_add2", - "test_launch_ffn_bias_bwd", - # "test_launch_attn_softmax", # need to fix - "test_launch_attn_softmax_new", - # "test_launch_attn_softmax_bw", # need to fix - "test_launch_attn_softmax_bw_new", - "test_launch_layer_norm", - # "test_launch_ln_bw", # need to fix - "test_launch_concat3_dim1", - # "test_adam", # need to fix - "test_launch_dropout_relu_bias", - "test_launch_dropout_relu_bias_bwd", - "test_launch_dropout_gelu_bias", - # "test_launch_dropout_gelu_bias_bwd", # need to fix - # "test_launch_layer_norm_i8O", # need to fix - # "test_launch_ln_i8O_bw", # need to fix - "test_launch_dropout_relu_bias_i8I_i8O", - # "test_launch_dropout_relu_bias_i8I_i8O_bwd", # need to fix - "test_launch_dropout_gelu_bias_i8I_i8O", - # "test_launch_dropout_gelu_bias_i8I_i8O_bwd", # need to fix - "test_launch_quant_bias_dropout_residual", - "test_launch_quant_bias_add_transform_20314", - # "test_launch_quant_transform4d_0213", # need to fix - # "test_torch_launch_ls_quantize", # need to fix - "test_torch_launch_ls_dequantize", - # "test_torch_launch_fake_quantize", # need to fix - # "test_crf", # need to fix + # "test_rms_layer_norm", + # "test_silu_elewise_product", + "test_split_rotary_position_qkv", + # "test_launch_transform_0213", + # "test_launch_bias_add_transform_20314", + # "test_launch_transform4d_0213", + # "test_launch_bias_add_transform_20314_new", + # "test_launch_fused_add2", + # "test_launch_ffn_bias_bwd", + # # "test_launch_attn_softmax", # need to fix + # "test_launch_attn_softmax_new", + # # "test_launch_attn_softmax_bw", # need to fix + # "test_launch_attn_softmax_bw_new", + # "test_launch_layer_norm", + # # "test_launch_ln_bw", # need to fix + # "test_launch_concat3_dim1", + # # "test_adam", # need to fix + # "test_launch_dropout_relu_bias", + # "test_launch_dropout_relu_bias_bwd", + # "test_launch_dropout_gelu_bias", + # # "test_launch_dropout_gelu_bias_bwd", # need to fix + # # "test_launch_layer_norm_i8O", # need to fix + # # "test_launch_ln_i8O_bw", # need to fix + # "test_launch_dropout_relu_bias_i8I_i8O", + # # "test_launch_dropout_relu_bias_i8I_i8O_bwd", # need to fix + # "test_launch_dropout_gelu_bias_i8I_i8O", + # # "test_launch_dropout_gelu_bias_i8I_i8O_bwd", # need to fix + # "test_launch_quant_bias_dropout_residual", + # "test_launch_quant_bias_add_transform_20314", + # # "test_launch_quant_transform4d_0213", # need to fix + # # "test_torch_launch_ls_quantize", # need to fix + # "test_torch_launch_ls_dequantize", + # # "test_torch_launch_fake_quantize", # need to fix + # # "test_crf", # need to fix ] )