Skip to content

Commit

Permalink
✅ Add GLiNER tests
Browse files Browse the repository at this point in the history
Signed-off-by: Marcos Martínez Galindo <marcosmartinezgalindo@Marcoss-MacBook-Pro.local>
  • Loading branch information
Marcos Martínez Galindo authored and Marcos Martínez Galindo committed Aug 15, 2024
1 parent 8e13681 commit 83553c9
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 0 deletions.
59 changes: 59 additions & 0 deletions zshot/tests/linker/test_gliner_linker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import gc
import logging

import pytest
import spacy

from zshot import PipelineConfig, Linker
from zshot.linker import LinkerGLINER
from zshot.tests.config import EX_DOCS, EX_ENTITIES

logger = logging.getLogger(__name__)


@pytest.fixture(scope="module", autouse=True)
def teardown():
logger.warning("Starting smxm tests")
yield True
gc.collect()


def test_gliner_download():
linker = LinkerGLINER()
linker.load_models()
assert isinstance(linker, Linker)
del linker.model, linker


def test_smxm_linker():
nlp = spacy.blank("en")
gliner_config = PipelineConfig(
linker=LinkerGLINER(),
entities=EX_ENTITIES
)
nlp.add_pipe("zshot", config=gliner_config, last=True)
assert "zshot" in nlp.pipe_names

doc = nlp(EX_DOCS[1])
assert len(doc.ents) > 0
docs = [doc for doc in nlp.pipe(EX_DOCS)]
assert all(len(doc.ents) > 0 for doc in docs)
del nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker
nlp.remove_pipe('zshot')
del doc, nlp, gliner_config


def test_smxm_linker_no_entities():
nlp = spacy.blank("en")
gliner_config = PipelineConfig(
linker=LinkerGLINER(),
entities=[]
)
nlp.add_pipe("zshot", config=gliner_config, last=True)
assert "zshot" in nlp.pipe_names

doc = nlp(EX_DOCS[1])
assert len(doc.ents) == 0
del nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker
nlp.remove_pipe('zshot')
del doc, nlp, gliner_config
70 changes: 70 additions & 0 deletions zshot/tests/mentions_extractor/test_gliner_mentions_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import gc
import logging

import pytest
import spacy

from zshot import PipelineConfig, MentionsExtractor
from zshot.mentions_extractor import MentionsExtractorGLINER
from zshot.tests.config import EX_DOCS, EX_ENTITIES

logger = logging.getLogger(__name__)


@pytest.fixture(scope="module", autouse=True)
def teardown():
logger.warning("Starting smxm tests")
yield True
gc.collect()


def test_gliner_download():
mentions_extractor = MentionsExtractorGLINER()
mentions_extractor.load_models()
assert isinstance(mentions_extractor, MentionsExtractor)
del mentions_extractor


def test_gliner_mentions_extractor():
nlp = spacy.blank("en")
gliner_config = PipelineConfig(
mentions_extractor=MentionsExtractorGLINER(),
mentions=EX_ENTITIES
)
nlp.add_pipe("zshot", config=gliner_config, last=True)
assert "zshot" in nlp.pipe_names

doc = nlp(EX_DOCS[1])
assert len(doc._.mentions) > 0
nlp.remove_pipe('zshot')
del doc, nlp


def test_gliner_mentions_extractor_pipeline():
nlp = spacy.blank("en")
gliner_config = PipelineConfig(
mentions_extractor=MentionsExtractorGLINER(),
mentions=EX_ENTITIES
)
nlp.add_pipe("zshot", config=gliner_config, last=True)
assert "zshot" in nlp.pipe_names

docs = [doc for doc in nlp.pipe(EX_DOCS)]
assert all(len(doc._.mentions) > 0 for doc in docs)
nlp.remove_pipe('zshot')
del docs, nlp


def test_gliner_mentions_extractor_no_entities():
nlp = spacy.blank("en")
gliner_config = PipelineConfig(
mentions_extractor=MentionsExtractorGLINER(),
mentions=[]
)
nlp.add_pipe("zshot", config=gliner_config, last=True)
assert "zshot" in nlp.pipe_names

doc = nlp(EX_DOCS[1])
assert len(doc._.mentions) == 0
nlp.remove_pipe('zshot')
del doc, nlp

0 comments on commit 83553c9

Please sign in to comment.