Skip to content

Commit

Permalink
added fallback mechanism for un detected languages
Browse files Browse the repository at this point in the history
  • Loading branch information
chandralegend committed Jul 22, 2024
1 parent a4ef7b0 commit d6e83b4
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 56 deletions.
46 changes: 42 additions & 4 deletions multi_tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Multi Tokenizer Module."""

import pickle
from typing import Any

from lingua import Language

Expand All @@ -14,6 +15,7 @@ class MultiTokenizer:
def __init__(
self,
tokenizers: list[LanguageSpecificTokenizer | PretrainedTokenizers],
fallback_tokenizer: Any = None,
split_text: bool = False,
sep: str = " ",
) -> None:
Expand All @@ -26,6 +28,9 @@ def __init__(
)
for tokenizer in tokenizers
]
self.fallback_tokenizer = (
self.tokenizers[0] if fallback_tokenizer is None else fallback_tokenizer
)
self.language_prefix_token_ids = [
tokenizer.language_prefix_token[1] for tokenizer in self.tokenizers
]
Expand All @@ -43,8 +48,18 @@ def pre_tokenize(self, text: str) -> list[tuple[str, tuple[int, int]]]:
if not self.split_text
else self.language_detector.split_n_detect(text, self.sep)
)
last_end_index = 0
for detection in language_detections:
# If there is text between the last detected language and the current one
if detection.start_index != last_end_index:
pre_tokenized_text.append(
(
text[last_end_index : detection.start_index],
(last_end_index, detection.start_index),
)
)
detected_text = text[detection.start_index : detection.end_index]
last_end_index = detection.end_index
tokenizer = self.get_tokenizer_by_language(detection.language)
output: list[tuple[str, tuple[int, int]]] = (
tokenizer.tokenizer.pre_tokenizer.pre_tokenize_str(detected_text)
Expand All @@ -71,6 +86,11 @@ def pre_tokenize(self, text: str) -> list[tuple[str, tuple[int, int]]]:
for token, (start, end) in output
]
pre_tokenized_text.extend(output)
# If there is text after the last detected language
if last_end_index < len(text):
pre_tokenized_text.append(
(text[last_end_index:], (last_end_index, len(text)))
)
return pre_tokenized_text

def get_tokenizer_by_language(
Expand Down Expand Up @@ -105,7 +125,14 @@ def encode(self, text: str) -> tuple[list[int], list[str]]:
if not self.split_text
else self.language_detector.split_n_detect(text, self.sep)
)
last_end_index = 0
for detection in language_detections:
if detection.start_index != last_end_index:
encoding = self.fallback_tokenizer.encode(
text[last_end_index : detection.start_index]
)
ids.extend(encoding.ids)
tokens.extend(encoding.tokens)
detected_text = text[detection.start_index : detection.end_index]
tokenizer = self.get_tokenizer_by_language(detection.language)
detected_text = (
Expand All @@ -116,6 +143,11 @@ def encode(self, text: str) -> tuple[list[int], list[str]]:
encoding = tokenizer.tokenizer.encode(detected_text)
ids.extend(encoding.ids)
tokens.extend(encoding.tokens)
last_end_index = detection.end_index
if last_end_index < len(text):
encoding = self.fallback_tokenizer.encode(text[last_end_index:])
ids.extend(encoding.ids)
tokens.extend(encoding.tokens)
return ids, tokens

def decode(self, token_ids: list[int]) -> str:
Expand All @@ -131,7 +163,15 @@ def decode(self, token_ids: list[int]) -> str:
j += 1
decoded_str.append(cur_tokenizer.decode(token_ids[i : j + 1]))
i = j + 1
return " ".join(decoded_str)
else:
while (
j < len(token_ids)
and token_ids[j] not in self.language_prefix_token_ids
):
j += 1
decoded_str.append(self.fallback_tokenizer.decode(token_ids[i:j]))
i = j
return "".join(decoded_str)

def save(self, path: str) -> None:
"""Save Tokenizer."""
Expand All @@ -154,6 +194,4 @@ def get_vocab(self) -> dict[str, dict[str, int]]:
def get_vocab_size(self) -> int:
"""Get Vocabulary Size."""
vocab = self.get_vocab()
return sum(
len(vocab[language]) for language in vocab
) # TODO: This is probably wrong
return max(len(vocab[language]) for language in vocab)
164 changes: 112 additions & 52 deletions support/try_multitokenizer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -25,7 +25,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 14,
"metadata": {},
"outputs": [
{
Expand All @@ -39,6 +39,7 @@
" ('Ġto', (30, 33)),\n",
" ('Ġenglish', (33, 41)),\n",
" ('</EN>', (39, 40)),\n",
" (' ', (40, 41)),\n",
" ('<HI>', (41, 42)),\n",
" ('-', (42, 43)),\n",
" ('Ġब', (43, 45)),\n",
Expand All @@ -57,10 +58,11 @@
" ('र', (60, 61)),\n",
" ('à¥Ģ', (61, 62)),\n",
" ('Ġह', (62, 64)),\n",
" ('</HI>', (62, 63))]"
" ('</HI>', (62, 63)),\n",
" ('ै.', (63, 65))]"
]
},
"execution_count": 34,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -74,16 +76,36 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"48840"
"65"
]
},
"execution_count": 36,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(sentence)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"25000"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -94,7 +116,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -103,67 +125,55 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['<EN>', 'Tr', 'ans', 'l', 'ate', 'Ġthis', 'Ġhind', 'i', 'Ġsentence', 'Ġto', 'Ġeng', 'lish', '</EN>', 'Ġ', '<HI>', '-', 'Ġब', 'ि', 'ल', 'à¥į', 'ल', 'à¥Ģ', 'Ġबह', 'à¥ģ', 'त', 'Ġप', 'à¥į', 'य', 'ा', 'र', 'à¥Ģ', 'Ġह', '</HI>', 'à', '¥', 'Ī', '.']\n",
"[3, 7235, 6614, 86, 755, 775, 10763, 83, 19412, 276, 3602, 9113, 4, 231, 9, 23, 290, 277, 285, 282, 285, 273, 342, 286, 283, 294, 282, 292, 270, 272, 273, 287, 10, 167, 109, 241, 24]\n"
]
}
],
"source": [
"print(tokens)\n",
"print(ids)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['<EN>',\n",
" 'Tr',\n",
" 'ans',\n",
" 'l',\n",
" 'ate',\n",
" 'Ġthis',\n",
" 'Ġhind',\n",
" 'i',\n",
" 'Ġsentence',\n",
" 'Ġto',\n",
" 'Ġeng',\n",
" 'lish',\n",
" '</EN>',\n",
" '<HI>',\n",
" '-',\n",
" 'Ġब',\n",
" 'ि',\n",
" 'ल',\n",
" 'à¥į',\n",
" 'ल',\n",
" 'à¥Ģ',\n",
" 'Ġबह',\n",
" 'à¥ģ',\n",
" 'त',\n",
" 'Ġप',\n",
" 'à¥į',\n",
" 'य',\n",
" 'ा',\n",
" 'र',\n",
" 'à¥Ģ',\n",
" 'Ġह',\n",
" '</HI>']"
"37"
]
},
"execution_count": 38,
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokens"
"len(tokens)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Translate this hindi sentence to english - बिल्ली बहुत प्यारी '"
"'Translate this hindi sentence to english - बिल्ली बहुत प्यारी है.'"
]
},
"execution_count": 39,
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -174,7 +184,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 21,
"metadata": {},
"outputs": [
{
Expand All @@ -183,7 +193,7 @@
"'Translate this hindi sentence to english - बिल्ली बहुत प्यारी है.'"
]
},
"execution_count": 40,
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -194,10 +204,60 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
}
],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"aya_tokenizer = AutoTokenizer.from_pretrained(\"CohereForAI/aya-23-8B\")"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"24 ['Translate', 'Ġthis', 'Ġhindi', 'Ġsentence', 'Ġto', 'Ġenglish', 'Ġ-', 'Ġब', 'ि', 'ल', 'à¥į', 'ल', 'à¥Ģ', 'Ġबह', 'à¥ģ', 'त', 'Ġप', 'à¥į', 'य', 'ा', 'र', 'à¥Ģ', 'Ġह', 'à¥Ī.']\n"
]
}
],
"source": [
"tokens = aya_tokenizer.tokenize(sentence)\n",
"print(len(tokens), tokens)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"255029"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(aya_tokenizer.get_vocab())"
]
},
{
"cell_type": "code",
Expand Down

0 comments on commit d6e83b4

Please sign in to comment.