Skip to content

Commit

Permalink
Feat/zshot version (#17)
Browse files Browse the repository at this point in the history
* 🎨 Improved structure of setup and init.
* ✏️ Fixed minor typos and format in evaluator
* ✅ Update evaluation tests to work with latest version of evaluate
* 🐛 Fixed bug while importing version
  • Loading branch information
marmg authored Oct 17, 2022
1 parent 92cfc30 commit 76e3a32
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 13 deletions.
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
[egg_info]
tag_svn_revision = true

[metadata]
version = attr: zshot.__version__
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
this_directory = Path(__file__).parent
long_description = (this_directory / "README.md").read_text()

version = '0.0.2'

setup(name='zshot',
version=version,
description="Zero and Few shot named entity recognition",
long_description_content_type='text/markdown',
long_description=long_description,
Expand Down
2 changes: 2 additions & 0 deletions zshot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from zshot.zshot import MentionsExtractor, Linker, Zshot, PipelineConfig # noqa: F401
from zshot.utils.displacy import displacy # noqa: F401

__version__ = '0.0.3'
3 changes: 2 additions & 1 deletion zshot/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def prepare_pipeline(
feature_extractor=None, # noqa: F821
device: int = None,
):
pipe = super(TokenClassificationEvaluator, self).prepare_pipeline(model_or_pipeline, tokenizer, feature_extractor, device)
pipe = super(TokenClassificationEvaluator, self).prepare_pipeline(model_or_pipeline, tokenizer,
feature_extractor, device)
return pipe


Expand Down
6 changes: 4 additions & 2 deletions zshot/evaluation/zshot_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
from prettytable import PrettyTable

from zshot.evaluation import load_medmentions, load_ontonotes
from zshot.evaluation.dataset.dataset import DatasetWithEntities
from zshot.evaluation.evaluator import ZeroShotTokenClassificationEvaluator, MentionsExtractorEvaluator
from zshot.evaluation.pipeline import LinkerPipeline, MentionsExtractorPipeline


def evaluate(nlp: spacy.language.Language,
datasets: Union[DatasetWithEntities, List[DatasetWithEntities]],
datasets: Union[str, List[str]],
splits: Optional[Union[str, List[str]]] = None,
metric: Optional[Union[str, EvaluationModule]] = None,
batch_size: Optional[int] = 16) -> str:
Expand All @@ -31,6 +30,9 @@ def evaluate(nlp: spacy.language.Language,
if type(splits) == str:
splits = [splits]

if type(datasets) == str:
datasets = [datasets]

result = {}
field_names = ["Metric"]
for dataset_name in datasets:
Expand Down
17 changes: 9 additions & 8 deletions zshot/tests/evaluation/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_prediction_token_based_evaluation_all_matching(self):
dataset = get_dataset(gt, sentences)

custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification")
metrics = custom_evaluator.compute(get_linker_pipe([('New York', 'FAC', 1)]), dataset, "seqeval")
metrics = custom_evaluator.compute(get_linker_pipe([('New York', 'FAC', 1)]), dataset, metric="seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
Expand All @@ -128,7 +128,7 @@ def test_prediction_token_based_evaluation_overlapping_spans(self):

custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification")
metrics = custom_evaluator.compute(get_linker_pipe([('New York', 'FAC', 1), ('York', 'LOC', 0.7)]), dataset,
"seqeval")
metric="seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
Expand All @@ -144,7 +144,7 @@ def test_prediction_token_based_evaluation_partial_match_spans_expand(self):
custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification",
alignment_mode=AlignmentMode.expand)
pipe = get_linker_pipe([('New Yo', 'FAC', 1)])
metrics = custom_evaluator.compute(pipe, dataset, "seqeval")
metrics = custom_evaluator.compute(pipe, dataset, metric="seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
Expand All @@ -160,7 +160,7 @@ def test_prediction_token_based_evaluation_partial_match_spans_contract(self):
custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification",
alignment_mode=AlignmentMode.contract)
pipe = get_linker_pipe([('New York i', 'FAC', 1)])
metrics = custom_evaluator.compute(pipe, dataset, "seqeval")
metrics = custom_evaluator.compute(pipe, dataset, metric="seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
Expand All @@ -176,7 +176,7 @@ def test_prediction_token_based_evaluation_partial_and_overlapping_spans(self):
custom_evaluator = ZeroShotTokenClassificationEvaluator("token-classification",
alignment_mode=AlignmentMode.contract)
pipe = get_linker_pipe([('New York i', 'FAC', 1), ('w York', 'LOC', 0.7)])
metrics = custom_evaluator.compute(pipe, dataset, "seqeval")
metrics = custom_evaluator.compute(pipe, dataset, metric="seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
Expand Down Expand Up @@ -207,7 +207,8 @@ def test_prediction_token_based_evaluation_all_matching(self):
dataset = get_dataset(gt, sentences)

custom_evaluator = MentionsExtractorEvaluator("token-classification")
metrics = custom_evaluator.compute(get_mentions_extractor_pipe([('New York', 'FAC', 1)]), dataset, "seqeval")
metrics = custom_evaluator.compute(get_mentions_extractor_pipe([('New York', 'FAC', 1)]), dataset,
metric="seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
Expand All @@ -222,7 +223,7 @@ def test_prediction_token_based_evaluation_overlapping_spans(self):

custom_evaluator = MentionsExtractorEvaluator("token-classification")
metrics = custom_evaluator.compute(get_mentions_extractor_pipe([('New York', 'FAC', 1), ('York', 'LOC', 0.7)]),
dataset, "seqeval")
dataset, metric="seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
Expand All @@ -238,7 +239,7 @@ def test_prediction_token_based_evaluation_partial_match_spans_expand(self):
custom_evaluator = MentionsExtractorEvaluator("token-classification",
alignment_mode=AlignmentMode.expand)
pipe = get_mentions_extractor_pipe([('New Yo', 'FAC', 1)])
metrics = custom_evaluator.compute(pipe, dataset, "seqeval")
metrics = custom_evaluator.compute(pipe, dataset, metric="seqeval")

assert float(metrics["overall_precision"]) == 1.0
assert float(metrics["overall_precision"]) == 1.0
Expand Down

0 comments on commit 76e3a32

Please sign in to comment.