Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Lawhy committed Nov 20, 2024
1 parent ed7c6cc commit 82920ec
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions scripts/hierarchy_retrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ def main(config_file: str, gpu_id: int):

if config.loss.cluster.weight > 0.0:
if config.apply_triplet_loss:
cluster_loss = ClusteringTripletLoss(model.manifold, config.loss.cluster.margin)
cluster_loss = HyperbolicClusteringLoss(model.manifold, config.loss.cluster.margin)
else:
cluster_loss = ClusteringConstrastiveLoss(
model.manifold, config.loss.cluster.positive_margin, config.loss.cluster.margin
)
losses.append((config.loss.cluster.weight, cluster_loss))

if config.loss.centri.weight > 0.0:
centri_loss_class = CentripetalTripletLoss if config.apply_triplet_loss else CentripetalContrastiveLoss
centri_loss_class = HyperbolicCentripetalLoss if config.apply_triplet_loss else CentripetalContrastiveLoss
centri_loss = centri_loss_class(model.manifold, model.embed_dim, config.loss.centri.margin)
losses.append((config.loss.centri.weight, centri_loss))

Expand Down
5 changes: 3 additions & 2 deletions src/hierarchy_transformers/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ def load_zenodo_dataset(
if entity_to_index:
entity_lexicon = entity_to_index

for split in dataset.keys():
for split, examples in dataset.items():
# list comprehension is faster than nested for-loop due to C implementation
dataset[split] = [
transformed
for example in tqdm(dataset[split], desc=f"Map ({split})", unit="example", leave=True)
for example in tqdm(examples, desc=f"Map ({split})", unit="example", leave=True)
for transformed in transform(example, negative_type, entity_lexicon)
]

Expand Down

0 comments on commit 82920ec

Please sign in to comment.