Skip to content

Commit

Permalink
✨ Add regen wikification (#44)
Browse files Browse the repository at this point in the history
* ✨ Add regen wikification

* 🐛 Fix Wikification

* 🐛 Reduce tests complexity

* 🐛 Reduce test resources

* 🐛 Fix test

* ➖ Remove test file

* ✏️ Remove too expensive test
  • Loading branch information
GabrielePicco authored Dec 28, 2022
1 parent b2c2a27 commit 07c9893
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 15 deletions.
22 changes: 13 additions & 9 deletions zshot/linker/linker_regen/linker_regen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,37 @@

class LinkerRegen(Linker):
""" REGEN linker """
def __init__(self, max_input_len=384, max_output_len=15, num_beams=10):
def __init__(self, max_input_len=384, max_output_len=15, num_beams=10, trie=None):
"""
:param max_input_len: Max length of input
:param max_output_len: Max length of output
:param num_beams: Number of beans to use
:param trie: If the trie is given the linker will use it to restrict the search space.
Custom entities won't be used if the trie is given.
"""
super().__init__()
self.model = None
self.tokenizer = None
self.trie = None
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.num_beams = num_beams
self.skip_set_kg = False if trie is None else True
self.trie = trie

def set_kg(self, entities: Iterator[Entity]):
""" Set new entities
:param entities: New entities to use
"""
super().set_kg(entities)
self.load_tokenizer()
self.trie = Trie(
[
self.tokenizer(e.name, return_tensors="pt")['input_ids'][0].tolist()
for e in entities
]
)
if not self.skip_set_kg:
self.load_tokenizer()
self.trie = Trie(
[
self.tokenizer(e.name, return_tensors="pt")['input_ids'][0].tolist()
for e in entities
]
)

def load_models(self):
""" Load Model """
Expand Down
8 changes: 4 additions & 4 deletions zshot/linker/linker_regen/trie.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from typing import List
from typing import Collection


class Trie(object):
def __init__(self, sequences: List[List[int]] = []):
def __init__(self, sequences: Collection[Collection[int]] = []):
self.trie_dict = {}
for sequence in sequences:
self.add(sequence)

def add(self, sequence: List[int]):
def add(self, sequence: Collection[int]):
trie = self.trie_dict
for idx in sequence:
if idx not in trie:
trie[idx] = {}
trie = trie[idx]

def postfix(self, prefix_sequence: List[int]):
def postfix(self, prefix_sequence: Collection[int]):
if len(prefix_sequence) == 1:
return list(self.trie_dict.keys())
trie = self.trie_dict
Expand Down
57 changes: 57 additions & 0 deletions zshot/linker/linker_regen/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
import json
import pickle
from typing import Dict, List

import pytest
from huggingface_hub import hf_hub_download

from zshot.linker.linker_regen.trie import Trie
from zshot.utils.data_models import Span

REPO_ID = "ibm/regen-disambiguation"
TRIE_FILE_NAME = "wikipedia_trie.pkl"
WIKIPEDIA_MAP = "wikipedia_map_id.json"


def create_input(sentence, max_length, start_delimiter, end_delimiter):
sent_list = sentence.split(" ")
if len(sent_list) < max_length:
Expand All @@ -12,3 +27,45 @@ def create_input(sentence, max_length, start_delimiter, end_delimiter):
left_index = left_index - max(0, (half_context - (right_index - end_delimiter_index)))
print(len(sent_list[left_index:right_index]))
return " ".join(sent_list[left_index:right_index])


def load_wikipedia_trie() -> Trie:
"""
Load the wikipedia trie from the HB hub
:return: The Wikipedia trie
"""
wikipedia_trie_file = hf_hub_download(repo_id=REPO_ID,
repo_type='model',
filename=TRIE_FILE_NAME)
with open(wikipedia_trie_file, "rb") as f:
wikipedia_trie = pickle.load(f)
return wikipedia_trie


@pytest.mark.skip(reason="Too expensive to run on every commit")
def load_wikipedia_mapping() -> Dict[str, str]:
"""
Load the wikipedia trie from the HB hub
:return: The Wikipedia trie
"""
wikipedia_map = hf_hub_download(repo_id=REPO_ID,
repo_type='model',
filename=WIKIPEDIA_MAP)
with open(wikipedia_map, "r") as f:
wikipedia_map = json.load(f)
return wikipedia_map


def spans_to_wikipedia(spans: List[Span]) -> List[str]:
"""
Generate wikipedia link for spans
:return: The list of generated links
"""
links = []
wikipedia_map = load_wikipedia_mapping()
for s in spans:
if s.label in wikipedia_map:
links.append(f"https://en.wikipedia.org/wiki?curid={wikipedia_map[s.label]}")
else:
links.append(None)
return links
41 changes: 39 additions & 2 deletions zshot/tests/linker/test_regen_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@

from zshot import PipelineConfig
from zshot.linker.linker_regen.linker_regen import LinkerRegen
from zshot.linker.linker_regen.trie import Trie
from zshot.linker.linker_regen.utils import load_wikipedia_trie, spans_to_wikipedia
from zshot.mentions_extractor import MentionsExtractorSpacy
from zshot.tests.config import EX_DOCS, EX_ENTITIES
from zshot.tests.mentions_extractor.test_mention_extractor import DummyMentionsExtractor
from zshot.utils.data_models import Span

logger = logging.getLogger(__name__)

Expand All @@ -25,9 +29,9 @@ def teardown():


def test_regen_linker():
nlp = spacy.load("en_core_web_sm")
nlp = spacy.blank("en")
config = PipelineConfig(
mentions_extractor=MentionsExtractorSpacy(),
mentions_extractor=DummyMentionsExtractor(),
linker=LinkerRegen(),
entities=EX_ENTITIES
)
Expand Down Expand Up @@ -60,3 +64,36 @@ def test_regen_linker_pipeline():
nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker
nlp.remove_pipe('zshot')
del docs, nlp, config


def test_regen_linker_wikification():
nlp = spacy.blank("en")
trie = Trie()
trie.add([794, 536, 1])
trie.add([794, 357, 1])
config = PipelineConfig(
mentions_extractor=DummyMentionsExtractor(),
linker=LinkerRegen(trie=trie),
)
nlp.add_pipe("zshot", config=config, last=True)
assert "zshot" in nlp.pipe_names

doc = nlp(EX_DOCS[1])
assert len(doc.ents) > 0
del nlp.get_pipe('zshot').mentions_extractor, nlp.get_pipe('zshot').entities, nlp.get_pipe('zshot').nlp
del nlp.get_pipe('zshot').linker.tokenizer, nlp.get_pipe('zshot').linker.trie, \
nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker
nlp.remove_pipe('zshot')
del doc, nlp, config


def test_load_wikipedia_trie():
trie = load_wikipedia_trie()
assert len(list(trie.trie_dict.keys())) == 6952


def test_span_to_wiki():
s = Span(label="Surfing", start=0, end=10)
wiki_links = spans_to_wikipedia([s])
assert len(wiki_links) > 0
assert wiki_links[0].startswith("https://en.wikipedia.org/wiki?curid=")

0 comments on commit 07c9893

Please sign in to comment.