diff --git a/finetune/run_classifier_deepspeed.py b/finetune/run_classifier_deepspeed.py index 155d02d6..6030de6d 100644 --- a/finetune/run_classifier_deepspeed.py +++ b/finetune/run_classifier_deepspeed.py @@ -14,16 +14,13 @@ sys.path.append(tencentpretrain_dir) from tencentpretrain.opts import deepspeed_opts -from tencentpretrain.model_loader import * from finetune.run_classifier import * +from tencentpretrain.model_loader import _load_state_dict_into_model - -def read_dataset(args, path, split): - dataset, columns = [], {} - if split: - for i in range(args.world_size): - dataset.append([]) - index = 0 +def read_dataset(args, path): + dataset, instances, columns = [], [], {} + for i in range(args.world_size): + dataset.append([]) with open(path, mode="r", encoding="utf-8") as f: for line_id, line in enumerate(f): if line_id == 0: @@ -31,53 +28,69 @@ def read_dataset(args, path, split): columns[column_name] = i continue line = line.rstrip("\r\n").split("\t") - tgt = int(line[columns["label"]]) - if args.soft_targets and "logits" in columns.keys(): - soft_tgt = [float(value) for value in line[columns["logits"]].split(" ")] - if "text_b" not in columns: # Sentence classification. - text_a = line[columns["text_a"]] - src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text_a) + [SEP_TOKEN]) - seg = [1] * len(src) - else: # Sentence-pair classification. - text_a, text_b = line[columns["text_a"]], line[columns["text_b"]] - src_a = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text_a) + [SEP_TOKEN]) - src_b = args.tokenizer.convert_tokens_to_ids(args.tokenizer.tokenize(text_b) + [SEP_TOKEN]) - src = src_a + src_b - seg = [1] * len(src_a) + [2] * len(src_b) - - if len(src) > args.seq_length: - src = src[: args.seq_length] - seg = seg[: args.seq_length] - PAD_ID = args.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0] - while len(src) < args.seq_length: - src.append(PAD_ID) - seg.append(0) - if split: - if args.soft_targets and "logits" in columns.keys(): - dataset[index].append((src, tgt, seg, soft_tgt)) - else: - dataset[index].append((src, tgt, seg)) - index += 1 - if index == args.world_size: - index = 0 - else: - if args.soft_targets and "logits" in columns.keys(): - dataset.append((src, tgt, seg, soft_tgt)) - else: - dataset.append((src, tgt, seg)) - if split: - max_data_num_rank_index = 0 - max_data_num = len(dataset[0]) - for i in range(args.world_size): - if len(dataset[i]) > max_data_num: - max_data_num_rank_index = i - max_data_num = len(dataset[i]) - for i in range(args.world_size): - if len(dataset[i]) < max_data_num: - dataset[i].append(dataset[max_data_num_rank_index][-1]) - + if len(columns) != len(line): + continue + instances.append(line) + rank_num = math.ceil(1.0 * len(instances) / args.world_size) + index = 0 + for line_id, line in enumerate(instances): + tgt = int(line[columns["label"]]) + if args.soft_targets and "logits" in columns.keys(): + soft_tgt = [float(value) for value in line[columns["logits"]].split(" ")] + if "text_b" not in columns: # Sentence classification. + text_a = line[columns["text_a"]] + src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text_a) + [SEP_TOKEN]) + seg = [1] * len(src) + else: # Sentence-pair classification. + text_a, text_b = line[columns["text_a"]], line[columns["text_b"]] + src_a = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text_a) + [SEP_TOKEN]) + src_b = args.tokenizer.convert_tokens_to_ids(args.tokenizer.tokenize(text_b) + [SEP_TOKEN]) + src = src_a + src_b + seg = [1] * len(src_a) + [2] * len(src_b) + + if len(src) > args.seq_length: + src = src[: args.seq_length] + seg = seg[: args.seq_length] + PAD_ID = args.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0] + while len(src) < args.seq_length: + src.append(PAD_ID) + seg.append(0) + if args.soft_targets and "logits" in columns.keys(): + dataset[index].append((src, tgt, seg, 0, soft_tgt)) + else: + dataset[index].append((src, tgt, seg, 0)) + if (line_id+1) % rank_num == 0: + index += 1 + for i in range(args.world_size): + while len(dataset[i]) < rank_num: + dataset[i].append(tuple([1 if j == 3 else dataset[0][-1][j] for j in range(len(dataset[0][-1]))])) return dataset +def load_model(args, model, model_path): + if args.enable_zero3: + with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config): + if os.path.isdir(model_path): + index_filename = os.path.join(model_path, "tencentpretrain_model.bin.index.json") + with open(index_filename, "r") as f: + index = json.loads(f.read()) + shard_filenames = sorted(set(index["weight_map"].values())) + shard_filenames = [os.path.join(model_path, f) for f in shard_filenames] + for shard_file in shard_filenames: + model = _load_state_dict_into_model(model, shard_file, "") + elif model_path is not None: + model = _load_state_dict_into_model(model, model_path, "") + else: + if os.path.isdir(model_path): + index_filename = os.path.join(model_path, "tencentpretrain_model.bin.index.json") + with open(index_filename, "r") as f: + index = json.loads(f.read()) + shard_filenames = sorted(set(index["weight_map"].values())) + shard_filenames = [os.path.join(model_path, f) for f in shard_filenames] + for shard_file in shard_filenames: + model.load_state_dict(torch.load(shard_file, map_location="cpu"), strict=False) + elif model_path is not None: + model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False) + return model def train_model(args, model, optimizer, scheduler, src_batch, tgt_batch, seg_batch, soft_tgt_batch=None): model.zero_grad() @@ -98,6 +111,75 @@ def train_model(args, model, optimizer, scheduler, src_batch, tgt_batch, seg_bat return loss +def batch_loader(batch_size, src, tgt, seg, is_pad, soft_tgt=None): + instances_num = src.size()[0] + for i in range(instances_num // batch_size): + src_batch = src[i * batch_size : (i + 1) * batch_size, :] + tgt_batch = tgt[i * batch_size : (i + 1) * batch_size] + seg_batch = seg[i * batch_size : (i + 1) * batch_size, :] + is_pad_batch = is_pad[i * batch_size : (i + 1) * batch_size] + if soft_tgt is not None: + soft_tgt_batch = soft_tgt[i * batch_size : (i + 1) * batch_size, :] + yield src_batch, tgt_batch, seg_batch, is_pad_batch, soft_tgt_batch + else: + yield src_batch, tgt_batch, seg_batch, is_pad_batch, None + if instances_num > instances_num // batch_size * batch_size: + src_batch = src[instances_num // batch_size * batch_size :, :] + tgt_batch = tgt[instances_num // batch_size * batch_size :] + seg_batch = seg[instances_num // batch_size * batch_size :, :] + is_pad_batch = is_pad[instances_num // batch_size * batch_size :] + if soft_tgt is not None: + soft_tgt_batch = soft_tgt[instances_num // batch_size * batch_size :, :] + yield src_batch, tgt_batch, seg_batch, is_pad_batch, soft_tgt_batch + else: + yield src_batch, tgt_batch, seg_batch, is_pad_batch, None + +def predict(args, dataset): + src = torch.LongTensor([sample[0] for sample in dataset]) + tgt = torch.LongTensor([sample[1] for sample in dataset]) + seg = torch.LongTensor([sample[2] for sample in dataset]) + is_pad = torch.LongTensor([sample[3] for sample in dataset]) + + batch_size = args.batch_size + + args.model.eval() + + result = [] + for _, (src_batch, tgt_batch, seg_batch, is_pad_batch, _) in enumerate(batch_loader(batch_size, src, tgt, seg, is_pad)): + src_batch = src_batch.to(args.device) + tgt_batch = tgt_batch.to(args.device) + seg_batch = seg_batch.to(args.device) + is_pad_batch = is_pad_batch.to(args.device) + with torch.no_grad(): + _, logits = args.model(src_batch, tgt_batch, seg_batch) + pred = torch.argmax(nn.Softmax(dim=1)(logits), dim=1) + gold = tgt_batch + pad = is_pad_batch + for j in range(pred.size()[0]): + result.append([pred[j], gold[j], pad[j]]) + return result + +def evaluate(args, output_list): + # Confusion matrix. + correct, total = 0, 0 + confusion = torch.zeros(args.labels_num, args.labels_num, dtype=torch.long) + for result in output_list: + for pred, gold, is_pad in result.tolist(): + if is_pad == 1: continue + confusion[pred, gold] += 1 + correct += pred == gold + total += 1 + args.logger.info("Confusion matrix:") + args.logger.info(confusion) + args.logger.info("Report precision, recall, and f1:") + eps = 1e-9 + for i in range(confusion.size()[0]): + p = confusion[i, i].item() / (confusion[i, :].sum().item() + eps) + r = confusion[i, i].item() / (confusion[:, i].sum().item() + eps) + f1 = 2 * p * r / (p + r + eps) + args.logger.info("Label {}: {:.3f}, {:.3f}, {:.3f}".format(i, p, r, f1)) + args.logger.info("Acc. (Correct/Total): {:.4f} ({}/{}) ".format(correct / total, correct, total)) + return correct / total, confusion def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -129,17 +211,15 @@ def main(): # Build tokenizer. args.tokenizer = str2tokenizer[args.tokenizer](args) - # Build classification model. - if args.enable_zero3: - with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config): - model = Classifier(args) - if args.pretrained_model_path: - model = _load_state_dict_into_model(model, args.load_model_path) + # Build multi-task classification model. + model = Classifier(args) + if args.pretrained_model_path: + load_model(args, model, args.pretrained_model_path) else: - model = Classifier(args) - - # Load or initialize parameters. - load_or_initialize_parameters(args, model) + # Initialize with normal distribution. + for n, p in list(model.named_parameters()): + if "gamma" not in n and "beta" not in n: + p.data.normal_(0, 0.02) # Get logger. args.logger = init_logger(args) @@ -155,7 +235,7 @@ def main(): rank = dist.get_rank() args.rank = rank - trainset = read_dataset(args, args.train_path, split=True)[args.rank] + trainset = read_dataset(args, args.train_path)[args.rank] random.shuffle(trainset) instances_num = len(trainset) batch_size = args.batch_size @@ -172,11 +252,12 @@ def main(): mpu=None, dist_init_required=False) - src = torch.LongTensor([example[0] for example in trainset]) - tgt = torch.LongTensor([example[1] for example in trainset]) - seg = torch.LongTensor([example[2] for example in trainset]) + src = torch.LongTensor([sample[0] for sample in trainset]) + tgt = torch.LongTensor([sample[1] for sample in trainset]) + seg = torch.LongTensor([sample[2] for sample in trainset]) + is_pad = torch.LongTensor([sample[3] for sample in trainset]) if args.soft_targets: - soft_tgt = torch.FloatTensor([example[3] for example in trainset]) + soft_tgt = torch.FloatTensor([sample[4] for sample in trainset]) else: soft_tgt = None @@ -193,14 +274,18 @@ def main(): for epoch in range(1, args.epochs_num + 1): model.train() - for i, (src_batch, tgt_batch, seg_batch, soft_tgt_batch) in enumerate(batch_loader(batch_size, src, tgt, seg, soft_tgt)): + for i, (src_batch, tgt_batch, seg_batch, _, soft_tgt_batch) in enumerate(batch_loader(batch_size, src, tgt, seg, is_pad, soft_tgt)): loss = train_model(args, model, optimizer, scheduler, src_batch, tgt_batch, seg_batch, soft_tgt_batch) total_loss += loss.item() if (i + 1) % args.report_steps == 0 and args.rank == 0: args.logger.info("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".format(epoch, i + 1, total_loss / args.report_steps)) total_loss = 0.0 + output = predict(args, read_dataset(args, args.dev_path)[args.rank]) + output = torch.as_tensor(output).to(args.device) + output_list = [torch.zeros_like(output).to(args.device) for _ in range(args.world_size)] + dist.all_gather(output_list, output) if args.rank == 0: - result = evaluate(args, read_dataset(args, args.dev_path, split=False)) + result = evaluate(args, output_list) result_tensor = torch.tensor(result[0]).to(args.device) dist.broadcast(result_tensor, 0, async_op=False) if result_tensor.float() >= best_result: @@ -208,12 +293,5 @@ def main(): best_epoch = epoch model.save_checkpoint(args.output_model_path, str(epoch)) - # Evaluation phase. - if args.test_path is not None and args.rank == 0: - args.logger.info("Test set evaluation.") - model.load_checkpoint(args.output_model_path, str(best_epoch)) - evaluate(args, read_dataset(args, args.test_path, split=False)) - - if __name__ == "__main__": main() diff --git a/finetune/run_classifier_mt_deepspeed.py b/finetune/run_classifier_mt_deepspeed.py new file mode 100644 index 00000000..00a4bbc6 --- /dev/null +++ b/finetune/run_classifier_mt_deepspeed.py @@ -0,0 +1,218 @@ +""" +This script provides an example to use DeepSpeed for multi-task classification. +""" +import sys +import os +import random +import argparse +import torch +import torch.nn as nn +import deepspeed +import torch.distributed as dist + +tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(tencentpretrain_dir) + +from tencentpretrain.opts import * +from finetune.run_classifier_deepspeed import * +from finetune.run_classifier_mt import MultitaskClassifier + +def pack_dataset(dataset, dataset_id, batch_size): + packed_dataset = [] + src_batch, tgt_batch, seg_batch, is_pad_batch = [], [], [], [] + for i, sample in enumerate(dataset): + src_batch.append(sample[0]) + tgt_batch.append(sample[1]) + seg_batch.append(sample[2]) + is_pad_batch.append(sample[3]) + if (i + 1) % batch_size == 0: + packed_dataset.append((dataset_id, torch.LongTensor(src_batch), torch.LongTensor(tgt_batch), torch.LongTensor(seg_batch), torch.LongTensor(is_pad_batch))) + src_batch, tgt_batch, seg_batch = [], [], [] + continue + if len(src_batch) > 0: + packed_dataset.append((dataset_id, torch.LongTensor(src_batch), torch.LongTensor(tgt_batch), torch.LongTensor(seg_batch), torch.LongTensor(is_pad_batch))) + + return packed_dataset + +def predict(args, dataset): + src = torch.LongTensor([sample[0] for sample in dataset]) + tgt = torch.LongTensor([sample[1] for sample in dataset]) + seg = torch.LongTensor([sample[2] for sample in dataset]) + is_pad = torch.LongTensor([sample[3] for sample in dataset]) + + batch_size = args.batch_size + + args.model.eval() + + result = [] + for _, (src_batch, tgt_batch, seg_batch, is_pad_batch, _) in enumerate(batch_loader(batch_size, src, tgt, seg, is_pad)): + src_batch = src_batch.to(args.device) + tgt_batch = tgt_batch.to(args.device) + seg_batch = seg_batch.to(args.device) + is_pad_batch = is_pad_batch.to(args.device) + with torch.no_grad(): + _, logits = args.model(src_batch, tgt_batch, seg_batch) + pred = torch.argmax(nn.Softmax(dim=1)(logits), dim=1) + gold = tgt_batch + pad = is_pad_batch + for j in range(pred.size()[0]): + result.append([pred[j], gold[j], pad[j]]) + return result + +def evaluate(args, output_list): + for dataset_id, _ in enumerate(args.dataset_path_list): + # Confusion matrix. + correct, total = 0, 0 + confusion = torch.zeros(args.labels_num, args.labels_num, dtype=torch.long) + for result in output_list: + for pred, gold, is_pad in result.tolist()[dataset_id]: + if is_pad == 1: continue + confusion[pred, gold] += 1 + correct += pred == gold + total += 1 + args.logger.info("Confusion matrix:") + args.logger.info(confusion) + args.logger.info("Report precision, recall, and f1:") + eps = 1e-9 + for i in range(confusion.size()[0]): + p = confusion[i, i].item() / (confusion[i, :].sum().item() + eps) + r = confusion[i, i].item() / (confusion[:, i].sum().item() + eps) + f1 = 2 * p * r / (p + r + eps) + args.logger.info("Label {}: {:.3f}, {:.3f}, {:.3f}".format(i, p, r, f1)) + args.logger.info("Dataset_id: {} Acc. (Correct/Total): {:.4f} ({}/{}) ".format(dataset_id, correct / total, correct, total)) + +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + # Path options. + parser.add_argument("--world_size", type=int, default=1, + help="Total number of processes (GPUs) for training.") + parser.add_argument("--pretrained_model_path", default=None, type=str, + help="Path of the pretrained model.") + parser.add_argument("--dataset_path_list", default=[], nargs='+', type=str, help="Dataset path list.") + parser.add_argument("--output_model_path", default="models/multitask_classifier_model.bin", type=str, + help="Path of the output model.") + parser.add_argument("--config_path", default="models/bert/base_config.json", type=str, + help="Path of the config file.") + parser.add_argument("--soft_targets", action='store_true', + help="Train model with logits.") + parser.add_argument("--soft_alpha", type=float, default=0.5, + help="Weight of the soft targets loss.") + + # Model options. + model_opts(parser) + + # Tokenizer options. + tokenizer_opts(parser) + + # Optimizer options. + optimization_opts(parser) + + # Training options. + training_opts(parser) + + adv_opts(parser) + + deepspeed_opts(parser) + + args = parser.parse_args() + + # Load the hyperparameters from the config file. + args = load_hyperparam(args) + + set_seed(args.seed) + + # Count the number of labels. + args.labels_num_list = [count_labels_num(os.path.join(path, "train.tsv")) for path in args.dataset_path_list] + + args.datasets_num = len(args.dataset_path_list) + + # Build tokenizer. + args.tokenizer = str2tokenizer[args.tokenizer](args) + + # Build multi-task classification model. + model = MultitaskClassifier(args) + if args.pretrained_model_path: + load_model(args, model, args.pretrained_model_path) + else: + # Initialize with normal distribution. + for n, p in list(model.named_parameters()): + if "gamma" not in n and "beta" not in n: + p.data.normal_(0, 0.02) + + # Get logger. + args.logger = init_logger(args) + + param_optimizer = list(model.named_parameters()) + no_decay = ["bias", "gamma", "beta"] + optimizer_grouped_parameters = [ + {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01}, + {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, + ] + + deepspeed.init_distributed() + rank = dist.get_rank() + args.rank = rank + + dataset_list = [read_dataset(args, os.path.join(path, "train.tsv"))[args.rank] for path in args.dataset_path_list] + packed_dataset_list = [pack_dataset(dataset, i, args.batch_size) for i, dataset in enumerate(dataset_list)] + packed_dataset_all = [] + for packed_dataset in packed_dataset_list: + packed_dataset_all += packed_dataset + + instances_num = sum([len(dataset) for dataset in dataset_list]) + batch_size = args.batch_size + args.train_steps = int(instances_num * args.epochs_num / batch_size) + 1 + + custom_optimizer, custom_scheduler = build_optimizer(args, model) + + model, optimizer, _, scheduler = deepspeed.initialize( + model=model, + model_parameters=optimizer_grouped_parameters, + args=args, + optimizer=custom_optimizer, + lr_scheduler=custom_scheduler, + mpu=None, + dist_init_required=False) + + args.model = model + args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + total_loss, result, best_result, best_epoch = 0.0, 0.0, 0.0, 0 + + if args.rank == 0: + args.logger.info("Batch size: {}".format(batch_size)) + args.logger.info("The number of training instances: {}".format(instances_num)) + args.logger.info("Start training.") + + for epoch in range(1, args.epochs_num + 1): + random.shuffle(packed_dataset_all) + model.train() + for i, (dataset_id, src_batch, tgt_batch, seg_batch, _) in enumerate(packed_dataset_all): + if hasattr(model, "module"): + model.module.change_dataset(dataset_id) + else: + model.change_dataset(dataset_id) + loss = train_model(args, model, optimizer, scheduler, src_batch, tgt_batch, seg_batch, None) + total_loss += loss.item() + if (i + 1) % args.report_steps == 0 and args.rank == 0: + args.logger.info("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".format(epoch, i + 1, total_loss / args.report_steps)) + total_loss = 0.0 + output = [] + for dataset_id, path in enumerate(args.dataset_path_list): + args.labels_num = args.labels_num_list[dataset_id] + if hasattr(model, "module"): + model.module.change_dataset(dataset_id) + else: + model.change_dataset(dataset_id) + result = predict(args, read_dataset(args, os.path.join(path, "dev.tsv"))[args.rank]) + output.append(result) + output = torch.as_tensor(output).to(args.device) + output_list = [torch.zeros_like(output).to(args.device) for _ in range(args.world_size)] + dist.all_gather(output_list, output) + if args.rank == 0: + evaluate(args, output_list) + model.save_checkpoint(args.output_model_path, str(epoch)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/inference/run_classifier_deepspeed_infer.py b/inference/run_classifier_deepspeed_infer.py index fe26b8ed..8d6812f1 100644 --- a/inference/run_classifier_deepspeed_infer.py +++ b/inference/run_classifier_deepspeed_infer.py @@ -14,19 +14,56 @@ tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(tencentpretrain_dir) - -from tencentpretrain.opts import deepspeed_opts +from tencentpretrain.opts import * from inference.run_classifier_infer import * +from tencentpretrain.utils.logging import * +from finetune.run_classifier_deepspeed import read_dataset, load_model + +def batch_loader(batch_size, src, seg, is_pad): + instances_num = src.size()[0] + for i in range(instances_num // batch_size): + src_batch = src[i * batch_size : (i + 1) * batch_size, :] + seg_batch = seg[i * batch_size : (i + 1) * batch_size, :] + is_pad_batch = is_pad[i * batch_size : (i + 1) * batch_size] + yield src_batch, seg_batch, is_pad_batch + if instances_num > instances_num // batch_size * batch_size: + src_batch = src[instances_num // batch_size * batch_size :, :] + seg_batch = seg[instances_num // batch_size * batch_size :, :] + is_pad_batch = is_pad[instances_num // batch_size * batch_size :] + yield src_batch, seg_batch, is_pad_batch + +def predict(args, dataset): + src = torch.LongTensor([sample[0] for sample in dataset]) + seg = torch.LongTensor([sample[2] for sample in dataset]) + is_pad = torch.LongTensor([sample[3] for sample in dataset]) + batch_size = args.batch_size + args.model.eval() + + result = [] + for i, (src_batch, seg_batch, is_pad_batch) in enumerate(batch_loader(batch_size, src, seg, is_pad)): + src_batch = src_batch.to(args.device) + seg_batch = seg_batch.to(args.device) + is_pad_batch = is_pad_batch.to(args.device) + with torch.no_grad(): + _, logits = args.model(src_batch, None, seg_batch) + is_pad_batch = is_pad_batch.view(is_pad_batch.shape[0],1) + logits_all = torch.cat((logits, is_pad_batch), dim=-1) + result.append(logits_all) + result = torch.cat(result, dim=0) + return result def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) infer_opts(parser) - + + parser.add_argument("--world_size", type=int, default=1, + help="Total number of processes (GPUs) for training.") parser.add_argument("--labels_num", type=int, required=True, help="Number of prediction labels.") + log_opts(parser) tokenizer_opts(parser) @@ -46,59 +83,51 @@ def main(): # Build classification model and load parameters. args.soft_targets, args.soft_alpha = False, False deepspeed.init_distributed() - if args.enable_zero3: - with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config): - model = Classifier(args) - model = _load_state_dict_into_model(model, args.load_model_path) - else: - model = Classifier(args) - model = load_model(model, args.load_model_path) + # Build classification model. + model = Classifier(args) + load_model(args, model, args.load_model_path) + + # Get logger. + args.logger = init_logger(args) model = deepspeed.initialize(model=model,config_params=args.deepspeed_config)[0] + args.model = model - rank = dist.get_rank() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + args.rank = dist.get_rank() + args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataset = predict(args, read_dataset(args, args.test_path)[args.rank]) + output = torch.as_tensor(dataset).to(args.device) + output_list = [torch.zeros_like(output).to(args.device) for _ in range(args.world_size)] + dist.all_gather(output_list, output) - dataset = read_dataset(args, args.test_path) - - src = torch.LongTensor([sample[0] for sample in dataset]) - seg = torch.LongTensor([sample[1] for sample in dataset]) - - batch_size = args.batch_size - instances_num = src.size()[0] - - print("The number of prediction instances: ", instances_num) - - model.eval() - - with open(args.prediction_path, mode="w", encoding="utf-8") as f: - if rank == 0: + if args.rank == 0: + with open(args.prediction_path, mode="w", encoding="utf-8") as f: f.write("label") if args.output_logits: f.write("\t" + "logits") if args.output_prob: f.write("\t" + "prob") f.write("\n") - for i, (src_batch, seg_batch) in enumerate(batch_loader(batch_size, src, seg)): - src_batch = src_batch.to(device) - seg_batch = seg_batch.to(device) - with torch.no_grad(): - _, logits = model(src_batch, None, seg_batch) - - pred = torch.argmax(logits, dim=1) - pred = pred.cpu().numpy().tolist() - prob = nn.Softmax(dim=1)(logits) - logits = logits.cpu().numpy().tolist() - prob = prob.cpu().numpy().tolist() - if rank == 0: + for logits_all in output_list: + logits = logits_all[:,:-1] + is_pad = logits_all[:,-1] + + pred = torch.argmax(logits, dim=1) + pred = pred.cpu().numpy().tolist() + prob = nn.Softmax(dim=1)(logits) + prob = prob.cpu().numpy().tolist() + logits = logits.cpu().numpy().tolist() + pad = is_pad.cpu().numpy().tolist() for j in range(len(pred)): + if pad[j] == 1: + continue f.write(str(pred[j])) if args.output_logits: f.write("\t" + " ".join([str(v) for v in logits[j]])) if args.output_prob: f.write("\t" + " ".join([str(v) for v in prob[j]])) f.write("\n") - + f.close() if __name__ == "__main__": main() diff --git a/inference/run_classifier_mt_deepspeed_infer.py b/inference/run_classifier_mt_deepspeed_infer.py new file mode 100644 index 00000000..1c36625b --- /dev/null +++ b/inference/run_classifier_mt_deepspeed_infer.py @@ -0,0 +1,115 @@ +""" + This script provides an example to use DeepSpeed for multi-task classification inference. +""" +import sys +import os +import torch +import argparse +import torch.nn as nn +import deepspeed +import torch.distributed as dist + +tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(tencentpretrain_dir) + +from tencentpretrain.opts import * +from finetune.run_classifier_deepspeed import * +from inference.run_classifier_deepspeed_infer import batch_loader +from inference.run_classifier_mt_infer import MultitaskClassifier + +def predict(args, dataset): + src = torch.LongTensor([sample[0] for sample in dataset]) + seg = torch.LongTensor([sample[2] for sample in dataset]) + is_pad = torch.LongTensor([sample[3] for sample in dataset]) + + batch_size = args.batch_size + + args.model.eval() + + result = [] + for i, (src_batch, seg_batch, is_pad_batch) in enumerate(batch_loader(batch_size, src, seg, is_pad)): + src_batch = src_batch.to(args.device) + seg_batch = seg_batch.to(args.device) + is_pad_batch = is_pad_batch.to(args.device) + with torch.no_grad(): + _, logits = args.model(src_batch, None, seg_batch) + logits = torch.stack(logits) + is_pad_batch = is_pad_batch.view(1,is_pad_batch.shape[0],1).repeat(logits.shape[0],1,1) + logits_all = torch.cat((logits, is_pad_batch), dim=-1) + result.append(logits_all) + result = torch.cat(result, dim=1) + return result + +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + infer_opts(parser) + + tokenizer_opts(parser) + parser.add_argument("--world_size", type=int, default=1, + help="Total number of processes (GPUs) for training.") + parser.add_argument("--output_logits", action="store_true", help="Write logits to output file.") + parser.add_argument("--output_prob", action="store_true", help="Write probabilities to output file.") + parser.add_argument("--labels_num_list", default=[], nargs='+', type=int, help="Dataset labels num list.") + log_opts(parser) + + deepspeed_opts(parser) + + args = parser.parse_args() + + # Load the hyperparameters from the config file. + args = load_hyperparam(args) + + # Build tokenizer. + args.tokenizer = str2tokenizer[args.tokenizer](args) + + # Get logger. + args.logger = init_logger(args) + + # Build multi-task classification model and load parameters. + args.soft_targets, args.soft_alpha = False, False + deepspeed.init_distributed() + # Build multi-task classification model. + model = MultitaskClassifier(args) + load_model(args, model, args.load_model_path) + model = deepspeed.initialize(model=model,config_params=args.deepspeed_config)[0] + args.model = model + + args.rank = dist.get_rank() + args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataset = predict(args, read_dataset(args, args.test_path)[args.rank]) + output = torch.as_tensor(dataset).to(args.device) + output_list = [torch.zeros_like(output).to(args.device) for _ in range(args.world_size)] + dist.all_gather(output_list, output) + + if args.rank == 0: + with open(args.prediction_path, mode="w", encoding="utf-8") as f: + f.write("label") + if args.output_logits: + f.write("\t" + "logits") + if args.output_prob: + f.write("\t" + "prob") + f.write("\n") + for logits_all in output_list: + logits = logits_all[:,:,:-1] + is_pad = logits_all[0,:,-1] + + pred = [torch.argmax(logits_i, dim=-1) for logits_i in logits] + prob = [nn.Softmax(dim=-1)(logits_i) for logits_i in logits] + logits = [x.cpu().numpy().tolist() for x in logits] + pred = [x.cpu().numpy().tolist() for x in pred] + prob = [x.cpu().numpy().tolist() for x in prob] + pad = [x.cpu().numpy().tolist() for x in is_pad] + for j in range(len(pred[0])): + if pad[j] == 1: + continue + f.write("|".join([str(v[j]) for v in pred])) + if args.output_logits: + f.write("\t" + "|".join([" ".join(["{0:.4f}".format(w) for w in v[j]]) for v in logits])) + if args.output_prob: + f.write("\t" + "|".join([" ".join(["{0:.4f}".format(w) for w in v[j]]) for v in prob])) + f.write("\n") + f.close() + +if __name__ == "__main__": + main()