Skip to content

Commit

Permalink
Revert "Update Project Page"
Browse files Browse the repository at this point in the history
This reverts commit 14be609.
  • Loading branch information
wjhou committed Jan 2, 2024
1 parent 14be609 commit faeb71e
Show file tree
Hide file tree
Showing 38 changed files with 137 additions and 1,616 deletions.
7 changes: 4 additions & 3 deletions src_stage1/data_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ class DataTrainingArguments:
)
chexbert_label: Optional[str] = field(default=None)
debug_model: Optional[bool] = field(default=False)
max_context_length: Optional[int] = field(
default=256,
)
max_tgt_length: Optional[int] = field(
default=64,
)
Expand Down Expand Up @@ -106,5 +103,9 @@ class DataTrainingArguments:
)
alpha: Optional[float] = field(default=3)
beta: Optional[float] = field(default=3)
wo_op: Optional[int] = field(default=1)
wo_obs: Optional[int] = field(default=1)
wo_pro: Optional[int] = field(default=1)
wo_prr: Optional[int] = field(default=1)
topk: Optional[int] = field(default=10)
lambda_: Optional[float] = field(default=0.5)
3 changes: 0 additions & 3 deletions src_stage1/dataset_ende.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import os

from torch.utils.data import Dataset
from torchvision import transforms
import torch
from data_arguments import DataTrainingArguments
from data_process_ende import process_examples
from tokenizer import Tokenizer
from tqdm import tqdm
from PIL import Image
from transformers import GPT2Tokenizer


def load_images(root_path, image_paths):
Expand Down
32 changes: 0 additions & 32 deletions src_stage1/extract_report.py

This file was deleted.

104 changes: 104 additions & 0 deletions src_stage1/graph_construction/pmi_observation_entity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from tqdm import tqdm
import pandas as pd
import argparse
import json
import math
from collections import defaultdict
import os
from nltk.corpus import stopwords
import numpy as np
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from tokenizer import Tokenizer

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, required=True, help="the name of dataset")
parser.add_argument("--chexbert_label", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--pmi_threshold", type=float, required=True)
parser.add_argument("--min_count", type=int, default=5, help="min_count")

config = parser.parse_args()

print("dataset: ", config.dataset)
clean_fn = Tokenizer.clean_report_mimic_cxr
print(clean_fn)
id2observation, observation_category, _ = Tokenizer.load_tag2ids(
config.chexbert_label, need_header=True
)
print(len(id2observation), observation_category)

sem_stat = json.load(
open(
os.path.join(config.output_dir, config.dataset, "sem_stat.json"),
"r",
encoding="utf-8",
)
)
tem_stat = json.load(
open(
os.path.join(config.output_dir, config.dataset, "tem_stat.json"),
"r",
encoding="utf-8",
)
)

swords = stopwords.words("english")
sem_stat_keep = {
k: {
subk
for subk, subv in v.items()
if subk not in swords
and not subk.startswith("_")
and subv >= config.min_count
and subk not in tem_stat
}
for k, v in sem_stat.items()
}


with open(
os.path.join(config.output_dir, config.dataset, "id2entity.json"),
"r",
encoding="utf-8",
) as f:
id2entity = json.load(f)
observation_stat = defaultdict(int)
observation_ngram_stat = defaultdict(int)
observation_ngram_norm = defaultdict(int)

sem_stat_all = sem_stat["ALL"]
p_y_norm = sum(sem_stat_all.values())
sem_stat.pop("ALL")
p_y_x = {}
for obs in sem_stat:
p_y_x[obs] = {
k: v / sum(sem_stat[obs].values())
for k, v in sem_stat[obs].items()
if k in sem_stat_keep[obs]
}

p_y = {k: v / p_y_norm for k, v in sem_stat_all.items()}

pmi = {}
k = 1
for observation in p_y_x:
for ent in p_y_x[observation]:
if ent not in p_y:
continue
pmi_xy = math.log(p_y_x[observation][ent] / p_y[ent], 2)
if pmi_xy <= config.pmi_threshold:
continue
pmi[(observation, ent)] = pmi_xy

new_pairs = {}
for key in pmi:
new_pairs["@".join(key)] = pmi[key]

with open(
os.path.join(config.output_dir, config.dataset, "obs2sem.json"),
"w",
encoding="utf-8",
) as f:
json.dump(new_pairs, f, ensure_ascii=False, indent=4)
207 changes: 0 additions & 207 deletions src_stage1/graph_construction/pmi_observation_ngram.py

This file was deleted.

Loading

0 comments on commit faeb71e

Please sign in to comment.