Skip to content

Commit

Permalink
updated the tokenizers decode and encode functions
Browse files Browse the repository at this point in the history
  • Loading branch information
chandralegend committed Jul 21, 2024
1 parent 242d1c9 commit a4ef7b0
Show file tree
Hide file tree
Showing 5 changed files with 1,308 additions and 95 deletions.
1 change: 0 additions & 1 deletion multi_tokenizer/language_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def merge_results(
return merged_results

texts = text.split(sep)
print(texts)
results = self.batch_detect(texts)
merged_results = merge_results(results)
return merged_results
Expand Down
21 changes: 12 additions & 9 deletions multi_tokenizer/pretrained/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def __init__(
self,
tokenizer_path: str,
language: Language,
language_prefix: str,
language_suffix: str,
language_prefix: tuple[str, int],
language_suffix: tuple[str, int],
) -> None:
"""Initialize Language Specific Tokenizer."""
self.language = language
Expand All @@ -39,23 +39,26 @@ class PretrainedTokenizers(Enum):
ENGLISH = LanguageSpecificTokenizer(
os.path.join(file_dir, "english_tokenizer.json"),
Language.ENGLISH,
"<EN>",
"</EN>",
("<EN>", 3),
("</EN>", 4),
)
SPANISH = LanguageSpecificTokenizer(
os.path.join(file_dir, "spanish_tokenizer.json"),
Language.SPANISH,
"<ES>",
"</ES>",
("<ES>", 5),
("</ES>", 6),
)
CHINESE = LanguageSpecificTokenizer(
os.path.join(file_dir, "chinese_tokenizer.json"),
Language.CHINESE,
"<ZH>",
"</ZH>",
("<ZH>", 7),
("</ZH>", 8),
)
HINDI = LanguageSpecificTokenizer(
os.path.join(file_dir, "hindi_tokenizer.json"), Language.HINDI, "<HI>", "</HI>"
os.path.join(file_dir, "hindi_tokenizer.json"),
Language.HINDI,
("<HI>", 9),
("</HI>", 10),
)


Expand Down
79 changes: 65 additions & 14 deletions multi_tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,10 @@

import pickle

from multi_tokenizer.language_detect import LanguageDetector
from multi_tokenizer.pretrained import (
LanguageSpecificTokenizer,
PretrainedTokenizers,
get_tokenizer_by_language,
)
from lingua import Language

from tokenizers import Encoding
from multi_tokenizer.language_detect import LanguageDetector
from multi_tokenizer.pretrained import LanguageSpecificTokenizer, PretrainedTokenizers


class MultiTokenizer:
Expand All @@ -30,6 +26,9 @@ def __init__(
)
for tokenizer in tokenizers
]
self.language_prefix_token_ids = [
tokenizer.language_prefix_token[1] for tokenizer in self.tokenizers
]
self.language_detector = LanguageDetector(
[tokenizer.language for tokenizer in self.tokenizers]
)
Expand All @@ -46,16 +45,16 @@ def pre_tokenize(self, text: str) -> list[tuple[str, tuple[int, int]]]:
)
for detection in language_detections:
detected_text = text[detection.start_index : detection.end_index]
tokenizer = get_tokenizer_by_language(detection.language)
tokenizer = self.get_tokenizer_by_language(detection.language)
output: list[tuple[str, tuple[int, int]]] = (
tokenizer.tokenizer.pre_tokenizer.pre_tokenize_str(detected_text)
)
output = (
[(tokenizer.language_prefix_token, (-1, 0))]
[(tokenizer.language_prefix_token[0], (-1, 0))]
+ output
+ [
(
tokenizer.language_suffix_token,
tokenizer.language_suffix_token[0],
(len(detected_text) - 2, len(detected_text) - 1),
)
]
Expand All @@ -74,13 +73,65 @@ def pre_tokenize(self, text: str) -> list[tuple[str, tuple[int, int]]]:
pre_tokenized_text.extend(output)
return pre_tokenized_text

def encode(self, text: str) -> Encoding:
def get_tokenizer_by_language(
self, language: Language
) -> LanguageSpecificTokenizer:
"""Get Tokenizer for Language."""
for tokenizer in self.tokenizers:
if tokenizer.language == language:
return tokenizer
raise ValueError(f"Tokenizer for {language} not found.")

def get_tokenizer_by_prefix(self, prefix: str) -> LanguageSpecificTokenizer:
"""Get Tokenizer by Prefix."""
for tokenizer in self.tokenizers:
if tokenizer.language_prefix_token[0] == prefix:
return tokenizer
raise ValueError(f"Tokenizer for prefix {prefix} not found.")

def get_tokenizer_by_prefix_id(self, prefix_id: int) -> LanguageSpecificTokenizer:
"""Get Tokenizer by Prefix ID."""
for tokenizer in self.tokenizers:
if tokenizer.language_prefix_token[1] == prefix_id:
return tokenizer
raise ValueError(f"Tokenizer for prefix ID {prefix_id} not found.")

def encode(self, text: str) -> tuple[list[int], list[str]]:
"""Encode Text."""
raise NotImplementedError
ids = []
tokens = []
language_detections = (
self.language_detector.detect(text)
if not self.split_text
else self.language_detector.split_n_detect(text, self.sep)
)
for detection in language_detections:
detected_text = text[detection.start_index : detection.end_index]
tokenizer = self.get_tokenizer_by_language(detection.language)
detected_text = (
tokenizer.language_prefix_token[0]
+ detected_text
+ tokenizer.language_suffix_token[0]
)
encoding = tokenizer.tokenizer.encode(detected_text)
ids.extend(encoding.ids)
tokens.extend(encoding.tokens)
return ids, tokens

def decode(self, encoding: Encoding) -> str:
def decode(self, token_ids: list[int]) -> str:
"""Decode Encoding."""
raise NotImplementedError
decoded_str = []
cur_tokenizer = None
i, j = 0, 0
while i < len(token_ids):
if token_ids[i] in self.language_prefix_token_ids:
cur_tokenizer = self.get_tokenizer_by_prefix_id(token_ids[i])
j = i + 1
while token_ids[j] != cur_tokenizer.language_suffix_token[1]:
j += 1
decoded_str.append(cur_tokenizer.decode(token_ids[i : j + 1]))
i = j + 1
return " ".join(decoded_str)

def save(self, path: str) -> None:
"""Save Tokenizer."""
Expand Down
Loading

0 comments on commit a4ef7b0

Please sign in to comment.