From 696e6b06708672639d66285bad6d88e61c48300d Mon Sep 17 00:00:00 2001 From: "Ethan, Wenjun Hou" Date: Fri, 22 Dec 2023 17:08:15 +0800 Subject: [PATCH] Add Code of Stage 2 --- src_stage2/__init__.py | 0 src_stage2/chexbert_eval.py | 325 +++++ src_stage2/data_arguments.py | 98 ++ src_stage2/data_collator_ende.py | 213 +++ src_stage2/data_process_ende.py | 254 ++++ src_stage2/dataset_ende.py | 440 ++++++ src_stage2/metrics.py | 39 + src_stage2/model_arguments.py | 84 ++ src_stage2/models/activations.py | 184 +++ src_stage2/models/modeling_bart.py | 1575 +++++++++++++++++++++ src_stage2/models/rgcn.py | 70 + src_stage2/optimizer.py | 65 + src_stage2/run_ende.py | 364 +++++ src_stage2/seq2seqtrainer_metrics_ende.py | 99 ++ src_stage2/tokenizer.py | 255 ++++ src_stage2/train_eval_ende_full.py | 275 ++++ 16 files changed, 4340 insertions(+) create mode 100644 src_stage2/__init__.py create mode 100644 src_stage2/chexbert_eval.py create mode 100644 src_stage2/data_arguments.py create mode 100644 src_stage2/data_collator_ende.py create mode 100644 src_stage2/data_process_ende.py create mode 100644 src_stage2/dataset_ende.py create mode 100644 src_stage2/metrics.py create mode 100644 src_stage2/model_arguments.py create mode 100644 src_stage2/models/activations.py create mode 100644 src_stage2/models/modeling_bart.py create mode 100644 src_stage2/models/rgcn.py create mode 100644 src_stage2/optimizer.py create mode 100644 src_stage2/run_ende.py create mode 100644 src_stage2/seq2seqtrainer_metrics_ende.py create mode 100644 src_stage2/tokenizer.py create mode 100644 src_stage2/train_eval_ende_full.py diff --git a/src_stage2/__init__.py b/src_stage2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src_stage2/chexbert_eval.py b/src_stage2/chexbert_eval.py new file mode 100644 index 0000000..0ae0516 --- /dev/null +++ b/src_stage2/chexbert_eval.py @@ -0,0 +1,325 @@ +#!/usr/bin/env python +# coding=utf-8 +import torch +from collections import OrderedDict +import numpy as np +from sklearn.metrics import precision_recall_fscore_support + +from tqdm import tqdm +from collections import defaultdict + +CONDITIONS = [ + "Enlarged Cardiomediastinum", + "Cardiomegaly", + "Lung Opacity", + "Lung Lesion", + "Edema", + "Consolidation", + "Pneumonia", + "Atelectasis", + "Pneumothorax", + "Pleural Effusion", + "Pleural Other", + "Fracture", + "Support Devices", + "No Finding", +] + + +def load_chexbert(checkpoint_path): + import sys + + sys.path.append("./CheXbert/src/") + from models.bert_labeler import bert_labeler + + chexbert = bert_labeler() + checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) + new_state_dict = OrderedDict() + for k, v in checkpoint["model_state_dict"].items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + chexbert.load_state_dict(new_state_dict, strict=False) + print("Loaded reward model from {}".format(checkpoint_path)) + chexbert.eval() + return chexbert.cuda() + + +def compute_ce_metric( + references, hypotheses, is_temporals, chexbert, bert_tokenizer, batch_size=128 +): + def pad_strings(strs): + max_len = max([len(s) for s in strs]) + return [s + " " * (max_len - len(s)) for s in strs] + + chexbert.eval() + CLASS_MAPPING = {0: "Blank", 1: "Positive", 2: "Negative", 3: "Positive"} + NO_FINDING_CLASS_MAPPING = {0: "Negative", 1: "Positive"} + LABEL_MAPPING = {0: 0, 1: 1, 2: 2, 3: 1} + TEM_keywords = { + "bigger", + "change", + "cleared", + "constant", + "decrease", + "decreased", + "decreasing", + "elevated", + "elevation", + "enlarged", + "enlargement", + "enlarging", + "expanded", + "greater", + "growing", + "improved", + "improvement", + "improving", + "increase", + "increased", + "increasing", + "larger", + "new", + "persistence", + "persistent", + "persisting", + "progression", + "progressive", + "reduced", + "removal", + "resolution", + "resolved", + "resolving", + "smaller", + "stability", + "stable", + "stably", + "unchanged", + "unfolded", + "worse", + "worsen", + "worsened", + "worsening", + "unaltered", + } + ref_observations = [] + hyp_observations = [] + y_preds = [] + y_trues = [] + macro_y_preds = [] + macro_y_trues = [] + for i in tqdm(range(0, len(references), batch_size), desc="Calculating CE Scores"): + ref = [r.replace(" .", ".") for r in references[i : i + batch_size]] + hyp = [h.replace(" .", ".") for h in hypotheses[i : i + batch_size]] + ref_input = bert_tokenizer.batch_encode_plus( + ref, return_tensors="pt", padding=True, truncation=True, max_length=512 + ) + hyp_input = bert_tokenizer.batch_encode_plus( + hyp, return_tensors="pt", padding=True, truncation=True, max_length=512 + ) + ref_input = {k: v.cuda() for k, v in ref_input.items()} + hyp_input = {k: v.cuda() for k, v in hyp_input.items()} + ref_logits = chexbert( + source_padded=ref_input["input_ids"], + attention_mask=ref_input["attention_mask"], + ) + hyp_logits = chexbert( + source_padded=hyp_input["input_ids"], + attention_mask=hyp_input["attention_mask"], + ) + ref_status = [l.argmax(dim=1).tolist() for l in ref_logits] + hyp_status = [l.argmax(dim=1).tolist() for l in hyp_logits] + y_pred = np.zeros((len(ref_status[0]), len(CONDITIONS))) + y_true = np.zeros((len(hyp_status[0]), len(CONDITIONS))) + macro_y_pred = np.zeros((len(ref_status[0]), len(CONDITIONS))) + macro_y_true = np.zeros((len(hyp_status[0]), len(CONDITIONS))) + ref_obs = [[] for _ in range(len(ref_status[0]))] + hyp_obs = [[] for _ in range(len(hyp_status[0]))] + for i, c in enumerate(CONDITIONS): + i_ref_status = ref_status[i] + i_hyp_status = hyp_status[i] + if c == "No Finding": + class_mapping = NO_FINDING_CLASS_MAPPING + else: + class_mapping = CLASS_MAPPING + for j in range(len(i_hyp_status)): # batch_size + macro_y_pred[j][i] = i_hyp_status[j] + macro_y_true[j][i] = i_ref_status[j] + if LABEL_MAPPING[i_hyp_status[j]] == 1: + y_pred[j][i] = 1 + if LABEL_MAPPING[i_ref_status[j]] == 1: + y_true[j][i] = 1 + if i_hyp_status[j] != 0 or c == "No Finding": + hyp_obs[j].append(":".join((c, class_mapping[i_hyp_status[j]]))) + if i_ref_status[j] != 0 or c == "No Finding": + ref_obs[j].append(":".join((c, class_mapping[i_ref_status[j]]))) + + y_preds.append(y_pred) + y_trues.append(y_true) + macro_y_preds.append(macro_y_pred) + macro_y_trues.append(macro_y_true) + ref_observations.extend(ref_obs) + hyp_observations.extend(hyp_obs) + y_preds = np.concatenate(y_preds, axis=0) + y_trues = np.concatenate(y_trues, axis=0) + macro_y_preds = np.concatenate(macro_y_preds, axis=0) + macro_y_trues = np.concatenate(macro_y_trues, axis=0) + ce_prf = [0, 0, 0] + macro_ce_prf = [0, 0, 0] + temporal_ce_prf = [0, 0, 0] + macro_temporal_ce_prf = [0, 0, 0] + + print("--------------------------------------------------------------") + pad_conditions = pad_strings(CONDITIONS) + for i, c in enumerate(CONDITIONS): + # for all reports + y_true = y_trues[:, i] + y_pred = y_preds[:, i] + i_prf = precision_recall_fscore_support( + y_true=y_true, y_pred=y_pred, average="binary", pos_label=1 + ) + ce_prf = [ce_prf[j] + i_prf[j] for j in range(3)] + + print( + "%s\tPrec. %0.4f\tRec. %0.4f\tF1 %0.4f" + % (pad_conditions[i], i_prf[0], i_prf[1], i_prf[2]) + ) + + y_true = macro_y_trues[:, i] + y_pred = macro_y_preds[:, i] + i_prf = precision_recall_fscore_support( + y_true=y_true, y_pred=y_pred, average="macro" + ) + macro_ce_prf = [macro_ce_prf[j] + i_prf[j] for j in range(3)] + + # for reports with temporal information + y_true = [z for z, k in zip(y_trues[:, i], is_temporals) if k] + y_pred = [z for z, k in zip(y_preds[:, i], is_temporals) if k] + i_prf = precision_recall_fscore_support( + y_true=y_true, y_pred=y_pred, average="binary", pos_label=1 + ) + temporal_ce_prf = [temporal_ce_prf[j] + i_prf[j] for j in range(3)] + + y_true = [z for z, k in zip(macro_y_trues[:, i], is_temporals) if k] + y_pred = [z for z, k in zip(macro_y_preds[:, i], is_temporals) if k] + i_prf = precision_recall_fscore_support( + y_true=y_true, y_pred=y_pred, average="macro" + ) + macro_temporal_ce_prf = [macro_temporal_ce_prf[j] + i_prf[j] for j in range(3)] + print("--------------------------------------------------------------") + ce_prf = [ce_prf[j] / len(CONDITIONS) for j in range(3)] + macro_ce_prf = [macro_ce_prf[j] / len(CONDITIONS) for j in range(3)] + temporal_ce_prf = [temporal_ce_prf[j] / len(CONDITIONS) for j in range(3)] + macro_temporal_ce_prf = [ + macro_temporal_ce_prf[j] / len(CONDITIONS) for j in range(3) + ] + + tp = 0 + count_gen = 0 + count_ref = 0 + for ref, hyp, is_temporal in zip(references, hypotheses, is_temporals): + if not is_temporal: + continue + ref_tem = set([z for z in ref.split() if z in TEM_keywords]) + hyp_tem = set([z for z in hyp.split() if z in TEM_keywords]) + tp += len(ref_tem & hyp_tem) + count_gen += len(hyp_tem) + count_ref += len(ref_tem) + tem_prec = tp / max(count_gen, 1) + tem_rec = tp / max(count_ref, 1) + tem_f1 = 2 * tem_prec * tem_rec / max((tem_prec + tem_rec), 0.1) + tem_score = [tem_prec, tem_rec, tem_f1] + return ( + ref_observations, + hyp_observations, + ce_prf, + temporal_ce_prf, + macro_ce_prf, + macro_temporal_ce_prf, + tem_score, + ) + + +def build_progression_graph( + progression_triples, + observations, + topk_entity=5, + tokenizer=None, +): + print("******************************************") + print("******Constructing Progression Graph******") + print("******Constructing Progression Graph******") + print("******Constructing Progression Graph******") + print("TopK:", topk_entity) + print("******************************************") + print("******************************************") + relation2id = { + "Better": 0, + "Worse": 1, + "No status change": 2, + "S2O": 3, + "O2O": 4, + } + id2relation = {v: k for k, v in relation2id.items()} + entity = set() + + for head in progression_triples: + entity.update(progression_triples[head][:topk_entity]) + + entity2subid = {} + for e in entity: + e_wo_head = e.split("-")[-1] + if e_wo_head not in tokenizer.token2idx: + continue + entity2subid[e] = tokenizer.token2idx[e_wo_head] + + observations = [o + ":Positive" for o in observations] + [ + o + ":Negative" for o in observations + ] + + entity = ( + ["pre_" + obs for obs in observations] + + observations + + list({e for e in entity if e in entity2subid}) + ) + entity = sorted(entity) + id2entity = {i: e for i, e in enumerate(entity)} + entity2id = {e: i for i, e in enumerate(entity)} + triples = defaultdict(list) + # prior observation->current observation + for obs in observations: + obs_, _ = obs.split(":") + triples[(entity2id["pre_" + obs], relation2id["O2O"])].append( + entity2id[obs_ + ":Positive"] + ) + triples[(entity2id["pre_" + obs], relation2id["O2O"])].append( + entity2id[obs_ + ":Negative"] + ) + for head in progression_triples: + tails = progression_triples[head][:topk_entity] + head = head.split("_") + if len(head) == 1: + head.append("S2O") + # current observation->entity + triples[(entity2id[head[0]], relation2id[head[1]])].extend( + [entity2id[tail] for tail in tails] + ) + # entity->prior observation + for tail in tails: + triples[(entity2id[tail], relation2id[head[1]])].append( + entity2id["pre_" + head[0]] + ) + + triples = {k: list(set(v)) for k, v in triples.items()} + print("******************************************") + print("******************************************") + print("***********Num of Entity: %d*************" % len(entity2id)) + print("******************************************") + print("******************************************") + return { + "triples": triples, + "entity2id": entity2id, + "id2entity": id2entity, + "relation2id": relation2id, + "id2relation": id2relation, + "entity2subid": entity2subid, + } diff --git a/src_stage2/data_arguments.py b/src_stage2/data_arguments.py new file mode 100644 index 0000000..9463941 --- /dev/null +++ b/src_stage2/data_arguments.py @@ -0,0 +1,98 @@ +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, + metadata={"help": "The name of the dataset to use (via the datasets library)."}, + ) + dataset_config_name: Optional[str] = field( + default=None, + metadata={ + "help": "The configuration name of the dataset to use (via the datasets library)." + }, + ) + image_path: Optional[str] = field( + default=None, + metadata={ + "help": "The text model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + annotation_file: Optional[str] = field( + default=None, + metadata={ + "help": "The text model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + miss_annotation_file: Optional[str] = field( + default=None, + metadata={ + "help": "The text model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + progression_graph: Optional[str] = field( + default=None, + ) + history: Optional[str] = field( + default=None, + metadata={ + "help": "The text model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + chexbert_label: Optional[str] = field(default=None) + debug_model: Optional[bool] = field(default=False) + max_tgt_length: Optional[int] = field( + default=64, + ) + eval_on_gen: Optional[bool] = field(default=False) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + + block_size: Optional[int] = field( + default=None, + metadata={ + "help": "Optional input sequence length after tokenization. " + "The training dataset will be truncated in block of this size for training. " + "Default to the model max input length for single sentence inputs (take into account special tokens)." + }, + ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets"}, + ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + keep_linebreaks: bool = field( + default=True, + metadata={"help": "Whether to keep line breaks when using TXT files or not."}, + ) \ No newline at end of file diff --git a/src_stage2/data_collator_ende.py b/src_stage2/data_collator_ende.py new file mode 100644 index 0000000..e9bd665 --- /dev/null +++ b/src_stage2/data_collator_ende.py @@ -0,0 +1,213 @@ +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch +from transformers import DataCollatorForSeq2Seq + +from transformers.file_utils import PaddingStrategy +from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase + + +@dataclass +class DataCollatorForEnDe(DataCollatorForSeq2Seq): + tokenizer: PreTrainedTokenizerBase + model: Optional[Any] = None + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + return_tensors: str = "pt" + + def __call__(self, features, return_tensors=None): + import numpy as np + + if return_tensors is None: + return_tensors = self.return_tensors + labels = ( + [feature["labels"] for feature in features] + if "labels" in features[0].keys() + else None + ) + input_ids = ( + [feature["input_ids"] for feature in features] + if "input_ids" in features[0].keys() + else None + ) + progression_input_ids = ( + [feature["progression_input_ids"] for feature in features] + if "progression_input_ids" in features[0].keys() + else None + ) + matrix = ( + [feature["matrix"] for feature in features] + if "matrix" in features[0].keys() + else None + ) + report_ids = ( + [feature["report_ids"] for feature in features] + if "report_ids" in features[0].keys() + else None + ) + is_temporal = ( + [feature["is_temporal"] for feature in features] + if "is_temporal" in features[0].keys() + else None + ) + observations = ( + [feature["observations"] for feature in features] + if "observations" in features[0].keys() + else None + ) + prior_observations = ( + [feature["prior_observations"] for feature in features] + if "prior_observations" in features[0].keys() + else None + ) + prior_entity_ids = ( + [feature["prior_entity_ids"] for feature in features] + if "prior_entity_ids" in features[0].keys() + else None + ) + temporal_image_paths = ( + [feature["temporal_image_path"] for feature in features] + if "temporal_image_path" in features[0].keys() + else None + ) + progressions = ( + [feature["progressions"] for feature in features] + if "progressions" in features[0].keys() + else None + ) + input_pixels = ( + [feature["input_pixels"] for feature in features] + if "input_pixels" in features[0].keys() + else None + ) + input_temporal_pixels = ( + [feature["input_temporal_pixels"] for feature in features] + if "input_temporal_pixels" in features[0].keys() + else None + ) + batch_outputs = {} + + if labels is not None: + batch_outputs["labels"] = [] + batch_outputs["gate_labels"] = [] + max_label_length = max(len(l) for l in labels) + if self.pad_to_multiple_of is not None: + max_label_length = ( + (max_label_length + self.pad_to_multiple_of - 1) + // self.pad_to_multiple_of + * self.pad_to_multiple_of + ) + for feature in features: + remainder = [self.label_pad_token_id] * ( + max_label_length - len(feature["labels"]) + ) + feature["labels"] = feature["labels"] + remainder + feature["gate_labels"] = feature["gate_labels"] + remainder + batch_outputs["labels"].append(feature["labels"]) + batch_outputs["gate_labels"].append(feature["gate_labels"]) + + if input_ids is not None: + batch_outputs["input_ids"] = [] + max_length = max(len(l) for l in input_ids) + if self.pad_to_multiple_of is not None: + max_length = ( + (max_length + self.pad_to_multiple_of - 1) + // self.pad_to_multiple_of + * self.pad_to_multiple_of + ) + for feature in features: + remainder = [self.tokenizer.pad_token_id] * ( + max_length - len(feature["input_ids"]) + ) + feature["input_ids"] = feature["input_ids"] + remainder + batch_outputs["input_ids"].append(feature["input_ids"]) + + if progression_input_ids is not None: + batch_outputs["progression_input_ids"] = [] + max_length = max(len(l) for l in progression_input_ids) + for feature in features: + remainder = [self.tokenizer.pad_token_id] * ( + max_length - len(feature["progression_input_ids"]) + ) + feature["progression_input_ids"] = ( + feature["progression_input_ids"] + remainder + ) + batch_outputs["progression_input_ids"].append( + feature["progression_input_ids"] + ) + + if observations is not None: + batch_outputs["observations"] = [] + for feature in features: + batch_outputs["observations"].append(feature["observations"]) + + if matrix is not None: + batch_outputs["node_mask"] = [] + batch_outputs["nodes"] = [] + batch_outputs["gather_index"] = [] + max_length = max(m.shape[-1] for m in matrix) + for feature in features: + batch_outputs["gather_index"].append(feature["gather_index"]) + + for i, m in enumerate(matrix): + feature = features[i] + diff = max_length - m.shape[-1] + m = np.pad( + m, + ((0, 0), (0, diff), (0, diff)), + mode="constant", + constant_values=0, + ) + feature["node_mask"] = feature["node_mask"] + [0] * diff + feature["nodes"] = feature["nodes"] + [-100] * diff + matrix[i] = m + batch_outputs["node_mask"].append(feature["node_mask"]) + batch_outputs["nodes"].append(feature["nodes"]) + + if progressions is not None: + batch_outputs["progressions"] = progressions + features = BatchEncoding(batch_outputs, tensor_type=return_tensors) + features["input_pixels"] = torch.cat(input_pixels, dim=0) + features["temporal_mask"] = torch.zeros((len(temporal_image_paths))) + for i, image_path in enumerate(temporal_image_paths): + if len(image_path) > 0: + features["temporal_mask"][i] = 1 + features["input_temporal_pixels"] = torch.cat( + input_temporal_pixels, dim=0) + features["matrix"] = torch.from_numpy(np.stack(matrix, axis=0)).float() + features["attention_mask"] = torch.ones_like(features["input_ids"]).masked_fill( + features["input_ids"] == self.tokenizer.pad_token_id, 0 + ) + features["progression_attention_mask"] = torch.ones_like( + features["progression_input_ids"] + ).masked_fill( + features["progression_input_ids"] == self.tokenizer.pad_token_id, 0 + ) + + if report_ids is not None: + features["report_ids"] = report_ids + features["is_temporal"] = is_temporal + features["prior_observations"] = prior_observations + features["prior_entity_ids"] = prior_entity_ids + + if ( + labels is not None + and self.model is not None + and hasattr(self.model, "prepare_decoder_input_ids_from_labels") + ): + decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels( + labels=features["labels"] + ) + features["decoder_input_ids"] = decoder_input_ids + return features + + def pad_sequence(self, seqs, padding_idx, max_len): + new_seqs = [] + for seq in seqs: + seq_len = len(seq) + diff = max_len - seq_len + new_seqs.append(seq + [padding_idx] * diff) + return new_seqs diff --git a/src_stage2/data_process_ende.py b/src_stage2/data_process_ende.py new file mode 100644 index 0000000..d3376ca --- /dev/null +++ b/src_stage2/data_process_ende.py @@ -0,0 +1,254 @@ +import re + +from tqdm import tqdm + +from tokenizer import Tokenizer + + +def clean_report_iu_xray(report): + report_cleaner = ( + lambda t: t.replace("..", ".") + .replace("..", ".") + .replace("..", ".") + .replace("1. ", "") + .replace(". 2. ", ". ") + .replace(". 3. ", ". ") + .replace(". 4. ", ". ") + .replace(". 5. ", ". ") + .replace(" 2. ", ". ") + .replace(" 3. ", ". ") + .replace(" 4. ", ". ") + .replace(" 5. ", ". ") + .strip() + .lower() + .split(". ") + ) + + def sent_cleaner(t): + return re.sub( + "[.,?;*!%^&_+():-\[\]{}]", + "", + t.replace('"', "") + .replace("/", "") + .replace("\\", "") + .replace("'", "") + .strip() + .lower(), + ) + + tokens = [ + sent_cleaner(sent) + for sent in report_cleaner(report) + if sent_cleaner(sent) != [] + ] + report = " . ".join(tokens) + " ." + return report + + +def clean_report_mimic_cxr(report): + report_cleaner = ( + lambda t: t.replace("\n", " ") + .replace("__", "_") + .replace("__", "_") + .replace("__", "_") + .replace("__", "_") + .replace("__", "_") + .replace("__", "_") + .replace("__", "_") + .replace(" ", " ") + .replace(" ", " ") + .replace(" ", " ") + .replace(" ", " ") + .replace(" ", " ") + .replace(" ", " ") + .replace("..", ".") + .replace("..", ".") + .replace("..", ".") + .replace("..", ".") + .replace("..", ".") + .replace("..", ".") + .replace("..", ".") + .replace("..", ".") + .replace("1. ", "") + .replace(". 2. ", ". ") + .replace(". 3. ", ". ") + .replace(". 4. ", ". ") + .replace(". 5. ", ". ") + .replace(" 2. ", ". ") + .replace(" 3. ", ". ") + .replace(" 4. ", ". ") + .replace(" 5. ", ". ") + .strip() + .lower() + .split(". ") + ) + + def sent_cleaner(t): + return re.sub( + "[.,?;*!%^&_+():-\[\]{}]", + "", + t.replace('"', "") + .replace("/", "") + .replace("\\", "") + .replace("'", "") + .strip() + .lower(), + ) + + tokens = [ + sent_cleaner(sent) + for sent in report_cleaner(report) + if sent_cleaner(sent) != [] + ] + report = " . ".join(tokens) + " ." + return report + + +def load_exemplar(annotation, text_tokenizer: Tokenizer, max_tgt_length): + id2exemplar = {} + progress = tqdm(annotation, desc="Loading Exemplars") + for sample in progress: + id2exemplar[sample["id"]] = [] + return id2exemplar + + +def jaccard_distance(a, b): + return len(a & b) / len(a | b) + + +def insert_plan(report, tokenizer, splan, gplan, observation): + report = tokenizer.clean_report(report) + sentences = [s for s in report.split(".") if len(s.strip()) > 0] + if len(sentences) == 0: + return [] + no_finding = observation[-1] + observation = observation[:-1] + positions = sorted(splan.keys(), key=lambda x: int(x)) + planed_observation = set() + observation_category = {o.split(":")[0] for o in observation} + for pos in positions: + planed_obs = splan[pos]["observation"] + clean_planed_obs = [] + for o in planed_obs: + c = o.split(":")[0] + if c not in observation_category: + continue + if o not in observation: + for new_o in observation: + if c in new_o: + o = new_o + clean_planed_obs.append(o) + break + else: + clean_planed_obs.append(o) + splan[pos]["observation"] = clean_planed_obs + planed_observation.update(splan[pos]["observation"]) + left_observation = [o for o in observation if o not in planed_observation] + + tokens = [] + existed_plan = set() + for pos in positions: + sentence_plan = ["[{}]".format(p) for p in splan[pos]["observation"]] + sentence_plan = [o for o in sentence_plan if o not in existed_plan] + + if len(sentence_plan) > 1 and gplan is not None: + sentence_plan = sorted( + sentence_plan, key=lambda x: gplan[x[1:-1].split(":")[0]] + ) + tokens.extend(sentence_plan) + tokens.extend(splan[pos]["sentence"].strip().split()) + + if len(left_observation) > 1: + left_observation = sorted( + left_observation, key=lambda x: gplan[x.split(":")[0]] + ) + left_observation = [no_finding] + left_observation + left_tokens = ["[{}]".format(o) for o in left_observation] + tokens = left_tokens + tokens + ids = [] + for token in tokens: + if token == " ": + continue + ids.append(tokenizer.get_id_by_token(token)) + ids = ids + [tokenizer.eos_token_id] + return ids + + +def construct_obs_aware_token(sentences, entity2id): + obs_aware_tokens = [] + for sentence_id in sentences: + sentence = sentences[sentence_id] + observations = sentence["observation"] + sentence = sentence["sentence"] + tokens = sentence.split() + for token in tokens: + if token == " ": + continue + obs_aware_token_ids = set() + if len(observations) == 0: + observations = ["NONE"] + for obs in observations: + obs_aware_token = obs + "-" + token + if obs_aware_token in entity2id: + obs_aware_token_ids.add(obs_aware_token) + obs_aware_tokens.append(obs_aware_token_ids) + return obs_aware_tokens + + +def process_examples( + examples, + max_tgt_length, + tokenizer, +): + progress = tqdm( + range(len(examples["id"])), + desc="Processing Samples", + ) + labels = [] + idxs = [] + image_paths = [] + temporal_image_paths = [] + temporal_reports = [] + temporal_entity_ids = [] + current_entity_ids = [] + temporal_predicates = [] + for index in progress: + report_id = examples["id"][index] + image_path = examples["image_path"][index] + report = tokenizer.encode(examples["report"][index]) + label = report[1:] + if len(label) > max_tgt_length: + label = label[: max_tgt_length - 1] + label[-1:] + temporal_image_path = examples["temporal_image_path"][index] + temporal_report = examples["temporal_report"][index] + if temporal_report is None: + temporal_report = "" + if len(temporal_report) == 0: + temporal_report = [] + else: + temporal_report = tokenizer.encode(temporal_report) + + if len(temporal_report) > max_tgt_length: + temporal_report = ( + temporal_report[: max_tgt_length - 1] + temporal_report[-1:] + ) + temporal_predicate = examples["temporal_predicate"][index] + labels.append(label) + idxs.append(report_id) + image_paths.append(image_path) + + temporal_image_paths.append(temporal_image_path) + temporal_reports.append(temporal_report) + temporal_predicates.append(temporal_predicate) + temporal_entity_ids.append(examples["temporal_entity"][index]) + current_entity_ids.append(examples["current_entity"][index]) + return ( + idxs, + image_paths, + temporal_image_paths, + temporal_entity_ids, + current_entity_ids, + temporal_predicates, + temporal_reports, + labels, + ) diff --git a/src_stage2/dataset_ende.py b/src_stage2/dataset_ende.py new file mode 100644 index 0000000..9009f61 --- /dev/null +++ b/src_stage2/dataset_ende.py @@ -0,0 +1,440 @@ +import json +import os +from collections import defaultdict + +import numpy as np +from torch.utils.data import Dataset +import torch +from data_arguments import DataTrainingArguments +from data_process_ende import process_examples +from tqdm import tqdm +from PIL import Image +import random + + +def load_images(root_path, image_paths): + images = {} + for image_path in tqdm(image_paths, desc="Loading Images"): + for img_path in image_path: + img_path = os.path.join(root_path, img_path) + image = Image.open(img_path).convert("RGB") + images[img_path] = image + return images + + +def extract_temporal_info( + samples, + ref_samples, + temporal_ids, + entity2id=None, + entity_label=None, +): + id2sample = {sample["id"]: sample for sample in samples} + if ref_samples is not None: + ref_id2sample = {sample["id"]: sample for sample in ref_samples} + for subject_id in temporal_ids: + object_id = temporal_ids[subject_id]["object_id"] + if object_id not in id2sample: + id2sample[object_id] = ref_id2sample[object_id] + + for sample in samples: + sample["temporal_image_path"] = [] + sample["temporal_entity"] = set() + sample["current_entity"] = set() + sample["temporal_predicate"] = [] + sample["temporal_id"] = None + sample["temporal_report"] = "" + + for subject_id in tqdm(temporal_ids, desc="Updating Temooral Info"): + predicate_object = temporal_ids[subject_id] + predicate = predicate_object["predicate"] + subject_example = id2sample[subject_id] + + object_id = predicate_object["object_id"] + if object_id not in id2sample: + print(object_id, "Not Found") + else: + object_example = id2sample[object_id] + subject_example["temporal_image_path"] = object_example["image_path"] + subject_example["temporal_report"] = object_example["report"] + if object_id in entity_label: + for e in entity_label[object_id]: + subject_example["temporal_entity"].add(e) + subject_example["temporal_predicate"] = predicate + return samples + + +class DatasetCustom(Dataset): + def __init__( + self, + data_args: DataTrainingArguments, + annotation, + ref_annotation, + temporal_ids, + split: str, + text_tokenizer, + tokenizer, + id2tags, + processor, + progression_graph, + observation_category, + transform=None, + keep_columns={ + "id", + "report", + "image_path", + "temporal_image_path", + "temporal_entity", + "current_entity", + "temporal_predicate", + "temporal_report", + }, + ) -> None: + super().__init__() + self.text_tokenizer = text_tokenizer + self.tokenizer = tokenizer + self.processor = processor + self.data_args = data_args + self.split = split + self.dataset = data_args.dataset + self.id2tags = id2tags + examples = {kc: [] for kc in keep_columns} + samples = annotation[split.replace("valid", "val")] + self.temporal_ids = temporal_ids[split.replace("valid", "val")] + ref_samples = None + if ref_annotation is not None: + ref_samples = ref_annotation[split.replace("valid", "val")] + self.temporal_collection = temporal_ids.keys() + ( + self.triples, + self.entity2id, + self.id2entity, + self.relation2id, + self.id2relation, + self.entity2subid, + ) = ( + progression_graph["triples"], + progression_graph["entity2id"], + progression_graph["id2entity"], + progression_graph["relation2id"], + progression_graph["id2relation"], + progression_graph["entity2subid"], + ) + with open( + f"./data/{data_args.graph_version}/%s/id2entity.json" % data_args.dataset, + "r", + encoding="utf-8", + ) as f: + self.id2entity_label = json.load(f) + + samples = extract_temporal_info( + samples, + ref_samples, + self.temporal_ids, + self.entity2id, + self.id2entity_label, + ) + for sample in samples: + for key in sample: + if key not in examples: + continue + examples[key].append(sample[key]) + for key in examples: + print(key, examples[key][:3]) + ( + idxs, + image_paths, + temporal_image_paths, + temporal_entity_ids, + current_entity_ids, + temporal_predicates, + temporal_reports, + labels, + ) = process_examples( + examples=examples, + max_tgt_length=data_args.max_tgt_length, + tokenizer=tokenizer, + ) + self.data = [ + { + "id": a, + "image_path": b, + "temporal_image_path": c, + "temporal_entity_ids": d, + "current_entity_ids": e, + "temporal_predicates": f, + "temporal_report": g, + "labels": h, + } + for a, b, c, d, e, f, g, h in zip( + idxs, + image_paths, + temporal_image_paths, + temporal_entity_ids, + current_entity_ids, + temporal_predicates, + temporal_reports, + labels, + ) + ] + self.all_index = list(range(len(self.data))) + + self.observation2id = {obs: idx for idx, obs in enumerate(observation_category)} + self.observation_category = observation_category + self.transform = transform + self.tokenizer = tokenizer + self.triples_ = defaultdict(list) + for (hid, rid), tids in self.triples.items(): + for i, tid in enumerate(tids): + self.triples_[(hid, tid)].append(rid) + + if self.split != "train": + path = data_args.stage1_model_name_or_path + if self.split == "test": + path = path + "results.json" + else: + path = path + data_args.stage1_eval_file + self.op_data = json.load(open(path, "r", encoding="utf-8")) + + def __getitem__(self, index): + idx = self.data[index]["id"] + labels = self.data[index]["labels"] + status2id = {"Better": 0, "Worse": 1, "No status change": 2} + progressions = [0, 0, 0] + for progression in self.data[index]["temporal_predicates"]: + staid = status2id[progression] + progressions[staid] = 1 + + # current radiograph + image_path = [ + os.path.join(self.data_args.image_path, a) + for a in self.data[index]["image_path"] + ] + pixel_value = [] + for img_path in image_path: + image = Image.open(img_path).convert("RGB") + if self.transform is not None: + image = self.transform(image) + image = self.processor(images=image, return_tensors="pt")["pixel_values"] + pixel_value.append(image) + pixel_value = torch.cat(pixel_value, dim=0) + + # prior radiograph + temporal_image_path = [ + os.path.join(self.data_args.image_path, a) + for a in self.data[index]["temporal_image_path"] + ] + pixel_value_temporal = torch.zeros_like(pixel_value) + for i, img_path in enumerate(temporal_image_path): + image = Image.open(img_path).convert("RGB") + if self.transform is not None: + image = self.transform(image) + image = self.processor(images=image, return_tensors="pt")["pixel_values"][0] + pixel_value_temporal[i] = image + + # load current observations + id2nodelabel = {0: ":Negative", 1: ":Positive"} + if self.split == "train": + current_observations = [ + self.observation_category[pos] + id2nodelabel[tag] + for pos, tag in enumerate(self.id2tags[idx]) + if tag != 2 + ] + else: + obs = self.op_data[idx]["obs_hyp"] + current_observations = obs + current_observations = sorted( + current_observations, + key=lambda x: self.observation_category.index(x.split(":")[0]), + ) + + # load prior observations + prior_observations = [] + if ( + idx in self.temporal_ids + and self.temporal_ids[idx]["object_id"] in self.id2tags + ): + prior_observations = [ + "pre_" + self.observation_category[pos] + id2nodelabel[tag] + for pos, tag in enumerate( + self.id2tags[self.temporal_ids[idx]["object_id"]] + ) + if tag != 2 + ] + observation_prompt_ids = [] + report_prompt_ids = [] + + if self.split == "train": + observation_prompt_ids = [ + self.tokenizer.token2idx[ + "[{}]".format(self.observation_category[pos] + id2nodelabel[tag]) + ] + for pos, tag in enumerate(self.id2tags[idx]) + if tag != 2 + ] + else: + observation_prompt_ids = [ + self.tokenizer.token2idx["[{}]".format(o)] + for o in sorted( + self.op_data[idx]["obs_hyp"], + key=lambda x: self.observation_category.index(x.split(":")[0]), + ) + ] + + # load prior report + if len(temporal_image_path) > 0: + report_prompt_ids = self.data[index]["temporal_report"] + + # insert [FiV] or [FoV] to distinguish first visits and follow-up visits + f_v = self.tokenizer.token2idx[ + "[First-Visit]" if len(temporal_image_path) == 0 else "[Follow-Up-Visit]" + ] + observation_prompt_ids = [f_v] + observation_prompt_ids + input_ids = observation_prompt_ids + size = len(input_ids) + progression_input_ids = report_prompt_ids + + if size == 0: + input_ids = [self.tokenizer.pad_token_id] + if len(progression_input_ids) == 0: + progression_input_ids = [self.tokenizer.pad_token_id] + prior_entity = self.data[index]["temporal_entity_ids"] + prior_entity_ids = [] + for e in prior_entity: + if e in self.entity2id: + prior_entity_ids.append(self.entity2id[e]) + + if self.split == "train": + progressions_ = self.data[index]["temporal_predicates"] + else: + progressions_ = [] + if "pro_hyp" in self.op_data[idx]: + progressions_ = self.op_data[idx]["pro_hyp"] + + # construct progression graph + graph_info = self.graph_construction( + prior_observations=prior_observations, + current_observations=current_observations, + prior_entity_ids=prior_entity_ids, + progressions=progressions_, + labels=labels if self.split == "train" else None, + ) + gate_labels = [0] * len(labels) + gather_index = graph_info["gather_index"] + if self.split == "train": + for lid, l in enumerate(labels[:-1]): + if l in graph_info["node_subids"].values(): + gate_labels[lid] = 1 + item = { + "image_path": image_path, + "temporal_image_path": temporal_image_path, + "input_pixels": pixel_value, + "input_temporal_pixels": pixel_value_temporal, + "labels": labels, + "input_ids": input_ids, + "progression_input_ids": progression_input_ids, + "progressions": progressions, + "split": self.split, + "observations": self.id2tags[idx], + "matrix": graph_info["matrix"], + "node_mask": graph_info["node_mask"], + "nodes": graph_info["nodes"], + "gather_index": gather_index, + "gate_labels": gate_labels, + } + if self.split != "train": + item["report_ids"] = idx + item["is_temporal"] = len(temporal_image_path) > 0 + item["prior_entity_ids"] = prior_entity_ids + item["prior_observations"] = prior_observations + return item + + def __len__(self): + return len(self.data) + + def graph_construction( + self, + prior_observations, + current_observations, + prior_entity_ids, + progressions, + labels=None, + ): + prior_observation_ids = { + self.entity2id[o] for o in prior_observations if o in self.entity2id + } + current_observation_ids = { + self.entity2id[o] for o in current_observations if o in self.entity2id + } + current_relation_ids = {self.relation2id[p] for p in progressions} + current_relation_ids.add(3) # S2O + candidate_entity_ids = set() + tem_ids = set() + + for (hid, rid), tids in self.triples.items(): + if hid in current_observation_ids and rid in current_relation_ids: + candidate_entity_ids.update(tids) + if rid != 3: + tem_ids.update(tids) + + nodes = sorted(prior_observation_ids.union(current_observation_ids)) + sorted( + set(prior_entity_ids).union(candidate_entity_ids) + ) + node2pos = {node: idx for idx, node in enumerate(nodes)} + matrix = np.zeros((len(self.id2relation), len(nodes), len(nodes))) + + # prior entity->prior observation + for eid in prior_entity_ids: + for oid in prior_observation_ids: + for rid in self.triples_[(eid, oid)]: + matrix[rid, node2pos[oid], node2pos[eid]] = 1 + + # prior observation->current observation + for pid in prior_observation_ids: + for cid in current_observation_ids: + for rid in self.triples_[(pid, cid)]: + matrix[rid, node2pos[cid], node2pos[pid]] = 1 + + # current observation->current observation + for cid in current_observation_ids: + for cid2 in current_observation_ids: + if cid == cid2: + continue + matrix[4, node2pos[cid], node2pos[cid2]] = 1 + matrix[4, node2pos[cid2], node2pos[cid]] = 1 + + # current observation->current entity + for oid in current_observation_ids: + for eid in candidate_entity_ids: + for rid in self.triples_[(oid, eid)]: + if rid in current_relation_ids: + matrix[rid, node2pos[eid], node2pos[oid]] = 1 + + node_subids = { + idx: self.entity2subid[self.id2entity[idx]] + for idx in candidate_entity_ids + if self.id2entity[idx] in self.entity2subid + } + gather_index = [0] * self.data_args.vocab_size + for idx, subid in node_subids.items(): + gather_index[subid] = node2pos[idx] + + node_mask = [0] * len(nodes) + + for idx in nodes: + if idx in candidate_entity_ids: + node_mask[node2pos[idx]] = 1 + + if idx in current_observation_ids: + node_mask[node2pos[idx]] = -1 + + nodes = sorted(nodes, key=lambda x: node2pos[x]) + + return { + "matrix": matrix, + "nodes": nodes, + "node_subids": node_subids, + "gather_index": gather_index, + "node_mask": node_mask, + "node2pos": node2pos, + } diff --git a/src_stage2/metrics.py b/src_stage2/metrics.py new file mode 100644 index 0000000..724ece9 --- /dev/null +++ b/src_stage2/metrics.py @@ -0,0 +1,39 @@ +import os +import sys + +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from pycocoevalcap.bleu.bleu import Bleu +from pycocoevalcap.cider.cider import Cider +from pycocoevalcap.meteor import Meteor +from pycocoevalcap.rouge import Rouge + + +def compute_scores(gts, res): + """ + Performs the MS COCO evaluation using the Python 3 implementation (https://github.com/salaniz/pycocoevalcap) + + :param gts: Dictionary with the image ids and their gold captions, + :param res: Dictionary with the image ids ant their generated captions + :print: Evaluation score (the mean of the scores of all the instances) for each measure + """ + + # Set up scorers + scorers = [ + (Bleu(4), ["BLEU_1", "BLEU_2", "BLEU_3", "BLEU_4"]), + (Meteor(), "METEOR"), + (Rouge(), "ROUGE_L"), + ] + eval_res = {} + # Compute score for each metric + for scorer, method in scorers: + try: + score, scores = scorer.compute_score(gts, res, verbose=0) + # except TypeError: + except Exception: + score, scores = scorer.compute_score(gts, res) + if type(method) == list: + for sc, m in zip(score, method): + eval_res[m] = sc + else: + eval_res[method] = score + return eval_res diff --git a/src_stage2/model_arguments.py b/src_stage2/model_arguments.py new file mode 100644 index 0000000..41769dc --- /dev/null +++ b/src_stage2/model_arguments.py @@ -0,0 +1,84 @@ +from dataclasses import dataclass, field +from typing import Optional + +from transformers import MODEL_FOR_CAUSAL_LM_MAPPING + +MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + chexbert_model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The plan model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + stage1_model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The plan model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + stage1_eval_file: Optional[str] = field(default=None) + model_type: Optional[str] = field( + default=None, + metadata={ + "help": "If training from scratch, pass a model type from the list: " + + ", ".join(MODEL_TYPES) + }, + ) + config_overrides: Optional[str] = field( + default=None, + metadata={ + "help": "Override some existing default config settings when a model is trained from scratch. Example: " + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" + }, + ) + config_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained config name or path if not the same as model_name" + }, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name" + }, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={ + "help": "Where do you want to store the pretrained models downloaded from huggingface.co" + }, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={ + "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." + }, + ) + model_revision: str = field( + default="main", + metadata={ + "help": "The specific model version to use (can be a branch name, tag name or commit id)." + }, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + test_model_name_or_path: Optional[str] = field(default=None) + fast_lr: float = field(default=1e-4) + num_beams: int = field(default=4) diff --git a/src_stage2/models/activations.py b/src_stage2/models/activations.py new file mode 100644 index 0000000..fd23c3d --- /dev/null +++ b/src_stage2/models/activations.py @@ -0,0 +1,184 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from packaging import version +from torch import Tensor, nn +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class NewGELUActivation(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * + (input + 0.044715 * torch.pow(input, 3.0)))) + + +class GELUActivation(nn.Module): + """ + Original Implementation of the GELU activation function in Google BERT repo when initially created. For + information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional + Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + def __init__(self, use_gelu_python: bool = False): + super().__init__() + if version.parse( + torch.__version__) < version.parse("1.4") or use_gelu_python: + self.act = self._gelu_python + else: + self.act = nn.functional.gelu + + def _gelu_python(self, input: Tensor) -> Tensor: + return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0))) + + def forward(self, input: Tensor) -> Tensor: + return self.act(input) + + +class FastGELUActivation(nn.Module): + """ + Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs + """ + def forward(self, input: Tensor) -> Tensor: + return 0.5 * input * (1.0 + + torch.tanh(input * 0.7978845608 * + (1.0 + 0.044715 * input * input))) + + +class QuickGELUActivation(nn.Module): + """ + Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs + """ + def forward(self, input: Tensor) -> Tensor: + return input * torch.sigmoid(1.702 * input) + + +class ClippedGELUActivation(nn.Module): + """ + Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as + it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to + https://arxiv.org/abs/2004.09602. + + Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when + initially created. + + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415 + """ + def __init__(self, min: float, max: float): + if min > max: + raise ValueError( + f"min should be < max (got min: {min}, max: {max})") + + super().__init__() + self.min = min + self.max = max + + def forward(self, x: Tensor) -> Tensor: + return torch.clip(gelu(x), self.min, self.max) + + +class SiLUActivation(nn.Module): + """ + See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear + Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function + Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated + Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with + later. + """ + def __init__(self): + super().__init__() + if version.parse(torch.__version__) < version.parse("1.7"): + self.act = self._silu_python + else: + self.act = nn.functional.silu + + def _silu_python(self, input: Tensor) -> Tensor: + return input * torch.sigmoid(input) + + def forward(self, input: Tensor) -> Tensor: + return self.act(input) + + +class MishActivation(nn.Module): + """ + See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also + visit the official repository for the paper: https://github.com/digantamisra98/Mish + """ + def __init__(self): + super().__init__() + if version.parse(torch.__version__) < version.parse("1.9"): + self.act = self._mish_python + else: + self.act = nn.functional.mish + + def _mish_python(self, input: Tensor) -> Tensor: + return input * torch.tanh(nn.functional.softplus(input)) + + def forward(self, input: Tensor) -> Tensor: + return self.act(input) + + +class LinearActivation(nn.Module): + """ + Applies the linear activation function, i.e. forwarding input directly to output. + """ + def forward(self, input: Tensor) -> Tensor: + return input + + +ACT2FN = { + "relu": nn.ReLU(), + "silu": SiLUActivation(), + "swish": SiLUActivation(), + "gelu": GELUActivation(), + "tanh": nn.Tanh(), + "gelu_python": GELUActivation(use_gelu_python=True), + "gelu_new": NewGELUActivation(), + "gelu_fast": FastGELUActivation(), + "quick_gelu": QuickGELUActivation(), + "gelu_10": ClippedGELUActivation(-10, 10), + "mish": MishActivation(), + "linear": LinearActivation(), + "sigmoid": nn.Sigmoid(), +} + + +def get_activation(activation_string): + if activation_string in ACT2FN: + return ACT2FN[activation_string] + else: + raise KeyError( + f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}" + ) + + +# For backwards compatibility with: from activations import gelu_python +gelu_python = get_activation("gelu_python") +gelu_new = get_activation("gelu_new") +gelu = get_activation("gelu") +gelu_fast = get_activation("gelu_fast") +quick_gelu = get_activation("quick_gelu") +silu = get_activation("silu") +mish = get_activation("mish") +linear_act = get_activation("linear") diff --git a/src_stage2/models/modeling_bart.py b/src_stage2/models/modeling_bart.py new file mode 100644 index 0000000..d70680f --- /dev/null +++ b/src_stage2/models/modeling_bart.py @@ -0,0 +1,1575 @@ +from typing import Optional, Tuple, Dict, Any, Union, List +import torch +from torch import nn +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import CrossEntropyLoss +from transformers import ( + ViTConfig, + BartForCausalLM, + BartConfig, + BartPretrainedModel, +) +from transformers.modeling_outputs import ( + ModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutput, +) +from transformers.models.bart.configuration_bart import BartConfig +from transformers.models.bart.modeling_bart import ( + shift_tokens_right, + BartEncoder, + BartDecoderLayer, + BartDecoder, + BartModel, + _make_causal_mask, + _expand_mask, + BartAttention, +) +from transformers.utils import logging +from dataclasses import dataclass +import os +from transformers.file_utils import WEIGHTS_NAME +import random +from torch.nn import Embedding +from transformers.generation_utils import * + +logger = logging.get_logger(__name__) + + +@dataclass +class CausalLMOutputWithCrossAttentions(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + gate: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class ViTBartOutput(ModelOutput): + loss: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + pre_visual_last_hidden_state: torch.FloatTensor = None + progression_hidden_state: torch.FloatTensor = None + progression_attention_mask: torch.FloatTensor = None + observation_hidden_state: torch.FloatTensor = None + observation_attention_mask: torch.FloatTensor = None + node_hidden_state: torch.FloatTensor = None + observation_det_logits: torch.FloatTensor = None + observation_cls_logits: torch.FloatTensor = None + progression_logits: torch.FloatTensor = None + temporal_mask: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + + +@dataclass +class CausalModelOutput(ModelOutput): + encoder_loss: Optional[torch.FloatTensor] = None + decoder_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + last_hidden_state: Optional[Tuple[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attention_mask: Optional[Tuple[torch.FloatTensor]] = None + encoder_visual_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_visual_attention_mask: Optional[Tuple[torch.FloatTensor]] = None + node_hidden_state: Optional[Tuple[torch.FloatTensor]] = None + pooler_output: Optional[Tuple[torch.FloatTensor]] = None + + +class BartEncoderCustom(BartEncoder): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input = input_ids + input_ids = input_ids.view(-1, input_ids.shape[-1]) + elif inputs_embeds is not None: + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # position_ids = ((attention_mask > 0).float().cumsum(dim=-1) - 1).long() + position_ids = (attention_mask.cumsum(dim=-1) - 1).long() + embed_pos = super(type(self.embed_positions), self.embed_positions).forward( + self.embed_positions.offset + position_ids + ) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + + # expand attention_mask + if attention_mask is not None: + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if self.training and ( + dropout_probability < self.layerdrop + ): # skip the layer + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=( + head_mask[idx] if head_mask is not None else None + ), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, encoder_states, all_attentions] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +class BartDecoderLayerCustom(BartDecoderLayer): + def __init__(self, config: BartConfig): + super().__init__(config) + self.encoder_visual_attn = BartAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.encoder_visual_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.progression_gate = nn.Sequential( + nn.Linear(config.d_model, 1), + nn.Sigmoid(), + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + encoder_visual_hidden_states: Optional[torch.Tensor] = None, + encoder_visual_attention_mask: Optional[torch.Tensor] = None, + temporal_mask=None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + index = 2 + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + # print("observation", encoder_hidden_states.size()) + residual = hidden_states + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = ( + past_key_value[index : index + 2] + if past_key_value is not None + else None + ) + ( + hidden_states, + cross_attn_weights, + cross_attn_present_key_value, + ) = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + index += 2 + + cross_visual_attn_present_key_value = None + cross_visual_attn_weights = None + if encoder_visual_hidden_states is not None: + # print("progression", encoder_visual_hidden_states.size()) + residual = hidden_states + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_visual_attn_past_key_value = ( + past_key_value[index : index + 2] + if past_key_value is not None + else None + ) + ( + hidden_states, + cross_visual_attn_weights, + cross_visual_attn_present_key_value, + ) = self.encoder_visual_attn( + hidden_states=hidden_states, + key_value_states=encoder_visual_hidden_states, + attention_mask=encoder_visual_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_visual_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = self.encoder_visual_attn_layer_norm(hidden_states) + alpha = self.progression_gate(residual) * temporal_mask.unsqueeze( + -1 + ).unsqueeze(-1) + hidden_states = alpha * hidden_states + (1 - alpha) * residual + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_visual_attn_present_key_value + index += 2 + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout( + hidden_states, p=self.activation_dropout, training=self.training + ) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += ( + self_attn_weights, + cross_attn_weights, + cross_visual_attn_weights, + ) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class BartDecoderCustom(BartDecoder): + def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config, embed_tokens) + self.layers = nn.ModuleList( + [BartDecoderLayerCustom(config) for _ in range(config.decoder_layers)] + ) + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_visual_hidden_states: Optional[torch.FloatTensor] = None, + encoder_visual_attention_mask: Optional[torch.LongTensor] = None, + temporal_mask=None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + input = input_ids + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) * self.embed_scale + + # embed positions + positions = self.embed_positions(input, past_key_values_length) + positions = positions.to(inputs_embeds.device) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + if ( + encoder_visual_hidden_states is not None + and encoder_visual_attention_mask is not None + ): + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_visual_attention_mask = _expand_mask( + encoder_visual_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + + hidden_states = inputs_embeds + positions + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = ( + () if (output_attentions and encoder_hidden_states is not None) else None + ) + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip( + [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"] + ): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] + if cross_attn_head_mask is not None + else None, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_visual_hidden_states=encoder_visual_hidden_states, + encoder_visual_attention_mask=encoder_visual_attention_mask, + temporal_mask=temporal_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] + if cross_attn_head_mask is not None + else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class BartForCausalLMCustom(BartForCausalLM): + def __init__(self, config): + super().__init__(config) + self.model.decoder = BartDecoderCustom(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + encoder_visual_hidden_states: Optional[torch.FloatTensor] = None, + encoder_visual_attention_mask: Optional[torch.FloatTensor] = None, + temporal_mask=None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + encoder_visual_hidden_states=encoder_visual_hidden_states, + encoder_visual_attention_mask=encoder_visual_attention_mask, + temporal_mask=temporal_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + # logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +class ViTEncoder(BartPretrainedModel): + def __init__(self, config, decoder_config, embed_tokens): + super().__init__(config) + self.observation_bart = BartEncoderCustom( + config=decoder_config, embed_tokens=embed_tokens + ) + self.progression_bart = BartEncoderCustom( + config=decoder_config, embed_tokens=embed_tokens + ) + # 0 for current image + # 1 for prior image + # 2 for prior report + # 3 for observation + + from models.rgcn import RGCN + + self.rgcn = RGCN(config) + self.post_init() + + from src_stage1.models.modeling_vit import VisualEncoder + + self.vit = VisualEncoder(config=config) + self.vit_config: ViTConfig = self.vit.visual_extractor.config + if config.stage1_model_name_or_path is not None: + print("***************************") + print("***************************") + print( + "Loading Stage 1 Pretrained ViT Model", config.stage1_model_name_or_path + ) + print("***************************") + print("***************************") + state_dict = torch.load( + os.path.join( + config.stage1_model_name_or_path, + WEIGHTS_NAME, # pytorch_model.bin + ), + map_location=self.device, + ) + self.vit.load_state_dict(state_dict, strict=True) + + def forward( + self, + input_ids=None, + attention_mask=None, + progression_input_ids=None, + progression_attention_mask=None, + input_pixels=None, + input_temporal_pixels=None, + temporal_mask=None, + observations=None, + progressions=None, + matrix=None, + nodes=None, + node_mask=None, + ): + loss = None + node_hidden_state = None + observation_det_logits = None + observation_cls_logits = None + progression_logits = None + progression_hidden_state = None + observation_hidden_state = None + prior_attention_mask = None + visual_attention_mask = None + + # spatiotemporal prediction + if self.config.is_temporal == 0: + temporal_mask = temporal_mask * 0 + if ( + input_pixels is not None + and input_temporal_pixels is not None + and temporal_mask is not None + ): + vit_outputs = self.vit( + input_pixels=input_pixels, + input_temporal_pixels=input_temporal_pixels, + temporal_mask=temporal_mask, + observations=observations, + progressions=progressions, + require_logits=False, + ) + pre_visual_last_hidden_state = vit_outputs.last_hidden_state + + last_hidden_state, prior_last_hidden_state = pre_visual_last_hidden_state + visual_attention_mask = torch.ones_like(last_hidden_state[:, :, 0]) + prior_attention_mask = torch.ones_like( + prior_last_hidden_state[:, :, 0] + ) * temporal_mask.unsqueeze(-1) + if ( + progression_input_ids is not None + and progression_attention_mask.sum() > 0 + ): + progression_input_embeds = self.progression_bart.embed_tokens( + progression_input_ids + ) + prior_last_hidden_state = torch.cat( + ( + prior_last_hidden_state, + progression_input_embeds, + ), + dim=1, + ) + prior_attention_mask = torch.cat( + (prior_attention_mask, progression_attention_mask), dim=-1 + ) + progression_hidden_state = self.progression_bart( + inputs_embeds=prior_last_hidden_state, + attention_mask=prior_attention_mask, + ).last_hidden_state + + if input_ids is not None and attention_mask.sum() > 0: + input_embeds = self.observation_bart.embed_tokens(input_ids) + last_hidden_state = torch.cat((last_hidden_state, input_embeds), dim=1) + visual_attention_mask = torch.cat( + (visual_attention_mask, attention_mask), dim=-1 + ) + observation_hidden_state = self.observation_bart( + inputs_embeds=last_hidden_state, + attention_mask=visual_attention_mask, + ).last_hidden_state + + # precise attribute modeling + node_hidden_state = self.rgcn( + nodes=nodes, + matrix=matrix, + ) + + return ViTBartOutput( + loss=loss, + pre_visual_last_hidden_state=pre_visual_last_hidden_state, + progression_hidden_state=progression_hidden_state, + progression_attention_mask=prior_attention_mask, + observation_hidden_state=observation_hidden_state, + observation_attention_mask=visual_attention_mask, + node_hidden_state=node_hidden_state, + observation_det_logits=observation_det_logits, + observation_cls_logits=observation_cls_logits, + progression_logits=progression_logits, + temporal_mask=temporal_mask, + pooler_output=vit_outputs.pooler_output, + ) + + +class ViTBartModel(BartPretrainedModel): + def __init__(self, plm_config: BartConfig, init_config: BartConfig): + super().__init__(plm_config) + decoder = BartForCausalLMCustom(init_config) + self.decoder = decoder.model.decoder + self.lm_head = decoder.lm_head + self.encoder = ViTEncoder( + plm_config, init_config, embed_tokens=self.decoder.embed_tokens + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.FloatTensor = None, + progression_input_ids: torch.LongTensor = None, + progression_attention_mask: torch.FloatTensor = None, + decoder_input_ids: torch.LongTensor = None, + decoder_attention_mask: torch.FloatTensor = None, + head_mask: torch.FloatTensor = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + input_pixels: torch.FloatTensor = None, + input_temporal_pixels: torch.FloatTensor = None, + temporal_mask: torch.FloatTensor = None, + matrix: torch.FloatTensor = None, + nodes: torch.LongTensor = None, + node_mask: torch.FloatTensor = None, + encoder_outputs: Optional[ModelOutput] = None, + labels: Optional[torch.LongTensor] = None, + observations: Optional[torch.FloatTensor] = None, + progressions: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + progression_input_ids=progression_input_ids, + progression_attention_mask=progression_attention_mask, + input_pixels=input_pixels, + input_temporal_pixels=input_temporal_pixels, + temporal_mask=temporal_mask, + observations=observations, + progressions=progressions, + matrix=matrix, + nodes=nodes, + node_mask=node_mask, + ) + encoder_visual_hidden_states = None + encoder_visual_attention_mask = None + if self.config.is_temporal == 1: + encoder_visual_hidden_states = encoder_outputs.progression_hidden_state + encoder_visual_attention_mask = encoder_outputs.progression_attention_mask + + decoder_outputs = self.decoder( + # self-attention + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + # observation-aware cross-attention + encoder_hidden_states=encoder_outputs.observation_hidden_state, + encoder_attention_mask=encoder_outputs.observation_attention_mask, + # progression-aware cross-attention + encoder_visual_hidden_states=encoder_visual_hidden_states, + encoder_visual_attention_mask=encoder_visual_attention_mask, + temporal_mask=encoder_outputs.temporal_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + return CausalModelOutput( + encoder_loss=encoder_outputs.loss, + past_key_values=decoder_outputs.past_key_values, + last_hidden_state=decoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.observation_hidden_state, + encoder_attention_mask=encoder_outputs.observation_attention_mask, + encoder_visual_hidden_states=encoder_outputs.progression_hidden_state, + encoder_visual_attention_mask=encoder_outputs.progression_attention_mask, + node_hidden_state=encoder_outputs.node_hidden_state, + pooler_output=encoder_outputs.pooler_output, + ) + + +class PrRModule(nn.Module): + def __init__(self, config: BartConfig): + super().__init__() + self.obs_weight = nn.ModuleList( + [ + nn.Linear(config.d_model, config.d_model, bias=False) + for _ in range(config.num_relation - 1) + ] + ) + self.tok_weight = nn.ModuleList( + [ + nn.Linear(config.d_model, config.d_model, bias=False) + for _ in range(config.num_relation - 1) + ] + ) + self.self_weight = nn.Linear(config.d_model, config.d_model, bias=False) + self.act = torch.nn.Tanh() + self.config = config + + def forward( + self, + last_hidden_state, + node_hidden_state, + cls_hidden_state, + matrix, + node_mask, + nodes, + gate_labels=None, + ): + node_matrix = matrix[:, :-1] + value = node_hidden_state + value = value.transpose(1, 2) + obs_query = [weight_fn(last_hidden_state) for weight_fn in self.obs_weight] + tok_query = [weight_fn(last_hidden_state) for weight_fn in self.tok_weight] + obs_logits = torch.stack([q.bmm(value) for q in obs_query], dim=1) + tok_logits = torch.stack([q.bmm(value) for q in tok_query], dim=1) + prr_logits = tok_logits.unsqueeze(-1) + obs_logits.unsqueeze(-2) + prr_logits = self.act(prr_logits) * (node_matrix > 0).float().unsqueeze(2) + norm = ( + (node_matrix > 0) + .float() + .sum(dim=1) + .sum(dim=-1, keepdim=True) + .transpose(1, 2) + ) + norm = torch.max(norm, torch.ones_like(norm)) + prr_logits = prr_logits.sum(dim=1).sum(dim=-1) / norm + logits = (self.self_weight(last_hidden_state)).bmm(value) + logits = self.act(logits) + weight = logits + prr_logits * 2 + weight = weight.masked_fill(node_mask.unsqueeze(1) <= 0, -10000) + node_proba = torch.softmax(weight, dim=-1) + return node_proba, node_proba + + +class ViTBartForGeneration(BartPretrainedModel): + def __init__(self, encoder_config: BartConfig, decoder_config: BartConfig): + super().__init__(decoder_config) + self.config = decoder_config + self.main_input_name = "input_pixels" + self.model_parallel = False + self.prr_model = PrRModule(decoder_config) + # copy gate + self.controller = nn.Sequential( + nn.Linear(decoder_config.d_model, 1, bias=False), + nn.Sigmoid(), + ) + self.apply(self._init_weights) + # ViT Pretrained Model dose not need init weights + self.model = ViTBartModel(encoder_config, decoder_config) + self.lm_head = self.model.lm_head + self.tie_weights() + + def get_encoder(self): + return self.model.encoder + + def get_decoder(self): + return self.model.decoder + + def get_output_embeddings(self): + return self.model.decoder.embed_tokens + + def get_input_embeddings(self): + return self.model.encoder.observation_bart.embed_tokens + + def set_input_embeddings(self, value): + self.model.encoder.observation_bart.embed_tokens = value + self.model.encoder.progression_bart.embed_tokens = value + + def tie_weights(self): + return super().tie_weights() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.FloatTensor = None, + progression_input_ids: torch.LongTensor = None, + progression_attention_mask: torch.FloatTensor = None, + decoder_input_ids: torch.LongTensor = None, + decoder_attention_mask: torch.FloatTensor = None, + head_mask: torch.FloatTensor = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + input_pixels: torch.FloatTensor = None, + input_temporal_pixels: torch.FloatTensor = None, + temporal_mask: torch.FloatTensor = None, + matrix: torch.FloatTensor = None, + nodes: torch.LongTensor = None, + node_mask: torch.FloatTensor = None, + gather_index: torch.LongTensor = None, + gate_labels: torch.FloatTensor = None, + labels: Optional[torch.LongTensor] = None, + observations: Optional[torch.FloatTensor] = None, + encoder_outputs: Optional[ModelOutput] = None, + progressions: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + progression_input_ids=progression_input_ids, + progression_attention_mask=progression_attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + input_pixels=input_pixels, + input_temporal_pixels=input_temporal_pixels, + temporal_mask=temporal_mask, + encoder_outputs=encoder_outputs, + matrix=matrix, + nodes=nodes, + node_mask=node_mask, + labels=labels, + observations=observations, + progressions=progressions, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = outputs.last_hidden_state + lm_logits = self.lm_head(last_hidden_state) + + # Progression Reasoning (RrR) + gate, proba = self.prr( + lm_logits=lm_logits, + outputs=outputs, + gather_index=gather_index, + node_mask=node_mask, + matrix=matrix, + gate_labels=gate_labels, + nodes=nodes, + ) + loss = None + if labels is not None: + loss = self.prr_loss( + gate=gate, + gate_labels=gate_labels, + proba=proba, + labels=labels, + ) + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=proba, + past_key_values=outputs.past_key_values, + ) + + def prr( + self, + lm_logits, + outputs, + gather_index, + node_mask, + matrix, + gate_labels=None, + nodes=None, + ): + node_proba, node_weight = self.prr_model( + last_hidden_state=outputs.last_hidden_state, + node_hidden_state=outputs.node_hidden_state, + cls_hidden_state=outputs.pooler_output, + matrix=matrix, + node_mask=node_mask, + nodes=nodes, + gate_labels=gate_labels, + ) + node_proba_vocab = node_proba.gather( + -1, gather_index.unsqueeze(1).expand_as(lm_logits) + ) + # 0 represents observation + node_proba_vocab.masked_fill_(gather_index.unsqueeze(1) == 0, 0) + + gate_rep = outputs.last_hidden_state + gate = self.controller(gate_rep) + gate_mask = (node_mask.sum(dim=-1, keepdim=True) > 0).float().unsqueeze(1) + gate = gate * gate_mask + proba_vocab = torch.softmax(lm_logits, dim=-1) + proba = (1.0 - gate) * proba_vocab + gate * node_proba_vocab + proba = proba.clamp(min=1e-5, max=1.0 - 1e-5) + return gate, proba + + def prr_loss(self, gate, gate_labels, proba, labels): + loss_fct = nn.NLLLoss() + loss = loss_fct( + input=proba.log().view(-1, proba.size(-1)), + target=labels.view(-1), + ) + gate = gate.clamp(min=1e-5, max=1.0 - 1e-5) + gate_mask = gate_labels != -100 + gate_labels = gate_labels.masked_fill(~gate_mask, 0) + gate = gate.squeeze(-1) + pointer_loss = ( + -(gate_labels * gate.log() + (1.0 - gate_labels) * (1 - gate).log()) + * gate_mask + ).mean() + if gate_mask.sum() > 0: + loss = loss + pointer_loss * self.config.lambda_ + return loss + + @staticmethod + def _expand_inputs_for_generation( + input_ids: torch.LongTensor, # decoder_input_ids + expand_size: int = 1, + is_encoder_decoder: bool = False, + encoder_outputs: ModelOutput = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + expanded_return_idx = ( + torch.arange(input_ids.shape[0]) + .view(-1, 1) + .repeat(1, expand_size) + .view(-1) + .to(input_ids.device) + ) + input_ids = input_ids.index_select(0, expanded_return_idx) + + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = token_type_ids.index_select( + 0, expanded_return_idx + ) + if "temporal_mask" in model_kwargs: + temporal_mask = model_kwargs["temporal_mask"] + model_kwargs["temporal_mask"] = temporal_mask.index_select( + 0, expanded_return_idx + ) + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + model_kwargs[ + "decoder_attention_mask" + ] = decoder_attention_mask.index_select(0, expanded_return_idx) + if ( + "attention_mask" in model_kwargs + and model_kwargs["attention_mask"] is not None + ): + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = attention_mask.index_select( + 0, expanded_return_idx + ) + if "node_mask" in model_kwargs: + node_mask = model_kwargs["node_mask"] + model_kwargs["node_mask"] = node_mask.index_select(0, expanded_return_idx) + + if "gather_index" in model_kwargs: + gather_index = model_kwargs["gather_index"] + model_kwargs["gather_index"] = gather_index.index_select( + 0, expanded_return_idx + ) + + if "matrix" in model_kwargs: + matrix = model_kwargs["matrix"] + model_kwargs["matrix"] = matrix.index_select(0, expanded_return_idx) + + if is_encoder_decoder: + if encoder_outputs is None: + raise ValueError( + "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." + ) + if ( + "last_hidden_state" in encoder_outputs + and encoder_outputs["last_hidden_state"] is not None + ): + encoder_outputs["last_hidden_state"] = encoder_outputs[ + "last_hidden_state" + ].index_select(0, expanded_return_idx) + if ( + "visual_last_hidden_state" in encoder_outputs + and encoder_outputs["visual_last_hidden_state"] is not None + ): + encoder_outputs["visual_last_hidden_state"] = encoder_outputs[ + "visual_last_hidden_state" + ].index_select(0, expanded_return_idx) + if ( + "visual_attention_mask" in encoder_outputs + and encoder_outputs["visual_attention_mask"] is not None + ): + encoder_outputs["visual_attention_mask"] = encoder_outputs[ + "visual_attention_mask" + ].index_select(0, expanded_return_idx) + if ( + "node_hidden_state" in encoder_outputs + and encoder_outputs["node_hidden_state"] is not None + ): + encoder_outputs["node_hidden_state"] = encoder_outputs[ + "node_hidden_state" + ].index_select(0, expanded_return_idx) + if ( + "pooler_output" in encoder_outputs + and encoder_outputs["pooler_output"] is not None + ): + encoder_outputs["pooler_output"] = encoder_outputs[ + "pooler_output" + ].index_select(0, expanded_return_idx) + if ( + "progression_hidden_state" in encoder_outputs + and encoder_outputs["progression_hidden_state"] is not None + ): + encoder_outputs["progression_hidden_state"] = encoder_outputs[ + "progression_hidden_state" + ].index_select(0, expanded_return_idx) + encoder_outputs["progression_attention_mask"] = encoder_outputs[ + "progression_attention_mask" + ].index_select(0, expanded_return_idx) + if ( + "observation_hidden_state" in encoder_outputs + and encoder_outputs["observation_hidden_state"] is not None + ): + encoder_outputs["observation_hidden_state"] = encoder_outputs[ + "observation_hidden_state" + ].index_select(0, expanded_return_idx) + encoder_outputs["observation_attention_mask"] = encoder_outputs[ + "observation_attention_mask" + ].index_select(0, expanded_return_idx) + encoder_outputs["temporal_mask"] = encoder_outputs[ + "temporal_mask" + ].index_select(0, expanded_return_idx) + model_kwargs["encoder_outputs"] = encoder_outputs + return input_ids, model_kwargs + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past[:2] + ) + + layer_past[2:], + ) + return reordered_past + + def prepare_inputs_for_generation( + self, + # attention_mask, + decoder_input_ids, + decoder_attention_mask=None, + past=None, # substitute to `past` in transformers==4.15.0 + temporal_mask=None, + head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + node_mask=None, + nodes=None, + gather_index=None, + matrix=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past, + "attention_mask": kwargs.get("attention_mask", None), + "decoder_input_ids": decoder_input_ids, + "head_mask": head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "temporal_mask": temporal_mask, + # "decoder_attention_mask": decoder_attention_mask, + # change this to avoid caching (presumably for debugging) + "use_cache": use_cache, + "node_mask": node_mask, + "nodes": nodes, + "gather_index": gather_index, + "matrix": matrix, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + def beam_search( + self, + input_ids: torch.LongTensor, + beam_scorer: BeamScorer, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[BeamSearchOutput, torch.LongTensor]: + # init values + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length + ) + if len(stopping_criteria) == 0: + warnings.warn( + "You don't have defined any stopping_criteria, this will likely loop forever", + UserWarning, + ) + pad_token_id = ( + pad_token_id if pad_token_id is not None else self.config.pad_token_id + ) + eos_token_id = ( + eos_token_id if eos_token_id is not None else self.config.eos_token_id + ) + output_scores = ( + output_scores if output_scores is not None else self.config.output_scores + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.config.return_dict_in_generate + ) + + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + + batch_beam_size, cur_len = input_ids.shape + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + beam_indices = ( + tuple(() for _ in range(batch_beam_size)) + if (return_dict_in_generate and output_scores) + else None + ) + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = ( + model_kwargs["encoder_outputs"].get("attentions") + if output_attentions + else None + ) + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") + if output_hidden_states + else None + ) + + # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens + # of the first beam are considered to avoid sampling the exact same tokens across all beams. + beam_scores = torch.zeros( + (batch_size, num_beams), dtype=torch.float, device=input_ids.device + ) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + this_peer_finished = False # used by synced_gpus only + while True: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor( + 0.0 if this_peer_finished else 1.0 + ).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + + # NOTICE major revision of beam_search + next_token_scores = next_token_logits.log() + + next_token_scores_processed = logits_processor(input_ids, next_token_scores) + next_token_scores = next_token_scores_processed + beam_scores[ + :, None + ].expand_as(next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores_processed,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view( + batch_size, num_beams * vocab_size + ) + + # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True + ) + + next_indices = torch_int_div(next_tokens, vocab_size) + next_tokens = next_tokens % vocab_size + + # stateless + beam_outputs = beam_scorer.process( + input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=beam_indices, + ) + + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids = torch.cat( + [input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1 + ) + + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + if model_kwargs["past"] is not None: + model_kwargs["past"] = self._reorder_cache( + model_kwargs["past"], beam_idx + ) + + if return_dict_in_generate and output_scores: + beam_indices = tuple( + ( + beam_indices[beam_idx[i]] + (beam_idx[i],) + for i in range(len(beam_indices)) + ) + ) + + # increase cur_len + cur_len = cur_len + 1 + + if beam_scorer.is_done or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=beam_indices, + ) + + if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + + if self.config.is_encoder_decoder: + return BeamSearchEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + beam_indices=sequence_outputs["beam_indices"], + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return BeamSearchDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], + scores=scores, + beam_indices=sequence_outputs["beam_indices"], + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return sequence_outputs["sequences"] diff --git a/src_stage2/models/rgcn.py b/src_stage2/models/rgcn.py new file mode 100644 index 0000000..4eca65c --- /dev/null +++ b/src_stage2/models/rgcn.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn + + +class RGCNLayer(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.dropout = self.config.dropout + self.transformations = nn.ModuleList( + [ + nn.Linear(config.d_model, config.d_model, bias=False) + for _ in range(config.num_relation) + ] + ) + self.self_transformations = nn.Linear( + config.d_model, config.d_model, bias=False + ) + + def forward(self, x, matrix): + self_x = self.self_transformations(x) + flatten_matrix = matrix.view(-1, matrix.size(-2), matrix.size(-1)) + progression_x = torch.stack([trans(x) + for trans in self.transformations], dim=1) + flatten_progression_x = progression_x.view( + -1, progression_x.size(-2), progression_x.size(-1) + ) + flatten_neigh_x = flatten_matrix.bmm(flatten_progression_x) + neigh_x = flatten_neigh_x.view( + self_x.size(0), - + 1, flatten_neigh_x.size(-2), flatten_neigh_x.size(-1) + ).sum(dim=1) + norm = matrix.sum(dim=1).sum(dim=-1, keepdim=True) + norm = torch.max(norm, torch.ones_like(norm)) + neigh_x = neigh_x / norm + x = self_x + neigh_x + x = nn.functional.dropout(x, p=self.dropout, training=self.training) + return torch.relu(x) + + +class RGCN(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.config = config + self.embed_nodes = nn.Embedding( + config.num_node + 1, config.d_model, padding_idx=config.num_node + ) + self.layers = nn.ModuleList( + [RGCNLayer(config) for _ in range(config.num_rgcnlayer)] + ) + self.apply(self._init_weights) + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def forward(self, nodes, matrix): + matrix = (matrix > 0).float() + nodes = nodes.masked_fill(nodes == -100, self.config.num_node).long() + x = self.embed_nodes(nodes) + for layer in self.layers: + x = layer(x, matrix) + return x diff --git a/src_stage2/optimizer.py b/src_stage2/optimizer.py new file mode 100644 index 0000000..f45829f --- /dev/null +++ b/src_stage2/optimizer.py @@ -0,0 +1,65 @@ +import torch.nn as nn + +from transformers import AdamW +from transformers.trainer_pt_utils import get_parameter_names + +# from torch.optim import AdamW + + +def create_optimizer(model, args, fast_lr=1e-4): + """ + fast_lr: for newly inited model + """ + decay_parameters = get_parameter_names(model, [nn.LayerNorm]) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + + fast_params = [] + for n, _ in model.named_parameters(): + if not n.startswith("model.encoder.vit"): + fast_params.append(n) + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if n in decay_parameters and n not in fast_params + ], + "weight_decay": args.weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if n not in decay_parameters and n not in fast_params + ], + "weight_decay": 0.0, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if n in fast_params and n in decay_parameters + ], + "lr": fast_lr, + "weight_decay": args.weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if n in fast_params and n not in decay_parameters + ], + "lr": fast_lr, + "weight_decay": 0.0, + }, + ] + optimizer_kwargs = { + "lr": args.learning_rate, + "betas": (args.adam_beta1, args.adam_beta2), + "eps": args.adam_epsilon, + } + optimizer = AdamW( + optimizer_grouped_parameters, + **optimizer_kwargs, + ) + return optimizer diff --git a/src_stage2/run_ende.py b/src_stage2/run_ende.py new file mode 100644 index 0000000..b20fe90 --- /dev/null +++ b/src_stage2/run_ende.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python +# coding=utf-8 +import json +import logging +import os +import sys + +import datasets +import torch +from torchvision import transforms +import transformers +from transformers import ( + DataCollatorForSeq2Seq, + HfArgumentParser, + Seq2SeqTrainingArguments, + set_seed, + BertTokenizer, + BartTokenizer, + BartConfig, +) +from transformers.file_utils import WEIGHTS_NAME +from transformers.trainer_utils import get_last_checkpoint +from radgraph import F1RadGraph +from data_collator_ende import DataCollatorForEnDe as DataCollatorForSeq2Seq +from dataset_ende import DatasetCustom +from model_arguments import ModelArguments +from seq2seqtrainer_metrics_ende import Seq2SeqTrainerGenMetrics +from train_eval_ende_full import train +from transformers import ViTFeatureExtractor +from chexbert_eval import compute_ce_metric, load_chexbert, build_progression_graph +import copy +from sklearn.exceptions import UndefinedMetricWarning +import warnings +from src_stage2.models.modeling_bart import ViTBartForGeneration + +sys.path.append("../") +from src_stage1.data_arguments import DataTrainingArguments + +warnings.filterwarnings( + action="ignore", category=UndefinedMetricWarning, module="sklearn" +) + +logger = logging.getLogger(__name__) + + +def main(): + parser = HfArgumentParser( + (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments) + ) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if ( + os.path.isdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif ( + last_checkpoint is not None and training_args.resume_from_checkpoint is None + ): + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + Seq2SeqTrainer = Seq2SeqTrainerGenMetrics + + from tokenizer import Tokenizer + + data_args.dataset = ( + "mimic_abn" if "mimic_abn" in data_args.annotation_file else "mimic_cxr" + ) + + logger.info("***************************") + logger.info("***************************") + logger.info(data_args) + logger.info("***************************") + logger.info("***************************") + + logger.info("***************************") + logger.info("***************************") + logger.info(model_args) + logger.info("***************************") + logger.info("***************************") + + # load necessary data + ref_annotation = None + if data_args.miss_annotation_file is not None: + with open(data_args.miss_annotation_file, "r", encoding="utf-8") as f: + ref_annotation = json.load(f) + with open(data_args.annotation_file, "r", encoding="utf-8") as f: + annotation = json.load(f) + + # temporal information + with open(data_args.history, "r", encoding="utf-8") as f: + temporal_ids = json.load(f) + + data_args.threshold = 3 if data_args.dataset == "mimic_abn" else 10 + # ngram labels + train_idxs = {sample["id"] for sample in annotation["train"]} + # observation labels + id2tags, observation_category, observation_weight = Tokenizer.load_tag2ids( + data_args.chexbert_label, + need_header=True, + train_idxs=train_idxs, + ) + checkpoint = "GanjinZero/biobart-base" + bart_tokenizer = BartTokenizer.from_pretrained(checkpoint) + tokenizer = Tokenizer(data_args, observation_category) + + progression_graph = build_progression_graph( + progression_triples=json.load( + open(data_args.progression_graph, "r", encoding="utf-8") + ), + observations=observation_category, + topk_entity=data_args.topk, + tokenizer=tokenizer, + ) + tokenizer.id2entity = progression_graph["id2entity"] + chexbert = load_chexbert(model_args.chexbert_model_name_or_path) + bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + f1radgraph = F1RadGraph(reward_level="partial") + + config = BartConfig.from_pretrained(checkpoint) + config.num_observation = len(observation_category) + config.num_progression = 3 + config.num_rgcnlayer = 3 + config.num_relation = len(progression_graph["relation2id"]) + # config.num_entity = len(progression_graph["entity2id"]) + config.num_node = len(progression_graph["entity2id"]) + config.observation_category = observation_category + config.alpha = data_args.alpha + config.beta = data_args.beta + config.observation_weight = observation_weight + config.pretrained_visual_extractor = "google/vit-base-patch16-224-in21k" + config.topk = data_args.topk + processor = ViTFeatureExtractor.from_pretrained(config.pretrained_visual_extractor) + + config.add_cross_attention = True + + config.is_temporal = 1 + config.is_stage1_pretrained = int(data_args.is_stage1_pretrained) + + config.stage1_model_name_or_path = model_args.stage1_model_name_or_path + if int(data_args.is_stage1_pretrained) == 0: + config.stage1_model_name_or_path = None + config.decoder_model_name_or_path = checkpoint + config.num_path = 16 * 16 + 1 + config.lambda_ = data_args.lambda_ + config.id2entity = progression_graph["id2entity"] + encoder_config = config + decoder_config = copy.deepcopy(config) + + decoder_config.vocab_size = len(tokenizer.token2idx) + decoder_config.decoder_layers = 3 + decoder_config.d_model = 768 + decoder_config.decoder_ffn_dim = 768 + decoder_config.decoder_attention_heads = 8 + decoder_config.encoder_layers = 3 + decoder_config.d_model = 768 + decoder_config.encoder_ffn_dim = 768 + decoder_config.encoder_attention_heads = 8 + decoder_config.activation_function = "relu" + decoder_config.decoder_start_token_id = tokenizer.bos_token_id + decoder_config.eos_token_id = tokenizer.eos_token_id + decoder_config.bos_token_id = tokenizer.bos_token_id + decoder_config.decoder_start_token_id = tokenizer.bos_token_id + decoder_config.pad_token_id = tokenizer.pad_token_id + data_args.vocab_size = decoder_config.vocab_size + model = ViTBartForGeneration( + encoder_config=encoder_config, + decoder_config=decoder_config, + ) + model.observation_category = observation_category + model.id2entity = progression_graph["id2entity"] + data_args.vocab_size = len(tokenizer.token2idx) + data_args.stage1_model_name_or_path = model_args.stage1_model_name_or_path + data_args.stage1_eval_file = model_args.stage1_eval_file + + logger.info("***************************") + logger.info("***** Model Structure *****") + logger.info(model) + logger.info("***************************") + logger.info("***************************") + train_dataset = eval_dataset = test_dataset = None + + if data_args.debug_model: + debug_data_size = 16 + for key in temporal_ids: + ref_ids = {report["id"] for report in ref_annotation[key]} + subject_ids = list(temporal_ids[key].keys())[:debug_data_size] + temporal_ids[key] = { + subject_id: temporal_ids[key][subject_id] for subject_id in subject_ids + } + ids = set(subject_ids) + annotation[key] = [ + ann + for ann in annotation[key] + if ann["id"] in ids + and temporal_ids[key][ann["id"]]["object_id"] in ref_ids + ] + if training_args.do_train: + transform = None + train_dataset = DatasetCustom( + data_args=data_args, + annotation=annotation, + ref_annotation=ref_annotation, + temporal_ids=temporal_ids, + split="train", + id2tags=id2tags, + processor=processor, + text_tokenizer=bart_tokenizer, + tokenizer=tokenizer, + progression_graph=progression_graph, + observation_category=observation_category, + transform=transform, + ) + eval_dataset = DatasetCustom( + data_args=data_args, + annotation=annotation, + ref_annotation=ref_annotation, + temporal_ids=temporal_ids, + split="valid", + id2tags=id2tags, + processor=processor, + text_tokenizer=bart_tokenizer, + tokenizer=tokenizer, + progression_graph=progression_graph, + observation_category=observation_category, + transform=None, + ) + if training_args.do_predict: + test_dataset = DatasetCustom( + data_args=data_args, + annotation=annotation, + ref_annotation=ref_annotation, + temporal_ids=temporal_ids, + split="test", + id2tags=id2tags, + processor=processor, + text_tokenizer=bart_tokenizer, + tokenizer=tokenizer, + progression_graph=progression_graph, + observation_category=observation_category, + transform=None, + ) + data_collator = DataCollatorForSeq2Seq( + tokenizer=tokenizer, + model=model, + padding=True, + max_length=data_args.max_context_length, + pad_to_multiple_of=8, + ) + + training_args.max_tgt_length = data_args.max_tgt_length + training_args.num_beams = model_args.num_beams + training_args.fast_lr = model_args.fast_lr + training_args.remove_unused_columns = False + data_args.max_steps = training_args.max_steps + + from transformers import EarlyStoppingCallback + + trainer = Seq2SeqTrainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + data_collator=data_collator, + callbacks=[ + EarlyStoppingCallback( + early_stopping_patience=5 if data_args.dataset == "mimic_cxr" else 3, + ) + ], + ) + trainer.data_args = data_args + trainer.chexbert = chexbert + trainer.bert_tokenizer = bert_tokenizer + trainer.f1radgraph = f1radgraph + trainer.compute_ce_metric = compute_ce_metric + trainer.tokenizer = bart_tokenizer + trainer.decoder_tokenizer = tokenizer + + if training_args.do_train: + logger.info("*** Train ***") + train( + training_args, + data_args, + last_checkpoint, + trainer, + train_dataset, + ) + + # Prediction + if training_args.do_predict: + logger.info("*** Test ***") + if model_args.test_model_name_or_path is not None: + logger.info( + "*** Test: Loading %s ***" % (model_args.test_model_name_or_path) + ) + state_dict = torch.load( + os.path.join( + model_args.test_model_name_or_path, + WEIGHTS_NAME, # pytorch_model.bin + ), + map_location="cpu", + ) + model.load_state_dict(state_dict, strict=False) + model = model.cuda() + from train_eval_ende_full import eval_text + + print(model_args.num_beams) + eval_text( + max_tgt_length=data_args.max_tgt_length, + model=model, + test_dataset=trainer.get_test_dataloader(test_dataset), + output_path=training_args.output_dir, + num_beams=model_args.num_beams, + compute_ce_metric=compute_ce_metric, + chexbert=chexbert, + bert_tokenizer=bert_tokenizer, + f1radgraph=f1radgraph, + tokenizer=bart_tokenizer, + decoder_tokenizer=tokenizer, + ) + + +if __name__ == "__main__": + main() diff --git a/src_stage2/seq2seqtrainer_metrics_ende.py b/src_stage2/seq2seqtrainer_metrics_ende.py new file mode 100644 index 0000000..2f31818 --- /dev/null +++ b/src_stage2/seq2seqtrainer_metrics_ende.py @@ -0,0 +1,99 @@ +import collections +from typing import List, Optional + +from torch.utils.data import DataLoader +from transformers import Seq2SeqTrainer +from transformers.trainer_utils import EvalLoopOutput +from transformers.utils import logging + +from optimizer import create_optimizer +from train_eval_ende_full import eval_text + +logger = logging.get_logger(__name__) + + +class Seq2SeqTrainerGenMetrics(Seq2SeqTrainer): + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. + + Works both with or without labels. + """ + prediction_loss_only = ( + prediction_loss_only + if prediction_loss_only is not None + else self.args.prediction_loss_only + ) + + model = self._wrap_model(self.model, training=False) + + # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while + # ``train`` is running, halve it first and then put on device + if not self.is_in_train and self.args.fp16_full_eval: + model = model.half().to(self.args.device) + + batch_size = dataloader.batch_size + + logger.info(f"***** Running {description} *****") + if isinstance(dataloader.dataset, collections.abc.Sized): + logger.info(f" Num examples = {self.num_examples(dataloader)}") + else: + logger.info(" Num examples: Unknown") + logger.info(f" Batch size = {batch_size}") + + model.eval() + + self.callback_handler.eval_dataloader = dataloader + # Do this before wrapping. + eval_dataset = dataloader.dataset + + if self.args.past_index >= 0: + self._past = None + + metrics = eval_text( + max_tgt_length=self.args.max_tgt_length, + model=self.model, + tokenizer=self.tokenizer, + test_dataset=dataloader, + output_path=self.args.output_dir, + result_file_name="results_eval_step_%d.txt" % (self.state.global_step), + reference_file_name="reference_eval_step_%d.txt" % (self.state.global_step), + prediction_file_name="prediction_eval_step_%d.txt" + % (self.state.global_step), + num_beams=self.args.num_beams, + compute_ce_metric=self.compute_ce_metric, + chexbert=self.chexbert, + bert_tokenizer=self.bert_tokenizer, + f1radgraph=self.f1radgraph, + decoder_tokenizer=self.decoder_tokenizer, + ) + + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return EvalLoopOutput( + predictions=None, + label_ids=None, + metrics=metrics, + num_samples=len(eval_dataset), + ) + + def create_optimizer(self): + print("Create Optimizer with Different Learning Rate") + print("Slow Learning Rate\t%0.5f" % self.args.learning_rate) + print("Fast Learning Rate\t%0.5f" % self.args.fast_lr) + print("Weight Decay\t%0.5f" % self.args.weight_decay) + self.optimizer = create_optimizer( + model=self.model, + args=self.args, + fast_lr=self.args.fast_lr, + ) + return self.optimizer diff --git a/src_stage2/tokenizer.py b/src_stage2/tokenizer.py new file mode 100644 index 0000000..e75f646 --- /dev/null +++ b/src_stage2/tokenizer.py @@ -0,0 +1,255 @@ +import copy +import json +import re +from collections import Counter, defaultdict + +import pandas as pd +from transformers.tokenization_utils import PreTrainedTokenizer +import os +import pickle + + +class Tokenizer: + def __init__(self, config, observation_category=None) -> None: + self.model_input_names = ["nodes"] + self.padding_side = "right" + self.ann_path = config.annotation_file + self.threshold = config.threshold + self.dataset = config.dataset + if self.dataset == "iu_xray": + self.clean_report = Tokenizer.clean_report_iu_xray + else: + self.clean_report = Tokenizer.clean_report_mimic_cxr + print(self.clean_report) + self.ann = json.loads(open(self.ann_path, "r").read()) + self.token2idx, self.idx2token, self.special_tokens = self.create_vocabulary( + observation_category + ) + self.bos_token_id = self.eos_token_id = self.decoder_start_token_id = 0 + self.pad_token_id = 1 + self.unk_token_id = 2 + + def create_vocabulary(self, observation_category=None): + total_tokens = [] + for example in self.ann["train"]: + tokens = self.clean_report(example["report"]).split() + for token in tokens: + total_tokens.append(token) + + counter = Counter(total_tokens) + vocab = [k for k, v in counter.items() if v >= self.threshold and k != " "] + vocab.sort() + special_tokens = ["[CLS]", "[PAD]", "[UNK]"] + for observation in observation_category: + special_tokens.append("[{}:Positive]".format(observation)) + special_tokens.append("[{}:Negative]".format(observation)) + special_tokens.extend(["[First-Visit]", "[Follow-Up-Visit]"]) + special_tokens.extend(["[Better]", "[Worse]", "[Stable]"]) + vocab = special_tokens + vocab + token2idx, idx2token = {}, {} + for idx, token in enumerate(vocab): + token2idx[token] = idx + idx2token[idx] = token + # return token2idx, idx2token, special_tokens[:2] + special_tokens[-3:] + return token2idx, idx2token, special_tokens[:2] + special_tokens[3:] + + @staticmethod + def clean_report_iu_xray(report): + def report_cleaner(t): + return ( + t.replace("..", ".") + .replace("..", ".") + .replace("..", ".") + .replace("1. ", "") + .replace(". 2. ", ". ") + .replace(". 3. ", ". ") + .replace(". 4. ", ". ") + .replace(". 5. ", ". ") + .replace(" 2. ", ". ") + .replace(" 3. ", ". ") + .replace(" 4. ", ". ") + .replace(" 5. ", ". ") + .strip() + .lower() + .split(". ") + ) + + def sent_cleaner(t): + return re.sub( + "[.,?;*!%^&_+():-\[\]{}]", + "", + t.replace('"', "") + .replace("/", "") + .replace("\\", "") + .replace("'", "") + .strip() + .lower(), + ) + + tokens = [ + sent_cleaner(sent).strip() + " ." + for sent in report_cleaner(report) + if len(sent_cleaner(sent).strip()) > 0 + ] + report = " ".join(tokens) + return report + + @staticmethod + def clean_report_mimic_cxr(report): + def report_cleaner(t): + return ( + t.replace("\n", " ") + .replace("__", "_") + .replace("__", "_") + .replace("__", "_") + .replace("__", "_") + .replace("__", "_") + .replace("__", "_") + .replace("__", "_") + .replace(" ", " ") + .replace(" ", " ") + .replace(" ", " ") + .replace(" ", " ") + .replace(" ", " ") + .replace(" ", " ") + .replace("..", ".") + .replace("..", ".") + .replace("..", ".") + .replace("..", ".") + .replace("..", ".") + .replace("..", ".") + .replace("..", ".") + .replace("..", ".") + .replace("1. ", "") + .replace(". 2. ", ". ") + .replace(". 3. ", ". ") + .replace(". 4. ", ". ") + .replace(". 5. ", ". ") + .replace(" 2. ", ". ") + .replace(" 3. ", ". ") + .replace(" 4. ", ". ") + .replace(" 5. ", ". ") + .strip() + .lower() + .split(". ") + ) + + def sent_cleaner(t): + return re.sub( + "[.,?;*!%^&_+():-\[\]{}]", + "", + t.replace('"', "") + .replace("/", "") + .replace("\\", "") + .replace("'", "") + .lower() + .strip(), + ) + + tokens = [ + sent_cleaner(sent).strip() + " ." + for sent in report_cleaner(report) + if len(sent_cleaner(sent).strip()) > 0 + ] + report = " ".join(tokens) + return report + + @staticmethod + def load_tag2ids( + tag_path, + train_idxs=None, + need_header=False, + ): + cached_path = tag_path + ".pkl" + if os.path.exists(cached_path): + with open(cached_path, "rb") as f: + tags = pickle.load(f) + else: + tags = pd.read_csv(tag_path) + with open(cached_path, "wb") as f: + pickle.dump(tags, file=f) + tags = tags.replace(-1, 1).fillna(2) + diseases = list(tags)[2:] + id2tags = defaultdict(list) + weight = [0] * len(diseases) + count = [0] * len(diseases) + for i in range(len(tags)): + tag = tags.iloc[i] + idx = tag[1] + id2tags[idx] = list(tag[2:].values) + if train_idxs is not None and idx in train_idxs: + weight = [ + w + v if v in (0, 1) else w for w, v in zip(weight, id2tags[idx]) + ] + count = [ + c + 1 if v in (0, 1) else c for c, v in zip(count, id2tags[idx]) + ] + + weight = [(c - w) / max(c, 1) for w, c in zip(weight, count)] + min_weight = 0.25 + max_weight = 0.75 + weight = [max(min_weight, min(max_weight, w)) for w in weight] + if not need_header: + return id2tags, weight + else: + return id2tags, diseases, weight + + def get_token_by_id(self, id): + return self.idx2token[id] + + def get_id_by_token(self, token): + if token not in self.token2idx: + return self.token2idx["[UNK]"] + return self.token2idx[token] + + def get_vocab_size(self): + return len(self.token2idx) + + def __call__(self, report): + tokens = self.clean_report(report).split() + ids = [] + for token in tokens: + ids.append(self.get_id_by_token(token)) + ids = [self.decoder_start_token_id] + ids + [self.eos_token_id] + return ids + + def encode( + self, + report, + add_special_tokens=True, + ): + ids = [] + tokens = self.clean_report(report).split() + for token in tokens: + if token == " ": + continue + ids.append(self.get_id_by_token(token)) + if add_special_tokens: + ids = [self.decoder_start_token_id] + ids + [self.eos_token_id] + return ids + + def decode(self, ids, skip_special_tokens=True, separator=" "): + txt = [] + for i, idx in enumerate(ids): + if idx not in self.idx2token: + idx = self.unk_token_id + token = self.idx2token[idx] + if skip_special_tokens and token in self.special_tokens: + continue + txt.append(token) + return separator.join(txt) + + def batch_decode(self, ids_batch, skip_special_tokens=True, separator=" "): + out = [] + for ids in ids_batch: + out.append( + self.decode( + ids, + skip_special_tokens=skip_special_tokens, + separator=separator, + ) + ) + return out + + def save_pretrained(self, save_directory): + return "" diff --git a/src_stage2/train_eval_ende_full.py b/src_stage2/train_eval_ende_full.py new file mode 100644 index 0000000..a1a816b --- /dev/null +++ b/src_stage2/train_eval_ende_full.py @@ -0,0 +1,275 @@ +import os + +import torch +from tqdm import tqdm + +from metrics import compute_scores +import json +from sklearn.metrics import precision_recall_fscore_support +import numpy as np +from collections import defaultdict +from radgraph import F1RadGraph +from src_stage2.models.modeling_bart import ViTBartForGeneration +from chexbert_eval import CONDITIONS + + +def pad_strings(strs): + max_len = max([len(s) for s in strs]) + return [s + " " * (max_len - len(s)) for s in strs] + + +def train(training_args, data_args, last_checkpoint, trainer, train_dataset): + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() # Saves the tokenizer too for easy upload + + metrics = train_result.metrics + + max_train_samples = ( + data_args.max_train_samples + if data_args.max_train_samples is not None + else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + +def eval_text( + max_tgt_length: int, + model: ViTBartForGeneration, + tokenizer, + test_dataset, + output_path: str, + result_file_name: str = "results.txt", + reference_file_name: str = "references.txt", + prediction_file_name: str = "predictions.txt", + num_beams=None, + compute_ce_metric=None, + chexbert=None, + bert_tokenizer=None, + f1radgraph=None, + decoder_tokenizer=None, +): + model.eval() + + max_length = max_tgt_length + print("******************") + print("Text generation max length", max_length) + print("******************") + + # for all report + predictions = [] + multi_predictions = [] + references = [] + temporal_references = [] + pre_nodes = [] + ref_nodes = [] + report_ids = [] + + # for temporal_report + predictions_with_temporal = [] + references_with_temporal = [] + is_temporals = [] + test_progress = tqdm( + test_dataset, + desc="Evaluating Model (Report Generation)", + ) + if num_beams is None: + num_beams = 1 + + print("******************") + print("Beam Size", num_beams) + print("******************") + + with torch.no_grad(): + for i, batch in enumerate(test_progress): + max_length = max_tgt_length + min_length = 2 + encoder_outputs = model.get_encoder()( + input_pixels=batch["input_pixels"].cuda(), + input_temporal_pixels=batch["input_temporal_pixels"].cuda(), + input_ids=batch["input_ids"].cuda(), + attention_mask=batch["attention_mask"].cuda(), + progression_input_ids=batch["progression_input_ids"].cuda(), + progression_attention_mask=batch["progression_attention_mask"].cuda(), + temporal_mask=batch["temporal_mask"].cuda(), + matrix=batch["matrix"].cuda(), + nodes=batch["nodes"].cuda(), + node_mask=batch["node_mask"].cuda(), + ) + + model_inputs = { + "attention_mask": batch["attention_mask"].cuda(), + "temporal_mask": batch["temporal_mask"].cuda(), + "input_pixels": batch["input_pixels"].cuda(), + "node_mask": batch["node_mask"].cuda(), + "gather_index": batch["gather_index"].cuda(), + "matrix": batch["matrix"].cuda(), + "nodes": batch["nodes"].cuda(), + "num_beams": num_beams, + "max_length": max_length, + "min_length": min_length, + "decoder_start_token_id": model.config.decoder_start_token_id, + "bos_token_id": model.config.bos_token_id, + "eos_token_id": model.config.eos_token_id, + "pad_token_id": model.config.pad_token_id, + "early_stopping": True, + "return_dict_in_generate": True, + "encoder_outputs": encoder_outputs, + "num_return_sequences": num_beams, + } + outputs = model.generate(**model_inputs) + output_sequences = outputs["sequences"] + multi_prediction = decoder_tokenizer.batch_decode( + output_sequences.tolist(), + skip_special_tokens=True, + ) + prediction = [ + p for pi, p in enumerate(multi_prediction) if (pi % num_beams) == 0 + ] + labels = batch["labels"].masked_fill( + batch["labels"] == -100, + tokenizer.pad_token_id, + ) + reference = decoder_tokenizer.batch_decode( + labels.tolist(), + skip_special_tokens=True, + ) + node = [ + { + decoder_tokenizer.id2entity[n_] + for n_ in n + if n_ != -100 + and ( + ":" not in decoder_tokenizer.id2entity[n_] + or "-" in decoder_tokenizer.id2entity[n_] + ) + } + for n in batch["nodes"].tolist() + ] + selected_pre_node = [] + selected_ref_node = [] + for p, r, n in zip(prediction, reference, node): + tokens = p.split() + selected = [] + token2node = defaultdict(list) + for a in n: + tok = a.split("-")[-1] + token2node[tok].append(a) + for t in tokens: + if t in token2node: + selected.extend(token2node[t]) + selected_pre_node.append(list(set(selected))) + + tokens = r.split() + selected = [] + for t in tokens: + if t in token2node: + selected.extend(token2node[t]) + selected_ref_node.append(list(set(selected))) + if batch["progression_input_ids"] is not None: + temporal_reference = decoder_tokenizer.batch_decode( + batch["progression_input_ids"].tolist(), + skip_special_tokens=True, + ) + else: + temporal_reference = ["Empty" for _ in range(len(prediction))] + prediction = [z.strip() for z in prediction] + reference = [z.strip() for z in reference] + predictions.extend(prediction) + references.extend(reference) + temporal_references.extend(temporal_reference) + pre_nodes.extend(selected_pre_node) + ref_nodes.extend(selected_ref_node) + report_ids.extend(batch["report_ids"]) + + for pi in range(0, len(multi_prediction), num_beams): + ps = [z.strip() for z in multi_prediction[pi : pi + num_beams]] + multi_predictions.append(ps) + + predictions_with_temporal.extend( + [ + pre + for is_temporal, pre in zip(batch["is_temporal"], prediction) + if is_temporal + ] + ) + references_with_temporal.extend( + [ + ref + for is_temporal, ref in zip(batch["is_temporal"], reference) + if is_temporal + ] + ) + is_temporals.extend(batch["is_temporal"]) + assert len(references) == len(predictions), "Prediction Num != Reference Num" + + ce_scores = [0, 0, 0] + with torch.no_grad(): + ( + _, + _, + ce_scores, + temporal_ce_scores, + macro_ce_scores, + macro_temporal_ce_scores, + tem_scores, + ) = compute_ce_metric( + references=references, + hypotheses=predictions, + is_temporals=is_temporals, + chexbert=chexbert, + bert_tokenizer=bert_tokenizer, + ) + print("--------------------------------------------------------------") + print( + "Binary CE Score\t\t\t Prec. %0.4f\tRec. %0.4f\tF1 %0.4f" + % (ce_scores[0], ce_scores[1], ce_scores[2]) + ) + print( + "Binray Temporal CE Score\t Prec. %0.4f\tRec. %0.4f\tF1 %0.4f" + % (temporal_ce_scores[0], temporal_ce_scores[1], temporal_ce_scores[2]) + ) + print( + "Macro CE Score\t\t\t Prec. %0.4f\tRec. %0.4f\tF1 %0.4f" + % (macro_ce_scores[0], macro_ce_scores[1], macro_ce_scores[2]) + ) + print( + "Macro Temporal CE Score\t\t Prec. %0.4f\tRec. %0.4f\tF1 %0.4f" + % ( + macro_temporal_ce_scores[0], + macro_temporal_ce_scores[1], + macro_temporal_ce_scores[2], + ) + ) + print( + "TEM Score\t\t\t Prec. %0.4f\tRec. %0.4f\tF1 %0.4f" + % (tem_scores[0], tem_scores[1], tem_scores[2]) + ) + print("--------------------------------------------------------------") + print("--------------------------------------------------------------") + for i in range(5): + print("Sample Prediction\t%d:" % i, predictions[i]) + print("Sample Reference\t%d:" % i, references[i]) + print("--------------------------------------------------------------") + bleu_scores = compute_scores( + gts={index: [gt] for index, gt in enumerate(references)}, + res={index: [re] for index, re in enumerate(predictions)}, + ) + for score in bleu_scores: + print("%s\t%0.4f" % (score, bleu_scores[score])) + bleu_scores_with_temporal = compute_scores( + gts={index: [gt] for index, gt in enumerate(references_with_temporal)}, + res={index: [re] for index, re in enumerate(predictions_with_temporal)}, + ) + for score in bleu_scores_with_temporal: + print("temporal_%s\t%0.4f" % (score, bleu_scores_with_temporal[score])) + print("--------------------------------------------------------------") + return bleu_scores