From 82920ec1222d2736a43f1c67d933451aa82282ed Mon Sep 17 00:00:00 2001 From: Lawhy Date: Wed, 20 Nov 2024 00:47:10 +0000 Subject: [PATCH] update --- scripts/hierarchy_retrain.py | 4 ++-- src/hierarchy_transformers/datasets/load.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/scripts/hierarchy_retrain.py b/scripts/hierarchy_retrain.py index 2e9b3bf..8309266 100644 --- a/scripts/hierarchy_retrain.py +++ b/scripts/hierarchy_retrain.py @@ -57,7 +57,7 @@ def main(config_file: str, gpu_id: int): if config.loss.cluster.weight > 0.0: if config.apply_triplet_loss: - cluster_loss = ClusteringTripletLoss(model.manifold, config.loss.cluster.margin) + cluster_loss = HyperbolicClusteringLoss(model.manifold, config.loss.cluster.margin) else: cluster_loss = ClusteringConstrastiveLoss( model.manifold, config.loss.cluster.positive_margin, config.loss.cluster.margin @@ -65,7 +65,7 @@ def main(config_file: str, gpu_id: int): losses.append((config.loss.cluster.weight, cluster_loss)) if config.loss.centri.weight > 0.0: - centri_loss_class = CentripetalTripletLoss if config.apply_triplet_loss else CentripetalContrastiveLoss + centri_loss_class = HyperbolicCentripetalLoss if config.apply_triplet_loss else CentripetalContrastiveLoss centri_loss = centri_loss_class(model.manifold, model.embed_dim, config.loss.centri.margin) losses.append((config.loss.centri.weight, centri_loss)) diff --git a/src/hierarchy_transformers/datasets/load.py b/src/hierarchy_transformers/datasets/load.py index 28fa175..b202f96 100644 --- a/src/hierarchy_transformers/datasets/load.py +++ b/src/hierarchy_transformers/datasets/load.py @@ -86,10 +86,11 @@ def load_zenodo_dataset( if entity_to_index: entity_lexicon = entity_to_index - for split in dataset.keys(): + for split, examples in dataset.items(): + # list comprehension is faster than nested for-loop due to C implementation dataset[split] = [ transformed - for example in tqdm(dataset[split], desc=f"Map ({split})", unit="example", leave=True) + for example in tqdm(examples, desc=f"Map ({split})", unit="example", leave=True) for transformed in transform(example, negative_type, entity_lexicon) ]