diff --git a/src_stage1/data_arguments.py b/src_stage1/data_arguments.py index 1e13504..9f63d37 100644 --- a/src_stage1/data_arguments.py +++ b/src_stage1/data_arguments.py @@ -54,9 +54,6 @@ class DataTrainingArguments: ) chexbert_label: Optional[str] = field(default=None) debug_model: Optional[bool] = field(default=False) - max_context_length: Optional[int] = field( - default=256, - ) max_tgt_length: Optional[int] = field( default=64, ) @@ -106,5 +103,9 @@ class DataTrainingArguments: ) alpha: Optional[float] = field(default=3) beta: Optional[float] = field(default=3) + wo_op: Optional[int] = field(default=1) + wo_obs: Optional[int] = field(default=1) + wo_pro: Optional[int] = field(default=1) + wo_prr: Optional[int] = field(default=1) topk: Optional[int] = field(default=10) lambda_: Optional[float] = field(default=0.5) diff --git a/src_stage1/dataset_ende.py b/src_stage1/dataset_ende.py index bcf11d1..4107e99 100644 --- a/src_stage1/dataset_ende.py +++ b/src_stage1/dataset_ende.py @@ -1,14 +1,11 @@ import os from torch.utils.data import Dataset -from torchvision import transforms import torch from data_arguments import DataTrainingArguments from data_process_ende import process_examples -from tokenizer import Tokenizer from tqdm import tqdm from PIL import Image -from transformers import GPT2Tokenizer def load_images(root_path, image_paths): diff --git a/src_stage1/extract_report.py b/src_stage1/extract_report.py deleted file mode 100644 index 23e3b99..0000000 --- a/src_stage1/extract_report.py +++ /dev/null @@ -1,32 +0,0 @@ -import json -import sys -from tokenizer import Tokenizer -import os -from tqdm import tqdm - -input_path = sys.argv[1] -output_path = sys.argv[2] - -if not os.path.exists(output_path): - os.mkdir(output_path) - -print("**************************") -print("input_path: ", input_path) -print("output_path: ", output_path) -print("**************************") - -clean_fn = Tokenizer.clean_report_mimic_cxr -reports = {} -with open(input_path, "r", encoding="utf-8") as f: - data = json.load(f)["train"] - for report in tqdm(data, desc="Extracting reports"): - reports[report['id'] + ".txt"] = clean_fn(report['report']) - -for idx in tqdm(reports, desc="Writing reports"): - with open(os.path.join(output_path, idx), "w", encoding="utf-8") as f: - report = reports[idx] - f.write(report + '\n') - -with open(os.path.join(output_path, "filenames.txt"), "w", encoding="utf-8") as f: - for idx in reports: - f.write(os.path.join(output_path, idx) + '\n') diff --git a/src_stage1/graph_construction/pmi_observation_entity.py b/src_stage1/graph_construction/pmi_observation_entity.py new file mode 100644 index 0000000..84f0ba3 --- /dev/null +++ b/src_stage1/graph_construction/pmi_observation_entity.py @@ -0,0 +1,104 @@ +from tqdm import tqdm +import pandas as pd +import argparse +import json +import math +from collections import defaultdict +import os +from nltk.corpus import stopwords +import numpy as np +import sys + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from tokenizer import Tokenizer + +parser = argparse.ArgumentParser() +parser.add_argument("--dataset", type=str, required=True, help="the name of dataset") +parser.add_argument("--chexbert_label", type=str, required=True) +parser.add_argument("--output_dir", type=str, required=True) +parser.add_argument("--pmi_threshold", type=float, required=True) +parser.add_argument("--min_count", type=int, default=5, help="min_count") + +config = parser.parse_args() + +print("dataset: ", config.dataset) +clean_fn = Tokenizer.clean_report_mimic_cxr +print(clean_fn) +id2observation, observation_category, _ = Tokenizer.load_tag2ids( + config.chexbert_label, need_header=True +) +print(len(id2observation), observation_category) + +sem_stat = json.load( + open( + os.path.join(config.output_dir, config.dataset, "sem_stat.json"), + "r", + encoding="utf-8", + ) +) +tem_stat = json.load( + open( + os.path.join(config.output_dir, config.dataset, "tem_stat.json"), + "r", + encoding="utf-8", + ) +) + +swords = stopwords.words("english") +sem_stat_keep = { + k: { + subk + for subk, subv in v.items() + if subk not in swords + and not subk.startswith("_") + and subv >= config.min_count + and subk not in tem_stat + } + for k, v in sem_stat.items() +} + + +with open( + os.path.join(config.output_dir, config.dataset, "id2entity.json"), + "r", + encoding="utf-8", +) as f: + id2entity = json.load(f) +observation_stat = defaultdict(int) +observation_ngram_stat = defaultdict(int) +observation_ngram_norm = defaultdict(int) + +sem_stat_all = sem_stat["ALL"] +p_y_norm = sum(sem_stat_all.values()) +sem_stat.pop("ALL") +p_y_x = {} +for obs in sem_stat: + p_y_x[obs] = { + k: v / sum(sem_stat[obs].values()) + for k, v in sem_stat[obs].items() + if k in sem_stat_keep[obs] + } + +p_y = {k: v / p_y_norm for k, v in sem_stat_all.items()} + +pmi = {} +k = 1 +for observation in p_y_x: + for ent in p_y_x[observation]: + if ent not in p_y: + continue + pmi_xy = math.log(p_y_x[observation][ent] / p_y[ent], 2) + if pmi_xy <= config.pmi_threshold: + continue + pmi[(observation, ent)] = pmi_xy + +new_pairs = {} +for key in pmi: + new_pairs["@".join(key)] = pmi[key] + +with open( + os.path.join(config.output_dir, config.dataset, "obs2sem.json"), + "w", + encoding="utf-8", +) as f: + json.dump(new_pairs, f, ensure_ascii=False, indent=4) diff --git a/src_stage1/graph_construction/pmi_observation_ngram.py b/src_stage1/graph_construction/pmi_observation_ngram.py deleted file mode 100644 index 5d29703..0000000 --- a/src_stage1/graph_construction/pmi_observation_ngram.py +++ /dev/null @@ -1,207 +0,0 @@ -from tqdm import tqdm -import pandas as pd -import argparse -import json -import math -from collections import defaultdict -import os -from nltk.corpus import stopwords -import numpy as np -import sys - -sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from tokenizer import Tokenizer - -parser = argparse.ArgumentParser() -parser.add_argument("--dataset", type=str, required=True, help="the name of dataset") -parser.add_argument("--chexbert_label", type=str, required=True, help="the output path") -parser.add_argument("--output_dir", type=str, required=True, help="the output path") -parser.add_argument( - "--pmi_threshold", type=float, required=True, help="the output path" -) -parser.add_argument("--min_count", type=int, default=5, help="min_count") -parser.add_argument("--max_count", type=int, default=1000, help="max_count") - -config = parser.parse_args() - -print("dataset: ", config.dataset) -clean_fn = Tokenizer.clean_report_mimic_cxr -print(clean_fn) -id2observation, observation_category, _ = Tokenizer.load_tag2ids( - config.chexbert_label, need_header=True -) -print(len(id2observation), observation_category) - -sem_stat = json.load( - open( - os.path.join(config.output_dir, config.dataset, "sem_stat.json"), - "r", - encoding="utf-8", - ) -) -tem_stat = json.load( - open( - os.path.join(config.output_dir, config.dataset, "tem_stat.json"), - "r", - encoding="utf-8", - ) -) - -swords = stopwords.words("english") -sem_stat_keep = { - k: { - subk - for subk, subv in v.items() - if subk not in swords - and not subk.startswith("_") - and subv >= config.min_count - and subk not in tem_stat - } - for k, v in sem_stat.items() -} - - -with open( - os.path.join(config.output_dir, config.dataset, "id2entity.json"), - "r", - encoding="utf-8", -) as f: - id2entity = json.load(f) -observation_stat = defaultdict(int) -observation_ngram_stat = defaultdict(int) -observation_ngram_norm = defaultdict(int) - -sem_stat_all = sem_stat["ALL"] -p_y_norm = sum(sem_stat_all.values()) -sem_stat.pop("ALL") -p_y_x = {} -for obs in sem_stat: - p_y_x[obs] = { - k: v / sum(sem_stat[obs].values()) - for k, v in sem_stat[obs].items() - if k in sem_stat_keep[obs] - # if k in sem_stat_keep["ALL"] - # if v / sum(sem_stat[obs].values()) > 5e-4 - } - -# p_x_y = {} -# for obs in sem_stat: -# p_x_y[obs] = { -# k: v / sem_stat_all[k] -# for k, v in sem_stat[obs].items() -# if k in sem_stat_keep[obs] -# } - -p_y = {k: v / p_y_norm for k, v in sem_stat_all.items()} - -pmi = {} -k = 1 -rich_entity = defaultdict(list) -max_pmi = defaultdict(float) -min_pmi = {} -avg_pmi = defaultdict(list) -for observation in p_y_x: - for ent in p_y_x[observation]: - if ent not in p_y: - continue - pmi_xy = math.log(p_y_x[observation][ent] / p_y[ent], 2) - if pmi_xy <= config.pmi_threshold: - continue - pmi[(observation, ent)] = pmi_xy - max_pmi[observation] = max(max_pmi[observation], pmi_xy) - if observation not in min_pmi: - min_pmi[observation] = pmi_xy - else: - min_pmi[observation] = min(min_pmi[observation], pmi_xy) - avg_pmi[observation].append(pmi_xy) -# p_y_x_keep = defaultdict(list) -# for observation in p_y_x: -# proba = sorted(p_y_x[observation].items(), key=lambda x: x[1], reverse=True) -# acc_proba = 0 -# for ent, p in proba: -# if acc_proba < 0.2: -# acc_proba += p -# print(observation, ent, p) -# else: -# p_y_x_keep[observation].append(ent) - -# rich_entity = {k for k, v in rich_entity.items() if len(set([z for z in v if "No Finding" not in z])) > 2} - -# except Exception as err: -# print("Error", err, xy) -# print(p_y_x) -gloabl_pmi, global_count = 0, 0 -for k, v in avg_pmi.items(): - gloabl_pmi += sum(v) - global_count += len(v) -gloabl_pmi = gloabl_pmi / global_count -median_pmi = {k: np.median(v) for k, v in avg_pmi.items()} -avg_pmi = {k: sum(v) / len(v) for k, v in avg_pmi.items()} -for observation in max_pmi: - print( - "Spatial Max/Min/Avg/Median PMI", - observation, - round(max_pmi[observation], 3), - round(min_pmi[observation], 3), - round(avg_pmi[observation], 3), - round(median_pmi[observation], 3), - ) -# new_pmi = {} -# for xy in pmi: -# observation, ent = xy -# # if pmi[xy] >= 1 and pmi[xy] >= 0.75 * max_pmi[observation]: -# # if pmi[xy] >= gloabl_pmi and pmi[xy] >= 0.5 * max_pmi[observation]: -# # if pmi[xy] >= median_pmi[observation] and pmi[xy] >= 1: -# # if pmi[xy] >= 1: -# # if pmi[xy] >= 0.75 * max_pmi[observation]: -# # if pmi[xy] >= avg_pmi[observation]: -# if pmi[xy] >= 0.5: -# # and ent in p_y_x_keep[observation]: -# # if pmi[xy] >= 0.75 * max_pmi[observation]: -# # if pmi[xy] >= median_pmi[observation]: -# new_pmi[xy] = pmi[xy] -# pmi = new_pmi - -# sorted_pmi = sorted(pmi.items(), key=lambda x: sem_stat[x[0][0]][x[0][1]], reverse=True) -# sorted_pmi = sorted(pmi.items(), key=lambda x: x[1], reverse=True) -# pmi = {} -# tmp = defaultdict(list) -# saved_entity = {"Positive": set(), "Negative": set()} -# for key, value in sorted_pmi: -# observation, ent = key -# # if "No Finding" in observation: -# # continue -# if ent not in saved_entity[observation.split(":")[1]]: -# tmp[observation].append((ent, value)) -# saved_entity[observation.split(":")[1]].add(ent) - -# # for key, value in sorted_pmi: -# # observation, ent = key -# # if "No Finding" not in observation: -# # continue -# # query = observation.split(":")[1] -# # if query == "Positive": -# # query = "Negative" -# # else: -# # query = "Positive" -# # if ent not in saved_entity[query]: -# # tmp[observation].append((ent, value)) -# # saved_entity[query].add(ent) - - -# for key in tmp: -# for ent, value in tmp[key]: -# pmi[(key, ent)] = value - -new_pairs = {} -for key in pmi: - # if key[1] in rich_entity: - # continue - new_pairs["@".join(key)] = pmi[key] - -with open( - os.path.join(config.output_dir, config.dataset, "obs2sem.json"), - "w", - encoding="utf-8", -) as f: - json.dump(new_pairs, f, ensure_ascii=False, indent=4) diff --git a/src_stage1/graph_construction/pmi_progression_ngram.py b/src_stage1/graph_construction/pmi_progression_entity.py similarity index 52% rename from src_stage1/graph_construction/pmi_progression_ngram.py rename to src_stage1/graph_construction/pmi_progression_entity.py index d0707eb..d15b124 100644 --- a/src_stage1/graph_construction/pmi_progression_ngram.py +++ b/src_stage1/graph_construction/pmi_progression_entity.py @@ -14,16 +14,13 @@ parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, required=True, help="the name of dataset") -parser.add_argument("--chexbert_label", type=str, required=True, help="the output path") -parser.add_argument("--output_dir", type=str, required=True, help="the output path") +parser.add_argument("--chexbert_label", type=str, required=True) +parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--min_count", type=int, default=5, help="min_count") -parser.add_argument( - "--pmi_threshold", type=float, required=True, help="the output path" -) -parser.add_argument( - "--temporal_id_dir", type=str, required=True, help="the output path" -) +parser.add_argument("--pmi_threshold", type=float, required=True) +parser.add_argument("--temporal_id_dir", type=str, required=True) config = parser.parse_args() + print("dataset: ", config.dataset) clean_fn = Tokenizer.clean_report_mimic_cxr print(clean_fn) @@ -66,18 +63,12 @@ def tag2obs(x, y): for idx in temporal_id: if len(temporal_id[idx]["predicate"]) == 0: continue - # observations = [ - # a for a, b in zip(observation_category, id2observation[idx]) if b == 1 - # ] observations = [ tag2obs(x, y) for x, y in zip(id2observation[idx], observation_category) if x != 2 ] progressions = set(temporal_id[idx]["predicate"]) - # for observation in temporal_id[idx]["predicate"]: - # status = temporal_id[idx]["predicate"][observation] - # progressions.update(status) observation2progression = [] for observation in observations: for progression in progressions: @@ -103,21 +94,16 @@ def tag2obs(x, y): continue for pro in progressions: progression_ngram_stat[(pro, ent)] += 1 - # p_xy_norm += 1 progression_ngram_norm[pro] += 1 p_y_norm = sum(sem_stat["ALL"].values()) swords = stopwords.words("english") p_xy = { - x[0]: x[1] / progression_ngram_norm[x[0][0]] - for x in progression_ngram_stat.items() + x[0]: x[1] / progression_ngram_norm[x[0][0]] for x in progression_ngram_stat.items() } p_x = {x[0]: x[1] / p_x_norm for x in progression_stat.items()} p_y = {x[0]: x[1] / p_y_norm for x in tem_stat.items()} pmi = {} k = 1 -max_pmi = defaultdict(float) -avg_pmi = defaultdict(list) -min_pmi = {} for xy in p_xy: observation, ent = xy if "No Finding" in observation or "Support Device" in observation: @@ -127,45 +113,9 @@ def tag2obs(x, y): if pmi_xy <= config.pmi_threshold: continue pmi[xy] = pmi_xy - max_pmi[observation] = max(max_pmi[observation], pmi_xy) - avg_pmi[observation].append(pmi_xy) - if observation not in min_pmi: - min_pmi[observation] = pmi_xy - else: - min_pmi[observation] = min(min_pmi[observation], pmi_xy) except Exception as err: print("Error", err, xy) -gloabl_pmi, global_count = 0, 0 -for k, v in avg_pmi.items(): - gloabl_pmi += sum(v) - global_count += len(v) -gloabl_pmi = gloabl_pmi / global_count -median_pmi = {k: np.median(v) for k, v in avg_pmi.items()} -avg_pmi = {k: sum(v) / len(v) for k, v in avg_pmi.items()} -for observation in max_pmi: - print( - "Temporal Max/Min/Avg/Median PMI", - observation, - round(max_pmi[observation], 3), - round(min_pmi[observation], 3), - round(avg_pmi[observation], 3), - round(median_pmi[observation], 3), - ) -# new_pmi = {} -# for xy in pmi: -# observation, ent = xy -# # if pmi[xy] >= 1 and pmi[xy] >= 0.75 * max_pmi[observation]: -# # if pmi[xy] >= gloabl_pmi and pmi[xy] >= 0.5 * max_pmi[observation]: -# # if pmi[xy] >= median_pmi[observation] and pmi[xy] >= 1: -# # if pmi[xy] >= avg_pmi[observation]: -# # if pmi[xy] >= 1: -# # if pmi[xy] >= 1: -# if pmi[xy] >= 0.75 * max_pmi[observation]: -# # if pmi[xy] >= avg_pmi[observation]: -# new_pmi[xy] = pmi[xy] -# pmi = new_pmi - new_pairs = {} for key in pmi: new_pairs["@".join(key)] = pmi[key] @@ -187,55 +137,15 @@ def tag2obs(x, y): obs_pmi = defaultdict(list) new_pmi = defaultdict(list) -# max_obs_val = defaultdict(float) -# for key, val in obs2sem.items(): -# obs, entity = key.split("@") -# max_obs_val[obs] = max(max_obs_val[obs], val) - -# obs2sem = sorted(obs2sem.items(), key=lambda x: x[1], reverse=True) -# saved_entity = {"Positive":defaultdict(int), "Negative": defaultdict(int)} -# entity_stat = {} -# for key, val in obs2sem.items(): -# obs, entity = key.split("@") -# obs_wo_status = obs.split(":")[0] -# if obs_wo_status not in entity_stat: -# entity_stat[obs_wo_status] = defaultdict(int) -# entity_stat[obs_wo_status][entity] += 1 -# # status = "Negative" if "Negative" in obs else "Positive" -# # if saved_entity[status][entity] >= 2: -# # continue for key, val in obs2sem.items(): obs, entity = key.split("@") - # obs_wo_status = obs.split(":")[0] - # if entity_stat[obs_wo_status][entity] >= 2: - # continue obs_pmi[obs].append((entity, val)) - # saved_entity[status][entity] += 1 for key, val in new_pairs.items(): pro, entity = key.split("@") new_pmi[pro].append((entity, val)) -# new_pmi = {k: sorted(v, key=lambda x: x[1], reverse=True) for k, v in new_pmi.items()} obs_pmi.update(new_pmi) -# for obs in obs_pmi: -# if "Positive" in obs: -# pos_entity = {} -# neg_entity = {} -# neg_obs = obs.replace("Positive", "Negative") -# pos_val = {x[0]: x[1] for x in obs_pmi[obs]} -# neg_val = {x[0]: x[1] for x in obs_pmi[neg_obs]} -# for entity in pos_val: -# if entity not in neg_val or pos_val[entity] > neg_val[entity]: -# pos_entity[entity] = pos_val[entity] -# else: -# neg_entity[entity] = neg_val[entity] -# neg_entity.update( -# {k: v for k, v in neg_val.items() if k not in pos_entity}) -# obs_pmi[obs] = [(k, v) for k, v in pos_entity.items()] -# obs_pmi[neg_obs] = [(k, v) for k, v in neg_entity.items()] - - obs_pmi = {k: sorted(v, key=lambda x: x[1], reverse=True) for k, v in obs_pmi.items()} obs_pmi = {k: [x[0] for x in v] for k, v in obs_pmi.items()} diff --git a/src_stage1/graph_construction/pmi_ngram.py b/src_stage1/graph_construction/prepare_stat.py similarity index 72% rename from src_stage1/graph_construction/pmi_ngram.py rename to src_stage1/graph_construction/prepare_stat.py index b11b6b4..0e48f4a 100644 --- a/src_stage1/graph_construction/pmi_ngram.py +++ b/src_stage1/graph_construction/prepare_stat.py @@ -4,7 +4,7 @@ import argparse import os from constants import TEM_KEYWORDS -from nltk.corpus import stopwords, wordnet +from nltk.corpus import stopwords import sys sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -13,8 +13,8 @@ parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, required=True, help="the name of dataset") parser.add_argument("--output_dir", type=str, required=True, help="the output path") -parser.add_argument("--min_count", type=int, default=5, help="min_count") -parser.add_argument("--chexbert_label", type=str, required=True, help="the output path") +parser.add_argument("--chexbert_label", type=str, required=True) +parser.add_argument("--radgraph_dir", type=str, required=True) def tag2obs(x, y): @@ -45,7 +45,6 @@ def filter_fn(s): config = parser.parse_args() dataset = config.dataset - min_count = config.min_count min_frequent = 3 if "mimic_abn" in dataset else 10 print("dataset: ", dataset) @@ -73,11 +72,7 @@ def filter_fn(s): key_tuple = defaultdict(int) window = 0 if not os.path.exists(sem_path): - with open( - "/home/wenjun/repo/report_gen/physionet.org/files/radgraph/1.0.0/MIMIC-CXR_graphs.json", - "r", - encoding="utf-8", - ) as f: + with open(config.radgraph_dir, "r", encoding="utf-8") as f: radgraph = json.load(f) for key in radgraph: if key in collect_ids: @@ -87,7 +82,6 @@ def filter_fn(s): continue for relation in entity["relations"]: if "modify" in relation or "located_at" in relation: - # if "modify" in relation: k1 = [ z.strip().lower() for z in entity["tokens"].split() @@ -114,7 +108,7 @@ def filter_fn(s): f.write(a + "\n") for b in key_tuple: - f.write("-modify-".join(b) + "," + str(key_tuple[b]) + "\n") + f.write("-rel-".join(b) + "," + str(key_tuple[b]) + "\n") else: with open(sem_path, "r", encoding="utf-8") as f: window = int(f.readline().strip()) @@ -122,20 +116,16 @@ def filter_fn(s): line = line.strip() if len(line) == 0: continue - if "-modify-" in line: + if "-rel-" in line: line = line.split(",") - key_tuple[tuple(line[0].split("-modify-"))] = int(line[1]) + key_tuple[tuple(line[0].split("-rel-"))] = int(line[1]) else: spatial_keywords.add(line) spatial_keywords = set() for k1, k2 in key_tuple: - if key_tuple[(k1, k2)] <= (50 if "mimic_abn" in dataset else 50): - # if key_tuple[(k1, k2)] <= 50: + if key_tuple[(k1, k2)] <= 50: continue - # if k1 not in TEM_KEYWORDS and key_tuple[(k1, k2)] >= (50 if "mimic_abn" in dataset else 200): - # if key_tuple[(k1, k2)] >= (50 if "mimic_abn" in dataset else 200): - # spatial_keywords.add(k1) if k1 not in TEM_KEYWORDS: spatial_keywords.add(k1) if k2 not in TEM_KEYWORDS: @@ -200,9 +190,7 @@ def filter_fn(s): observation = [ o for o in sentences[pos]["observation"] if o in observations ] - # if len(observation) == 0: - if True: - observation.append(no_finding) + observation.append(no_finding) for obs in observations: if obs not in sem_stat_all: sem_stat_all[obs] = defaultdict(int) @@ -212,50 +200,6 @@ def filter_fn(s): else: sem_stat_all[obs]["_" + token] += 1 - # max_stat = defaultdict(int) - # for obs in sem_stat_all: - # # if "No Finding" in obs: - # # continue - # flip_obs = ( - # obs.replace("Positive", "Negative") - # if "Positive" in obs - # else obs.replace("Negative", "Positive") - # ) - # stat = {} - # for k, v in sem_stat_all[obs].items(): - # if k.startswith("_"): - # continue - # count = v + sem_stat_all[flip_obs].get(k, 0) - # max_stat[k] = max(max_stat[k], count) - # new_sem_stat_all = {} - # for obs in sem_stat_all: - # if "No Finding" in obs: - # continue - # flip_obs = ( - # obs.replace("Positive", "Negative") - # if "Positive" in obs - # else obs.replace("Negative", "Positive") - # ) - # stat = {} - # for k, v in sem_stat_all[obs].items(): - # count = v + sem_stat_all[flip_obs].get(k, 0) - # # count2 = sem_stat_all["No Finding:Negative"].get(k, 0) + sem_stat_all[ - # # "No Finding:Positive" - # # ].get(k, 0) - # # if not k.startswith("_") and count <= sem_stat[k] // len(observation_category) * 2: - # # if not k.startswith("_") and (count <= count2 * 0.5): - # if not k.startswith("_") and count <= max_stat[k] * 0.5: - # # print(obs, k, count, max_stat[k] * 0.25, count2 * 0.75) - # # if not k.startswith("_") and count <= count2: - # k = "_" + k - # stat[k] = v + sem_stat_all[obs].get(k, 0) - # if k not in stat: - # stat[k] = v - # new_sem_stat_all[obs] = stat - # new_sem_stat_all["No Finding:Positive"] = sem_stat_all["No Finding:Positive"] - # new_sem_stat_all["No Finding:Negative"] = sem_stat_all["No Finding:Negative"] - # sem_stat_all = new_sem_stat_all - sem_stat_all["ALL"] = sem_stat sem_stat_all = { k: { diff --git a/src_stage1/graph_construction/run_mimic_abn.sh b/src_stage1/graph_construction/run_mimic_abn.sh index a2b2471..909f410 100644 --- a/src_stage1/graph_construction/run_mimic_abn.sh +++ b/src_stage1/graph_construction/run_mimic_abn.sh @@ -1,37 +1,36 @@ -version="20230901" +version="2023xxxx" dataset="mimic_abn" +radgraph_dir="radgraph/1.0.0/MIMIC-CXR_graphs.json" min_count=200 -max_count=5000000 -chexbert_label="../CheXbert/src/data/$dataset/id2tag_ref_64.csv" +chexbert_label="./CheXbert/$dataset/id2tag.csv" output_dir="data/$version/" temporal_id_dir="../$dataset/temporal_ids.json" mkdir -p $output_dir$dataset echo "================================================================" -echo "Step1: running python src_stage1/graph_construction/pmi_ngram.py" +echo "Step1: running python src_stage1/graph_construction/prepare_stat.py" echo "================================================================" -python src_stage1/graph_construction/pmi_ngram.py \ +python src_stage1/graph_construction/prepare_stat.py \ --dataset $dataset \ --chexbert_label $chexbert_label \ --output_dir $output_dir \ - --min_count $min_count + --radgraph_dir $radgraph_dir echo "================================================================" -echo "Step 2: running python src_stage1/graph_construction/pmi_observation_ngram.py" +echo "Step 2: running python src_stage1/graph_construction/pmi_observation_entity.py" echo "================================================================" -python src_stage1/graph_construction/pmi_observation_ngram.py \ +python src_stage1/graph_construction/pmi_observation_entity.py \ --dataset $dataset \ --chexbert_label $chexbert_label \ --output_dir $output_dir \ --pmi_threshold 0 \ - --min_count $min_count \ - --max_count $max_count + --min_count $min_count echo "================================================================" -echo "Step 3: running python src_stage1/graph_construction/pmi_progression_ngram.py" +echo "Step 3: running python src_stage1/graph_construction/pmi_progression_entity.py" echo "================================================================" -python src_stage1/graph_construction/pmi_progression_ngram.py \ +python src_stage1/graph_construction/pmi_progression_entity.py \ --dataset $dataset \ --chexbert_label $chexbert_label \ --output_dir $output_dir \ diff --git a/src_stage1/models/__pycache__/activations.cpython-39.pyc b/src_stage1/models/__pycache__/activations.cpython-39.pyc deleted file mode 100644 index 13ef079..0000000 Binary files a/src_stage1/models/__pycache__/activations.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/generation_utils.cpython-39.pyc b/src_stage1/models/__pycache__/generation_utils.cpython-39.pyc deleted file mode 100644 index 88765d4..0000000 Binary files a/src_stage1/models/__pycache__/generation_utils.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/ive.cpython-39.pyc b/src_stage1/models/__pycache__/ive.cpython-39.pyc deleted file mode 100644 index 98d3854..0000000 Binary files a/src_stage1/models/__pycache__/ive.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/layers.cpython-39.pyc b/src_stage1/models/__pycache__/layers.cpython-39.pyc deleted file mode 100644 index 646b64e..0000000 Binary files a/src_stage1/models/__pycache__/layers.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/masked_resnet.cpython-39.pyc b/src_stage1/models/__pycache__/masked_resnet.cpython-39.pyc deleted file mode 100644 index c7d1445..0000000 Binary files a/src_stage1/models/__pycache__/masked_resnet.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/modeling_bart.cpython-39.pyc b/src_stage1/models/__pycache__/modeling_bart.cpython-39.pyc deleted file mode 100644 index 7f19700..0000000 Binary files a/src_stage1/models/__pycache__/modeling_bart.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/modeling_bart_drl.cpython-39.pyc b/src_stage1/models/__pycache__/modeling_bart_drl.cpython-39.pyc deleted file mode 100644 index ac2ea48..0000000 Binary files a/src_stage1/models/__pycache__/modeling_bart_drl.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/modeling_bart_drl_base.cpython-39.pyc b/src_stage1/models/__pycache__/modeling_bart_drl_base.cpython-39.pyc deleted file mode 100644 index 1c703d3..0000000 Binary files a/src_stage1/models/__pycache__/modeling_bart_drl_base.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/modeling_bart_drl_outlined.cpython-39.pyc b/src_stage1/models/__pycache__/modeling_bart_drl_outlined.cpython-39.pyc deleted file mode 100644 index ba75e22..0000000 Binary files a/src_stage1/models/__pycache__/modeling_bart_drl_outlined.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/modeling_bart_outline_aware.cpython-39.pyc b/src_stage1/models/__pycache__/modeling_bart_outline_aware.cpython-39.pyc deleted file mode 100644 index 95bca87..0000000 Binary files a/src_stage1/models/__pycache__/modeling_bart_outline_aware.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/modeling_bart_outlined.cpython-39.pyc b/src_stage1/models/__pycache__/modeling_bart_outlined.cpython-39.pyc deleted file mode 100644 index 4821581..0000000 Binary files a/src_stage1/models/__pycache__/modeling_bart_outlined.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/modeling_bart_variant.cpython-39.pyc b/src_stage1/models/__pycache__/modeling_bart_variant.cpython-39.pyc deleted file mode 100644 index d4ae478..0000000 Binary files a/src_stage1/models/__pycache__/modeling_bart_variant.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/modeling_beit.cpython-39.pyc b/src_stage1/models/__pycache__/modeling_beit.cpython-39.pyc deleted file mode 100644 index f9ea321..0000000 Binary files a/src_stage1/models/__pycache__/modeling_beit.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/modeling_gpt2.cpython-39.pyc b/src_stage1/models/__pycache__/modeling_gpt2.cpython-39.pyc deleted file mode 100644 index 9e2b1b6..0000000 Binary files a/src_stage1/models/__pycache__/modeling_gpt2.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/modeling_guided_bart.cpython-39.pyc b/src_stage1/models/__pycache__/modeling_guided_bart.cpython-39.pyc deleted file mode 100644 index eca5107..0000000 Binary files a/src_stage1/models/__pycache__/modeling_guided_bart.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/modeling_vit.cpython-39.pyc b/src_stage1/models/__pycache__/modeling_vit.cpython-39.pyc deleted file mode 100644 index 5fc88e2..0000000 Binary files a/src_stage1/models/__pycache__/modeling_vit.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/vae.cpython-39.pyc b/src_stage1/models/__pycache__/vae.cpython-39.pyc deleted file mode 100644 index 4cfe76f..0000000 Binary files a/src_stage1/models/__pycache__/vae.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/vision_encoder_decoder.cpython-39.pyc b/src_stage1/models/__pycache__/vision_encoder_decoder.cpython-39.pyc deleted file mode 100644 index 1e12d07..0000000 Binary files a/src_stage1/models/__pycache__/vision_encoder_decoder.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/__pycache__/von_mises_fisher.cpython-39.pyc b/src_stage1/models/__pycache__/von_mises_fisher.cpython-39.pyc deleted file mode 100644 index fca8f1e..0000000 Binary files a/src_stage1/models/__pycache__/von_mises_fisher.cpython-39.pyc and /dev/null differ diff --git a/src_stage1/models/modeling_gpt2.py b/src_stage1/models/modeling_gpt2.py deleted file mode 100644 index 77c2a7f..0000000 --- a/src_stage1/models/modeling_gpt2.py +++ /dev/null @@ -1,140 +0,0 @@ -from typing import Optional, Tuple, Dict, Any - -import torch -import torch.nn as nn -from transformers import GPT2LMHeadModel -import torchvision.models as models -from transformers import PreTrainedModel, ViTModel -from transformers.modeling_outputs import ModelOutput -from dataclasses import dataclass - - -@dataclass -class VisualOutput(ModelOutput): - loss: Optional[torch.FloatTensor] = None - observation_logits: torch.FloatTensor = None - progression_logits: torch.FloatTensor = None - - -class VisualExtractor(nn.Module): - def __init__(self, visual_extractor): - super(VisualExtractor, self).__init__() - model = getattr(models, visual_extractor)(pretrained=True) - # num_fts = model.fc.in_features - # model.fc = nn.Linear(num_fts, 512, bias=False) - # medclip_state_dict = torch.load( - # "../CLIP/pretrained/medclip-resnet/clip_resnet50.bin" - # ) - # model.load_state_dict(medclip_state_dict.state_dict()) - modules = list(model.children()) - self.model = nn.Sequential(*modules[:-2]) - - def forward(self, images): - patch_feats = self.model(images) - batch_size, feat_size, _, _ = patch_feats.shape - patch_feats = patch_feats.reshape( - batch_size, - feat_size, - -1, - ).permute(0, 2, 1) - return patch_feats - - -class VisualEncoder(PreTrainedModel): - def __init__( - self, - config, - visual_extractor, - ): - super().__init__(config) - visual_extractor_name, d_visual = visual_extractor - # self.visual_extractor = VisualExtractor(visual_extractor_name) - self.visual_extractor = ViTModel.from_pretrained( - "google/vit-base-patch16-224-in21k" - ) - self.observation_cls = nn.Linear( - self.visual_extractor.config.hidden_size, config.num_observation - ) - - def encode_image(self, input_pixels, observations=None): - observation_hidden_states = None - attention_mask = None - visual_outputs = self.visual_extractor(input_pixels) - pooler_output = visual_outputs.pooler_output - # observation_attn_weight = torch.softmax( - # self.observation_attn(image_hidden_states), dim=1 - # ) - # observation_hidden_states = observation_attn_weight.permute(0, 2, 1).bmm( - # image_hidden_states - # ) - # observation_logits = self.observation_cls(image_hidden_states.mean(dim=1)) - observation_logits = self.observation_cls(pooler_output) - # observation_hidden_states = torch.stack( - # [ - # self.observation_transformations[i](image_hidden_states) - # for i in range(self.config.num_observation) - # ], - # dim=1, - # ).mean(dim=2) - # image_hidden_states = self.feature_space_transformation_nn(image_hidden_states) - - # if observations is not None: - # observation_mask = observations - # else: - # observation_mask = (observation_logits > 0).float() - # attention_mask = torch.cat( - # (torch.ones_like(image_hidden_states[..., 0]), observation_mask), dim=-1 - # ) - # image_hidden_states = torch.cat( - # (image_hidden_states, observation_hidden_states), dim=1 - # ) - return observation_hidden_states, attention_mask, observation_logits - - def forward( - self, - input_pixels: torch.FloatTensor = None, - input_temporal_pixels: torch.FloatTensor = None, - temporal_mask: torch.FloatTensor = None, - observations: Optional[torch.FloatTensor] = None, - progressions: Optional[torch.FloatTensor] = None, - ): - observation_logits = None - progression_logits = None - ( - obs_hidden_states, - _, - observation_logits, - ) = self.encode_image(input_pixels, observations) - # ( - # prior_obs_hidden_states, - # _, - # _, - # ) = self.encode_image(input_temporal_pixels, observations) - - # progression_hidden_states = torch.cat( - # (obs_hidden_states, prior_obs_hidden_states), dim=-1 - # )[:, :-2] - # progression_logits = self.progression_cls(progression_hidden_states) - - loss = None - if observations is not None: - weight = torch.ones_like(observations) + self.config.alpha * observations - loss_fct = nn.BCEWithLogitsLoss(weight=weight.view(-1)) - loss = loss_fct( - observation_logits.view(-1), - observations.view(-1), - ) - - if progressions is not None and False: - loss_fct = nn.CrossEntropyLoss() - progression_loss = loss_fct( - progression_logits.view(-1, self.config.num_progression), - progressions.view(-1), - ) - loss = loss + progression_loss - - return VisualOutput( - loss=loss, - observation_logits=observation_logits, - progression_logits=progression_logits, - ) diff --git a/src_stage1/models/modeling_gpt2_20230602.py b/src_stage1/models/modeling_gpt2_20230602.py deleted file mode 100644 index 4f3c112..0000000 --- a/src_stage1/models/modeling_gpt2_20230602.py +++ /dev/null @@ -1,154 +0,0 @@ -from typing import Optional, Tuple, Dict, Any - -import torch -import torch.nn as nn -from transformers import GPT2LMHeadModel -import torchvision.models as models -from transformers import PreTrainedModel -from transformers.modeling_outputs import ModelOutput -from dataclasses import dataclass - - -@dataclass -class VisualOutput(ModelOutput): - loss: Optional[torch.FloatTensor] = None - observation_logits: torch.FloatTensor = None - progression_logits: torch.FloatTensor = None - - -class VisualExtractor(nn.Module): - def __init__(self, visual_extractor): - super(VisualExtractor, self).__init__() - model = getattr(models, visual_extractor)(pretrained=True) - # num_fts = model.fc.in_features - # model.fc = nn.Linear(num_fts, 512, bias=False) - # medclip_state_dict = torch.load( - # "../CLIP/pretrained/medclip-resnet/clip_resnet50.bin" - # ) - # model.load_state_dict(medclip_state_dict.state_dict()) - modules = list(model.children()) - self.model = nn.Sequential(*modules[:-2]) - - def forward(self, images): - patch_feats = self.model(images) - batch_size, feat_size, _, _ = patch_feats.shape - patch_feats = patch_feats.reshape( - batch_size, - feat_size, - -1, - ).permute(0, 2, 1) - return patch_feats - - -class VisualEncoder(PreTrainedModel): - def __init__( - self, - config, - visual_extractor, - ): - super().__init__(config) - visual_extractor_name, d_visual = visual_extractor - self.visual_extractor = VisualExtractor(visual_extractor_name) - - self.feature_space_transformation_nn = nn.Sequential( - nn.Linear(in_features=d_visual, out_features=self.config.n_embd), - nn.ReLU(), - nn.Dropout(self.config.resid_pdrop), - ) - self.observation_transformations = nn.ModuleList( - [ - nn.Sequential( - nn.Linear(d_visual, self.config.n_embd), - nn.ReLU(), - nn.Dropout(self.config.resid_pdrop), - ) - for _ in range(self.config.num_observation) - ] - ) - # self.observation_attn = nn.Linear(self.config.n_embd, self.config.num_observation) - self.observation_cls = nn.Linear(d_visual, self.config.num_observation) - self.progression_cls = nn.Sequential( - nn.Linear(self.config.n_embd * 2, self.config.n_embd), - nn.ReLU(), - nn.Dropout(self.config.resid_pdrop), - nn.Linear(self.config.n_embd, self.config.num_progression), - ) - - def encode_image(self, input_pixels, observations=None): - image_hidden_states = self.visual_extractor(input_pixels) - # observation_attn_weight = torch.softmax( - # self.observation_attn(image_hidden_states), dim=1 - # ) - # observation_hidden_states = observation_attn_weight.permute(0, 2, 1).bmm( - # image_hidden_states - # ) - observation_logits = self.observation_cls(image_hidden_states.mean(dim=1)) - observation_hidden_states = torch.stack( - [ - self.observation_transformations[i](image_hidden_states) - for i in range(self.config.num_observation) - ], - dim=1, - ).mean(dim=2) - image_hidden_states = self.feature_space_transformation_nn(image_hidden_states) - - if observations is not None: - observation_mask = observations - else: - observation_mask = (observation_logits > 0).float() - attention_mask = torch.cat( - (torch.ones_like(image_hidden_states[..., 0]), observation_mask), dim=-1 - ) - image_hidden_states = torch.cat( - (image_hidden_states, observation_hidden_states), dim=1 - ) - return observation_hidden_states, attention_mask, observation_logits - - def forward( - self, - input_pixels: torch.FloatTensor = None, - input_temporal_pixels: torch.FloatTensor = None, - temporal_mask: torch.FloatTensor = None, - observations: Optional[torch.FloatTensor] = None, - progressions: Optional[torch.FloatTensor] = None, - ): - observation_logits = None - progression_logits = None - ( - obs_hidden_states, - _, - observation_logits, - ) = self.encode_image(input_pixels, observations) - ( - prior_obs_hidden_states, - _, - _, - ) = self.encode_image(input_temporal_pixels, observations) - - # progression_hidden_states = torch.cat( - # (obs_hidden_states, prior_obs_hidden_states), dim=-1 - # )[:, :-2] - # progression_logits = self.progression_cls(progression_hidden_states) - - loss = None - if observations is not None: - weight = torch.ones_like(observations) + self.config.alpha * observations - loss_fct = nn.BCEWithLogitsLoss(weight=weight.view(-1)) - loss = loss_fct( - observation_logits.view(-1), - observations.view(-1), - ) - - if progressions is not None and False: - loss_fct = nn.CrossEntropyLoss() - progression_loss = loss_fct( - progression_logits.view(-1, self.config.num_progression), - progressions.view(-1), - ) - loss = loss + progression_loss - - return VisualOutput( - loss=loss, - observation_logits=observation_logits, - progression_logits=progression_logits, - ) diff --git a/src_stage1/models/modeling_vit copy.py b/src_stage1/models/modeling_vit copy.py deleted file mode 100644 index aee9047..0000000 --- a/src_stage1/models/modeling_vit copy.py +++ /dev/null @@ -1,176 +0,0 @@ -from typing import Optional, Tuple, Dict, Any - -import torch -import torch.nn as nn -from transformers import GPT2LMHeadModel -import torchvision.models as models -from transformers import PreTrainedModel, ViTModel -from transformers.modeling_outputs import ModelOutput -from dataclasses import dataclass -from transformers import ViTConfig - - -@dataclass -class VisualOutput(ModelOutput): - loss: Optional[torch.FloatTensor] = None - observation_loss: Optional[torch.FloatTensor] = None - progression_loss: Optional[torch.FloatTensor] = None - observation_det_logits: torch.FloatTensor = None - observation_cls_logits: torch.FloatTensor = None - progression_logits: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - last_hidden_state: torch.FloatTensor = None - - -class VisualEncoder(PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.visual_config = ViTConfig.from_pretrained( - config.pretrained_visual_extractor - ) - self.observation_det = nn.Linear( - self.config.hidden_size * 2, config.num_observation - 1 - ) - self.observation_cls = nn.Linear( - self.config.hidden_size * 2, config.num_observation - ) - self.progression_cls = nn.Linear( - self.config.hidden_size * 2, config.num_progression - ) - self.post_init() - self.visual_extractor = ViTModel.from_pretrained( - config.pretrained_visual_extractor - ) - - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def encode_image(self, input_pixels, require_logits=True): - observation_det_logits = None - observation_cls_logits = None - visual_outputs = self.visual_extractor(input_pixels) - pooler_output = visual_outputs.pooler_output - last_hidden_state = visual_outputs.last_hidden_state - if require_logits: - observation_det_logits = self.observation_det(pooler_output) - observation_cls_logits = self.observation_cls(pooler_output) - return ( - pooler_output, - last_hidden_state, - observation_det_logits, - observation_cls_logits, - ) - - def forward( - self, - input_pixels: torch.FloatTensor = None, - input_temporal_pixels: torch.FloatTensor = None, - temporal_mask: torch.FloatTensor = None, - observations: Optional[torch.FloatTensor] = None, - progressions: Optional[torch.FloatTensor] = None, - entity_labels: Optional[torch.FloatTensor] = None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - require_logits=True, - ): - progression_logits = None - ( - pooler_output, - last_hidden_state, - observation_det_logits, - observation_cls_logits, - ) = self.encode_image( - input_pixels, - # require_logits=require_logits, - require_logits=False, - ) - if temporal_mask.sum() > 0: - ( - prior_pooler_output, - prior_last_hidden_state, - _, - _, - ) = self.encode_image( - input_temporal_pixels, - require_logits=False, - ) - else: - prior_pooler_output = torch.zeros_like(pooler_output) - prior_last_hidden_state = torch.zeros_like(last_hidden_state) - prior_pooler_output = prior_pooler_output * temporal_mask.unsqueeze(-1) - prior_last_hidden_state = prior_last_hidden_state * temporal_mask.unsqueeze( - -1 - ).unsqueeze(-1) - - pooler_output = torch.cat((pooler_output, prior_pooler_output), dim=-1) - observation_det_logits = self.observation_det(pooler_output) - observation_cls_logits = self.observation_cls(pooler_output) - progression_logits = self.progression_cls(pooler_output) - # if require_logits: - # progression_pooler_output = torch.cat( - # (pooler_output, prior_pooler_output), dim=-1 - # ) - # progression_logits = self.progression_cls(progression_pooler_output) - - loss = None - observation_loss = None - progression_loss = None - if observations is not None: - observations_det = (observations != 2).float() - observations_cls = (observations == 1).float() - weight = ( - torch.ones_like(observations_det[:, :-1]) - + self.config.alpha * observations_det[:, :-1] - ) - loss_fct = nn.BCEWithLogitsLoss(weight=weight.view(-1)) - loss = loss_fct( - observation_det_logits.view(-1), - observations_det[:, :-1].reshape(-1), - ) - - observation_cls_loss = self.bceloss_with_mask( - observation_cls_logits, - observations_cls, - mask=observations_det, - ) - loss = loss + observation_cls_loss - observation_loss = loss - if progressions is not None: - num_label = progressions.size(-1) - mask = temporal_mask.unsqueeze(-1).expand(-1, num_label) - progression_loss = self.bceloss_with_mask( - progression_logits, progressions.float(), mask - ) - if loss is None: - loss = progression_loss - else: - loss = loss + progression_loss - return VisualOutput( - loss=loss, - observation_loss=observation_loss, - progression_loss=progression_loss, - observation_det_logits=observation_det_logits, - observation_cls_logits=observation_cls_logits, - progression_logits=progression_logits, - last_hidden_state=(last_hidden_state, prior_last_hidden_state), - ) - - def bceloss_with_mask(self, logits, labels, mask, weight=None): - # if weight is not None: - # weight = torch.ones_like(labels) + weight * labels - loss_fct = nn.BCEWithLogitsLoss(reduction="none", weight=weight) - loss = loss_fct(logits, labels) - # norm = mask.sum() - # norm = torch.max(norm, torch.ones_like(norm)) - # loss = (loss * mask).sum() / norm - loss = (loss * mask).mean() - return loss diff --git a/src_stage1/models/modeling_vit_20230609.py b/src_stage1/models/modeling_vit_20230609.py deleted file mode 100644 index 89865df..0000000 --- a/src_stage1/models/modeling_vit_20230609.py +++ /dev/null @@ -1,132 +0,0 @@ -from typing import Optional, Tuple, Dict, Any - -import torch -import torch.nn as nn -from transformers import GPT2LMHeadModel -import torchvision.models as models -from transformers import PreTrainedModel, ViTModel -from transformers.modeling_outputs import ModelOutput -from dataclasses import dataclass - - -@dataclass -class VisualOutput(ModelOutput): - loss: Optional[torch.FloatTensor] = None - observation_logits: torch.FloatTensor = None - progression_logits: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - last_hidden_state: torch.FloatTensor = None - - -class VisualExtractor(nn.Module): - def __init__(self, visual_extractor): - super(VisualExtractor, self).__init__() - model = getattr(models, visual_extractor)(pretrained=True) - # num_fts = model.fc.in_features - # model.fc = nn.Linear(num_fts, 512, bias=False) - # medclip_state_dict = torch.load( - # "../CLIP/pretrained/medclip-resnet/clip_resnet50.bin" - # ) - # model.load_state_dict(medclip_state_dict.state_dict()) - modules = list(model.children()) - self.model = nn.Sequential(*modules[:-2]) - - def forward(self, images): - patch_feats = self.model(images) - batch_size, feat_size, _, _ = patch_feats.shape - patch_feats = patch_feats.reshape( - batch_size, - feat_size, - -1, - ).permute(0, 2, 1) - return patch_feats - - -class VisualEncoder(PreTrainedModel): - def __init__( - self, - config, - ): - super().__init__(config) - self.visual_extractor = ViTModel.from_pretrained( - config.pretrained_visual_extractor - ) - self.observation_cls = nn.Linear( - self.visual_extractor.config.hidden_size, config.num_observation - ) - self.progression_cls = nn.Linear( - self.visual_extractor.config.hidden_size * 2, config.num_progression - ) - - def encode_image(self, input_pixels): - visual_outputs = self.visual_extractor(input_pixels) - pooler_output = visual_outputs.pooler_output - last_hidden_state = visual_outputs.last_hidden_state - observation_logits = self.observation_cls(pooler_output) - return pooler_output, last_hidden_state, observation_logits - - def forward( - self, - input_pixels: torch.FloatTensor = None, - input_temporal_pixels: torch.FloatTensor = None, - temporal_mask: torch.FloatTensor = None, - observations: Optional[torch.FloatTensor] = None, - progressions: Optional[torch.FloatTensor] = None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): - observation_logits = None - progression_logits = None - ( - pooler_output, - last_hidden_state, - observation_logits, - ) = self.encode_image(input_pixels) - ( - prior_pooler_output, - prior_last_hidden_state, - _, - ) = self.encode_image(input_temporal_pixels) - - progression_pooler_output = torch.cat( - (pooler_output, prior_pooler_output), dim=-1 - ) - progression_logits = self.progression_cls(progression_pooler_output) - - loss = None - if observations is not None: - weight = torch.ones_like(observations) + self.config.alpha * observations - loss_fct = nn.BCEWithLogitsLoss(weight=weight.view(-1)) - loss = loss_fct( - observation_logits.view(-1), - observations.view(-1), - ) - - if progressions is not None: - progression_loss = self.bceloss_with_mask( - progression_logits, progressions.float(), temporal_mask - ) - if loss is None: - loss = progression_loss - else: - loss = loss + progression_loss - - return VisualOutput( - loss=loss, - observation_logits=observation_logits, - progression_logits=progression_logits, - last_hidden_state=(last_hidden_state, prior_last_hidden_state), - ) - - def bceloss_with_mask(self, logits, labels, mask, weight=None): - if weight is not None: - weight = torch.ones_like(labels) + weight * labels - loss_fct = nn.BCEWithLogitsLoss(reduction="none", weight=weight) - num_label = labels.size(-1) - mask = mask.unsqueeze(-1).expand(-1, num_label) - loss = loss_fct(logits, labels) - norm = mask.sum() - norm = torch.max(norm, torch.ones_like(norm)) - loss = (loss * mask).sum() / norm - return loss diff --git a/src_stage1/models/modeling_vit_20230628.py b/src_stage1/models/modeling_vit_20230628.py deleted file mode 100644 index 918ab4d..0000000 --- a/src_stage1/models/modeling_vit_20230628.py +++ /dev/null @@ -1,167 +0,0 @@ -from typing import Optional, Tuple, Dict, Any - -import torch -import torch.nn as nn -from transformers import GPT2LMHeadModel -import torchvision.models as models -from transformers import PreTrainedModel, ViTModel -from transformers.modeling_outputs import ModelOutput -from dataclasses import dataclass -from transformers import ViTConfig - - -@dataclass -class VisualOutput(ModelOutput): - loss: Optional[torch.FloatTensor] = None - observation_loss: Optional[torch.FloatTensor] = None - progression_loss: Optional[torch.FloatTensor] = None - observation_det_logits: torch.FloatTensor = None - observation_cls_logits: torch.FloatTensor = None - progression_logits: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - last_hidden_state: torch.FloatTensor = None - - -class VisualEncoder(PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.visual_config = ViTConfig.from_pretrained( - config.pretrained_visual_extractor - ) - self.observation_det = nn.Linear( - self.config.hidden_size, config.num_observation - 1 - ) - self.observation_cls = nn.Linear( - self.config.hidden_size, config.num_observation - ) - self.progression_cls = nn.Linear( - self.config.hidden_size * 2, config.num_progression - ) - self.post_init() - self.visual_extractor = ViTModel.from_pretrained( - config.pretrained_visual_extractor - ) - - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def encode_image(self, input_pixels, require_logits=True): - observation_det_logits = None - observation_cls_logits = None - visual_outputs = self.visual_extractor(input_pixels) - pooler_output = visual_outputs.pooler_output - last_hidden_state = visual_outputs.last_hidden_state - if require_logits: - observation_det_logits = self.observation_det(pooler_output) - observation_cls_logits = self.observation_cls(pooler_output) - return ( - pooler_output, - last_hidden_state, - observation_det_logits, - observation_cls_logits, - ) - - def forward( - self, - input_pixels: torch.FloatTensor = None, - input_temporal_pixels: torch.FloatTensor = None, - temporal_mask: torch.FloatTensor = None, - observations: Optional[torch.FloatTensor] = None, - progressions: Optional[torch.FloatTensor] = None, - entity_labels: Optional[torch.FloatTensor] = None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - require_logits=True, - ): - progression_logits = None - ( - pooler_output, - last_hidden_state, - observation_det_logits, - observation_cls_logits, - ) = self.encode_image( - input_pixels, - require_logits=require_logits, - ) - if temporal_mask.sum() > 0: - ( - prior_pooler_output, - prior_last_hidden_state, - _, - _, - ) = self.encode_image( - input_temporal_pixels, - require_logits=False, - ) - else: - prior_pooler_output = torch.zeros_like(pooler_output) - prior_last_hidden_state = torch.zeros_like(last_hidden_state) - - if require_logits: - progression_pooler_output = torch.cat( - (pooler_output, prior_pooler_output), dim=-1 - ) - progression_logits = self.progression_cls(progression_pooler_output) - - loss = None - observation_loss = None - progression_loss = None - if observations is not None: - observations_det = (observations != 2).float() - observations_cls = (observations == 1).float() - weight = ( - torch.ones_like(observations_det[:, :-1]) - + self.config.alpha * observations_det[:, :-1] - ) - loss_fct = nn.BCEWithLogitsLoss(weight=weight.view(-1)) - loss = loss_fct( - observation_det_logits.view(-1), - observations_det[:, :-1].reshape(-1), - ) - - observation_cls_loss = self.bceloss_with_mask( - observation_cls_logits, - observations_cls, - mask=observations_det, - ) - loss = loss + observation_cls_loss - observation_loss = loss - if progressions is not None: - num_label = progressions.size(-1) - mask = temporal_mask.unsqueeze(-1).expand(-1, num_label) - progression_loss = self.bceloss_with_mask( - progression_logits, progressions.float(), mask - ) - if loss is None: - loss = progression_loss - else: - loss = loss + progression_loss - return VisualOutput( - loss=loss, - observation_loss=observation_loss, - progression_loss=progression_loss, - observation_det_logits=observation_det_logits, - observation_cls_logits=observation_cls_logits, - progression_logits=progression_logits, - last_hidden_state=(last_hidden_state, prior_last_hidden_state), - ) - - def bceloss_with_mask(self, logits, labels, mask, weight=None): - # if weight is not None: - # weight = torch.ones_like(labels) + weight * labels - loss_fct = nn.BCEWithLogitsLoss(reduction="none", weight=weight) - loss = loss_fct(logits, labels) - # norm = mask.sum() - # norm = torch.max(norm, torch.ones_like(norm)) - # loss = (loss * mask).sum() / norm - loss = (loss * mask).mean() - return loss diff --git a/src_stage1/models/vae.py b/src_stage1/models/vae.py deleted file mode 100644 index dc9c484..0000000 --- a/src_stage1/models/vae.py +++ /dev/null @@ -1,343 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from src.models.modeling_bart import shift_tokens_right -from torch.autograd import Variable - -from models.activations import ACT2FN -from models.layers import SinkhornDistance -from models.von_mises_fisher import VonMisesFisher - - -def word_avg(encoder_hidden_states, attention_mask): - if attention_mask is not None: - sum_vecs = (encoder_hidden_states * - attention_mask.unsqueeze(-1)).sum(1) - avg_vecs = sum_vecs / attention_mask.sum(1, keepdim=True) - else: - avg_vecs = encoder_hidden_states.mean(1) - return avg_vecs - - -def kl_loss(mean, var): - """ - KL(p||N(0,1)) - """ - return -0.5 * torch.mean(torch.mean(1 + var - mean.pow(2) - var.exp(), 1)) - - -class VecDecoder(nn.Module): - def __init__(self, config): - super(VecDecoder, self).__init__() - self.config = config - bow_head = [] - for i in range(config.bow_layers): - if i == 0: - inp_dim = config.d_mean - out_dim = config.d_mean_var - elif i == config.bow_layers - 1: - inp_dim = config.d_mean_var - out_dim = config.d_model - else: - inp_dim = out_dim = config.d_mean_var - layer = nn.Linear(inp_dim, out_dim) - bow_head.append(layer) - bow_head.append(nn.Dropout(config.dropout)) - if i < config.bow_layers - 1: - bow_head.append(ACT2FN[config.activation_function]) - else: - bow_head.append(nn.Tanh()) - self.bow_head = nn.Sequential(*bow_head) - - def forward(self, z, labels=None): - vae_loss = None - if labels is not None: - loss_fct = nn.MSELoss() - logits = self.bow_head(z) - vae_loss = loss_fct( - logits, - labels, - ) - return vae_loss - - -class BoWDecoder(nn.Module): - def __init__(self, config, init_inp_dim=None): - super(BoWDecoder, self).__init__() - self.config = config - bow_head = [] - for i in range(config.bow_layers): - if i == 0: - inp_dim = config.d_mean if init_inp_dim is None else init_inp_dim - out_dim = config.d_mean_var - elif i == config.bow_layers - 1: - inp_dim = config.d_mean_var - out_dim = config.vocab_size - else: - inp_dim = out_dim = config.d_mean_var - layer = nn.Linear(inp_dim, out_dim) - bow_head.append(layer) - bow_head.append(nn.Dropout(config.dropout)) - if i < config.bow_layers - 1: - bow_head.append(ACT2FN[config.activation_function]) - self.bow_head = nn.Sequential(*bow_head) - - def forward(self, z, labels=None): - vae_loss = None - if labels is not None: - seq_len = labels.size(1) - loss_fct = nn.CrossEntropyLoss() - logits = self.bow_head(z) - expanded_logits = logits.unsqueeze(1).expand(-1, seq_len, -1) - vae_loss = loss_fct( - expanded_logits.reshape(-1, self.config.vocab_size), - labels.view(-1), - ) - return vae_loss - - -class RNNDecoder(nn.Module): - def __init__(self, config): - super(RNNDecoder, self).__init__() - self.config = config - self.embed = nn.Embedding( - num_embeddings=config.vocab_size, - embedding_dim=config.d_model, - padding_idx=config.pad_token_id, - ) - self.decoder = nn.GRU( - input_size=config.d_model, - hidden_size=config.d_mean, - num_layers=1, - batch_first=True, - ) - self.lm_head = nn.Linear(config.d_mean, config.vocab_size) - - def forward(self, init_hidden, labels): - input_ids = shift_tokens_right( - input_ids=labels, - pad_token_id=self.config.pad_token_id, - decoder_start_token_id=self.config.bos_token_id, - ) - embed_tokens = self.embed(input_ids) - hidden_states, _ = self.decoder(embed_tokens, init_hidden.unsqueeze(0)) - logits = self.lm_head(hidden_states) - loss_fct = nn.CrossEntropyLoss() - lm_loss = loss_fct( - logits.view(-1, self.config.vocab_size), - labels.view(-1), - ) - return lm_loss - - -class VAE(nn.Module): - def __init__(self, config, is_sem=False): - super(VAE, self).__init__() - self.mean = nn.Linear(config.d_model, config.d_mean) - d_var = config.d_var if not is_sem else 1 - self.var = nn.Linear(config.d_model, d_var) - self.decoder = None - - def process( - self, - encoder_hidden_states, - attention_mask, - is_sem=False, - ): - mean_hidden_states = self.mean(encoder_hidden_states) - if is_sem: - mean_hidden_states = mean_hidden_states / mean_hidden_states.norm( - dim=-1, keepdim=True) - mean_state = word_avg(mean_hidden_states, attention_mask) - var_hidden_states = self.var(encoder_hidden_states) - if is_sem: - var_hidden_states = F.softplus(var_hidden_states) + 100 - var_state = word_avg(var_hidden_states, attention_mask) - return mean_state, var_state - - def forward( - self, - encoder_hidden_states, - attention_mask, - labels=None, - ): - mean_hidden_states, var_hidden_states = self.process( - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - ) - z = self.sample_gaussian( - mean=mean_hidden_states, - var=var_hidden_states, - ) - # assert self.decoder is None, "Please initialize decoder" - rl = self.decoder(z, labels) - kl = kl_loss(mean_hidden_states, var_hidden_states) - return rl, kl, mean_hidden_states - - def sample_gaussian(self, mean, var): - sample = mean + torch.exp(0.5 * var) * Variable( - var.data.new(var.size()).normal_()) - return sample - - -class TextVAE(VAE): - def __init__(self, config): - super(TextVAE, self).__init__(config) - self.decoder = BoWDecoder(config) - # self.decoder = RNNDecoder(config) - - -class VGVAE(nn.Module): - def __init__(self, config): - super(VGVAE, self).__init__() - self.semantic_vae = VAE(config, is_sem=True) - self.syntactic_vae = VAE(config) - self.decoder = BoWDecoder(config, init_inp_dim=config.d_mean * 2) - self.pos_decoder = nn.Linear( - config.d_var + config.d_model, - config.max_position_embeddings, - ) - self.max_position = config.max_position_embeddings - - def wpl_fn(self, hidden_states, labels): - # word position loss - max_position = labels.size(1) - pos_labels = torch.arange( - 0, - max_position, - device=hidden_states.device, - ).masked_fill(labels == -100, -100) - pos_hidden_states = self.pos_decoder(hidden_states) - loss_fct = nn.CrossEntropyLoss() - wpl = loss_fct( - pos_hidden_states.view(-1, self.max_position), - pos_labels.view(-1), - ) - return wpl - - def forward( - self, - encoder_hidden_states, - attention_mask, - labels=None, - ): - sem_mean, sem_var = self.semantic_vae.process( - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - is_sem=True, - ) - syn_mean, syn_var = self.syntactic_vae.process( - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - is_sem=False, - ) - sem_dist = VonMisesFisher(sem_mean, sem_var) - sem_z = sem_dist.rsample() - syn_z = self.syntactic_vae.sample_gaussian(syn_mean, syn_var) - z = torch.cat((sem_z, syn_z), dim=-1) - rl = self.decoder(z, labels) - sem_kl = sem_dist.kl_div().mean() - syn_kl = kl_loss(syn_mean, syn_var) - wpl = self.wpl_fn( - torch.cat( - (encoder_hidden_states, syn_z.unsqueeze(1).expand( - -1, - encoder_hidden_states.size(1), - -1, - )), - dim=-1, - ), - labels, - ) - return rl, syn_kl, wpl, sem_kl, sem_mean - - -class VisualVAE(VAE): - def __init__(self, config): - super(VisualVAE, self).__init__(config) - self.decoder = VecDecoder(config) - - def forward(self, encoder_hidden_states, labels): - avg_hidden_states = word_avg(encoder_hidden_states, None) - mean_hidden_states = self.mean(avg_hidden_states) - var_hidden_states = self.var(avg_hidden_states) - z = self.sample_gaussian( - mean=mean_hidden_states, - var=var_hidden_states, - ) - rl = self.decoder(z, labels) - kl = kl_loss(mean_hidden_states, var_hidden_states) - return rl, kl, mean_hidden_states - - -class DisentanglementModel(nn.Module): - def __init__(self, config): - super(DisentanglementModel, self).__init__() - self.text_vae = TextVAE(config) - # self.text_vae = VGVAE(config) - self.visual_vae = VisualVAE(config) - self.margin = 1. - - def forward( - self, - visual_hidden_states, - pos_hidden_states, - neg_hidden_states, - pos_attention_mask, - neg_attention_mask, - visual_labels, - pos_labels, - neg_labels, - ): - pos_text_rl, pos_sem_kl, pos_states = self.text_vae( - pos_hidden_states, - pos_attention_mask, - pos_labels, - ) - neg_text_rl, neg_sem_kl, neg_states = self.text_vae( - neg_hidden_states, - neg_attention_mask, - neg_labels, - ) - # pos_text_rl, pos_syn_kl, pos_wpl, pos_sem_kl, pos_states = self.text_vae( - # pos_hidden_states.detach(), - # pos_attention_mask, - # pos_labels, - # ) - # neg_text_rl, neg_syn_kl, neg_wpl, neg_sem_kl, neg_states = self.text_vae( - # neg_hidden_states.detach(), - # neg_attention_mask, - # neg_labels, - # ) - visual_rl, visual_kl, anchor_states = self.visual_vae( - visual_hidden_states, - visual_labels.mean(dim=1).detach(), - ) - pos_cos = F.cosine_similarity(pos_states, anchor_states) - neg_cos = F.cosine_similarity(neg_states, anchor_states) - dl = F.relu(self.margin - pos_cos + neg_cos).mean() - # pos_loss = self.margin - F.mse_loss(pos_states, anchor_states) - # neg_loss = self.margin - F.mse_loss(neg_states, anchor_states) - # dl = 0.5 * F.relu(pos_loss).mean() + 0.5 * F.relu(neg_loss).mean() - text_beta = 1e-3 - visual_beta = 1e-3 - # text_beta = visual_beta = 1. - # print("pos_text_rl", pos_text_rl.item()) - print("pos_text_kl", pos_sem_kl.item()) - # print("neg_text_rl", neg_text_rl.item()) - print("neg_text_kl", neg_sem_kl.item()) - # print("visual_rl", visual_rl.item()) - print("visual_kl", visual_kl.item()) - # print("dl", dl.item()) - loss = 0 - # positive report vae - loss = loss + pos_text_rl + text_beta * pos_sem_kl #+ text_beta * pos_syn_kl + pos_wpl - # negative report vae - loss = loss + neg_text_rl + text_beta * neg_sem_kl #+ text_beta * neg_syn_kl + neg_wpl - # image vae - loss = loss + visual_rl + visual_beta * visual_kl - # discriminative loss - # sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction=None) - # dl, P, C = sinkhorn(pos_states, anchor_states) - loss = loss + dl - return loss diff --git a/src_stage1/optimizer.py b/src_stage1/optimizer.py index f27bd74..744d296 100644 --- a/src_stage1/optimizer.py +++ b/src_stage1/optimizer.py @@ -13,12 +13,7 @@ def create_optimizer(model, args, fast_lr=1e-4): decay_parameters = [name for name in decay_parameters if "bias" not in name] fast_params = [] - # for n, _ in model.named_parameters(): - # if "visual_encoder" not in n and "Bert" not in n: - # fast_params.append(n) - for n, _ in model.named_parameters(): - # if not n.startswith("model.decoder"): if not n.startswith("gpt_with_lm_head"): fast_params.append(n) diff --git a/src_stage1/run_ende.py b/src_stage1/run_ende.py index a7b4cb7..721e4f6 100644 --- a/src_stage1/run_ende.py +++ b/src_stage1/run_ende.py @@ -133,7 +133,6 @@ def main(): train_idxs=train_idxs, ) - from models.modeling_vit import VisualEncoder checkpoint = "GanjinZero/biobart-base" diff --git a/src_stage1/seq2seqtrainer_metrics_ende.py b/src_stage1/seq2seqtrainer_metrics_ende.py index 859079a..5fbc206 100644 --- a/src_stage1/seq2seqtrainer_metrics_ende.py +++ b/src_stage1/seq2seqtrainer_metrics_ende.py @@ -50,6 +50,7 @@ def evaluation_loop( model.eval() + # Only for single GPU training self.callback_handler.eval_dataloader = dataloader # Do this before wrapping. eval_dataset = dataloader.dataset diff --git a/src_stage1/tokenizer.py b/src_stage1/tokenizer.py index 2486140..1d41ae8 100644 --- a/src_stage1/tokenizer.py +++ b/src_stage1/tokenizer.py @@ -1,10 +1,8 @@ -import copy import json import re from collections import Counter, defaultdict import pandas as pd -from transformers.tokenization_utils import PreTrainedTokenizer import os import pickle @@ -16,10 +14,7 @@ def __init__(self, config): self.ann_path = config.annotation_file self.threshold = config.threshold self.dataset = config.dataset - if self.dataset == "iu_xray": - self.clean_report = Tokenizer.clean_report_iu_xray - else: - self.clean_report = Tokenizer.clean_report_mimic_cxr + self.clean_report = Tokenizer.clean_report_mimic_cxr print(self.clean_report) self.ann = json.loads(open(self.ann_path, "r").read()) self.token2idx, self.idx2token, self.special_tokens = self.create_vocabulary() @@ -45,47 +40,6 @@ def create_vocabulary(self): idx2token[idx] = token return token2idx, idx2token, special_tokens[:-1] - @staticmethod - def clean_report_iu_xray(report): - def report_cleaner(t): - return ( - t.replace("..", ".") - .replace("..", ".") - .replace("..", ".") - .replace("1. ", "") - .replace(". 2. ", ". ") - .replace(". 3. ", ". ") - .replace(". 4. ", ". ") - .replace(". 5. ", ". ") - .replace(" 2. ", ". ") - .replace(" 3. ", ". ") - .replace(" 4. ", ". ") - .replace(" 5. ", ". ") - .strip() - .lower() - .split(". ") - ) - - def sent_cleaner(t): - return re.sub( - "[.,?;*!%^&_+():-\[\]{}]", - "", - t.replace('"', "") - .replace("/", "") - .replace("\\", "") - .replace("'", "") - .strip() - .lower(), - ) - - tokens = [ - sent_cleaner(sent).strip() + " ." - for sent in report_cleaner(report) - if len(sent_cleaner(sent).strip()) > 0 - ] - report = " ".join(tokens) - return report - @staticmethod def clean_report_mimic_cxr(report): def report_cleaner(t): @@ -160,8 +114,6 @@ def load_tag2ids( tags = pd.read_csv(tag_path) with open(cached_path, "wb") as f: pickle.dump(tags, file=f) - # tags = tags.fillna(0).replace(-1, 1) - # tags = tags.replace(-1, 1).replace(0, 1).fillna(0) tags = tags.replace(-1, 1).fillna(2) diseases = list(tags)[2:] id2tags = defaultdict(list) @@ -240,33 +192,3 @@ def batch_decode(self, ids_batch, skip_special_tokens=True, separator=" "): def save_pretrained(self, save_directory): return "" - - def update_progression_tokens(self, observations, statuses): - token_id = len(self.token2idx) - for obs in observations: - for status in statuses: - token = f"[{obs}_{status}]" - if token not in self.token2idx: - self.token2idx[token] = token_id - self.idx2token[token_id] = token - self.special_tokens.append(token) - token_id += 1 - self.token2idx["[PRO]"] = token_id - self.idx2token[token_id] = "[PRO]" - self.special_tokens.append("[PRO]") - - def search_progression_token(self, observation, status): - return self.token2idx[f"[{observation}_{status}]"] - - -class TagTokenizer: - def __init__(self, header) -> None: - self.head2id = {head: idx for idx, head in enumerate(header)} - - def encode(self, tags): - tag_ids = [] - for tag in tags: - tag_ids.append(self.head2id[tag]) - if len(tag_ids) == 0: - tag_ids = [len(self.head2id) + 1] - return tag_ids diff --git a/src_stage1/train_eval_ende_full.py b/src_stage1/train_eval_ende_full.py index 4a6d625..fec0e87 100644 --- a/src_stage1/train_eval_ende_full.py +++ b/src_stage1/train_eval_ende_full.py @@ -303,4 +303,4 @@ def get_pred(a, b): encoding="utf-8", ) as f: json.dump(output_data, f, ensure_ascii=False, indent=4) - return {"eval_BLEU_4": target} + return {"eval_macro_f1": target}