diff --git a/zshot/linker/linker_regen/linker_regen.py b/zshot/linker/linker_regen/linker_regen.py index d6ef49c..8285896 100644 --- a/zshot/linker/linker_regen/linker_regen.py +++ b/zshot/linker/linker_regen/linker_regen.py @@ -16,19 +16,22 @@ class LinkerRegen(Linker): """ REGEN linker """ - def __init__(self, max_input_len=384, max_output_len=15, num_beams=10): + def __init__(self, max_input_len=384, max_output_len=15, num_beams=10, trie=None): """ :param max_input_len: Max length of input :param max_output_len: Max length of output :param num_beams: Number of beans to use + :param trie: If the trie is given the linker will use it to restrict the search space. + Custom entities won't be used if the trie is given. """ super().__init__() self.model = None self.tokenizer = None - self.trie = None self.max_input_len = max_input_len self.max_output_len = max_output_len self.num_beams = num_beams + self.skip_set_kg = False if trie is None else True + self.trie = trie def set_kg(self, entities: Iterator[Entity]): """ Set new entities @@ -36,13 +39,14 @@ def set_kg(self, entities: Iterator[Entity]): :param entities: New entities to use """ super().set_kg(entities) - self.load_tokenizer() - self.trie = Trie( - [ - self.tokenizer(e.name, return_tensors="pt")['input_ids'][0].tolist() - for e in entities - ] - ) + if not self.skip_set_kg: + self.load_tokenizer() + self.trie = Trie( + [ + self.tokenizer(e.name, return_tensors="pt")['input_ids'][0].tolist() + for e in entities + ] + ) def load_models(self): """ Load Model """ diff --git a/zshot/linker/linker_regen/trie.py b/zshot/linker/linker_regen/trie.py index acf64f0..2ba2396 100644 --- a/zshot/linker/linker_regen/trie.py +++ b/zshot/linker/linker_regen/trie.py @@ -1,20 +1,20 @@ -from typing import List +from typing import Collection class Trie(object): - def __init__(self, sequences: List[List[int]] = []): + def __init__(self, sequences: Collection[Collection[int]] = []): self.trie_dict = {} for sequence in sequences: self.add(sequence) - def add(self, sequence: List[int]): + def add(self, sequence: Collection[int]): trie = self.trie_dict for idx in sequence: if idx not in trie: trie[idx] = {} trie = trie[idx] - def postfix(self, prefix_sequence: List[int]): + def postfix(self, prefix_sequence: Collection[int]): if len(prefix_sequence) == 1: return list(self.trie_dict.keys()) trie = self.trie_dict diff --git a/zshot/linker/linker_regen/utils.py b/zshot/linker/linker_regen/utils.py index 112424e..958f807 100644 --- a/zshot/linker/linker_regen/utils.py +++ b/zshot/linker/linker_regen/utils.py @@ -1,3 +1,18 @@ +import json +import pickle +from typing import Dict, List + +import pytest +from huggingface_hub import hf_hub_download + +from zshot.linker.linker_regen.trie import Trie +from zshot.utils.data_models import Span + +REPO_ID = "ibm/regen-disambiguation" +TRIE_FILE_NAME = "wikipedia_trie.pkl" +WIKIPEDIA_MAP = "wikipedia_map_id.json" + + def create_input(sentence, max_length, start_delimiter, end_delimiter): sent_list = sentence.split(" ") if len(sent_list) < max_length: @@ -12,3 +27,45 @@ def create_input(sentence, max_length, start_delimiter, end_delimiter): left_index = left_index - max(0, (half_context - (right_index - end_delimiter_index))) print(len(sent_list[left_index:right_index])) return " ".join(sent_list[left_index:right_index]) + + +def load_wikipedia_trie() -> Trie: + """ + Load the wikipedia trie from the HB hub + :return: The Wikipedia trie + """ + wikipedia_trie_file = hf_hub_download(repo_id=REPO_ID, + repo_type='model', + filename=TRIE_FILE_NAME) + with open(wikipedia_trie_file, "rb") as f: + wikipedia_trie = pickle.load(f) + return wikipedia_trie + + +@pytest.mark.skip(reason="Too expensive to run on every commit") +def load_wikipedia_mapping() -> Dict[str, str]: + """ + Load the wikipedia trie from the HB hub + :return: The Wikipedia trie + """ + wikipedia_map = hf_hub_download(repo_id=REPO_ID, + repo_type='model', + filename=WIKIPEDIA_MAP) + with open(wikipedia_map, "r") as f: + wikipedia_map = json.load(f) + return wikipedia_map + + +def spans_to_wikipedia(spans: List[Span]) -> List[str]: + """ + Generate wikipedia link for spans + :return: The list of generated links + """ + links = [] + wikipedia_map = load_wikipedia_mapping() + for s in spans: + if s.label in wikipedia_map: + links.append(f"https://en.wikipedia.org/wiki?curid={wikipedia_map[s.label]}") + else: + links.append(None) + return links diff --git a/zshot/tests/linker/test_regen_linker.py b/zshot/tests/linker/test_regen_linker.py index 662751f..5da3135 100644 --- a/zshot/tests/linker/test_regen_linker.py +++ b/zshot/tests/linker/test_regen_linker.py @@ -8,8 +8,12 @@ from zshot import PipelineConfig from zshot.linker.linker_regen.linker_regen import LinkerRegen +from zshot.linker.linker_regen.trie import Trie +from zshot.linker.linker_regen.utils import load_wikipedia_trie, spans_to_wikipedia from zshot.mentions_extractor import MentionsExtractorSpacy from zshot.tests.config import EX_DOCS, EX_ENTITIES +from zshot.tests.mentions_extractor.test_mention_extractor import DummyMentionsExtractor +from zshot.utils.data_models import Span logger = logging.getLogger(__name__) @@ -25,9 +29,9 @@ def teardown(): def test_regen_linker(): - nlp = spacy.load("en_core_web_sm") + nlp = spacy.blank("en") config = PipelineConfig( - mentions_extractor=MentionsExtractorSpacy(), + mentions_extractor=DummyMentionsExtractor(), linker=LinkerRegen(), entities=EX_ENTITIES ) @@ -60,3 +64,36 @@ def test_regen_linker_pipeline(): nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker nlp.remove_pipe('zshot') del docs, nlp, config + + +def test_regen_linker_wikification(): + nlp = spacy.blank("en") + trie = Trie() + trie.add([794, 536, 1]) + trie.add([794, 357, 1]) + config = PipelineConfig( + mentions_extractor=DummyMentionsExtractor(), + linker=LinkerRegen(trie=trie), + ) + nlp.add_pipe("zshot", config=config, last=True) + assert "zshot" in nlp.pipe_names + + doc = nlp(EX_DOCS[1]) + assert len(doc.ents) > 0 + del nlp.get_pipe('zshot').mentions_extractor, nlp.get_pipe('zshot').entities, nlp.get_pipe('zshot').nlp + del nlp.get_pipe('zshot').linker.tokenizer, nlp.get_pipe('zshot').linker.trie, \ + nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker + nlp.remove_pipe('zshot') + del doc, nlp, config + + +def test_load_wikipedia_trie(): + trie = load_wikipedia_trie() + assert len(list(trie.trie_dict.keys())) == 6952 + + +def test_span_to_wiki(): + s = Span(label="Surfing", start=0, end=10) + wiki_links = spans_to_wikipedia([s]) + assert len(wiki_links) > 0 + assert wiki_links[0].startswith("https://en.wikipedia.org/wiki?curid=")