diff --git a/multi_tokenizer/tokenizer.py b/multi_tokenizer/tokenizer.py
index 1860651..7181a71 100644
--- a/multi_tokenizer/tokenizer.py
+++ b/multi_tokenizer/tokenizer.py
@@ -1,6 +1,7 @@
"""Multi Tokenizer Module."""
import pickle
+from typing import Any
from lingua import Language
@@ -14,6 +15,7 @@ class MultiTokenizer:
def __init__(
self,
tokenizers: list[LanguageSpecificTokenizer | PretrainedTokenizers],
+ fallback_tokenizer: Any = None,
split_text: bool = False,
sep: str = " ",
) -> None:
@@ -26,6 +28,9 @@ def __init__(
)
for tokenizer in tokenizers
]
+ self.fallback_tokenizer = (
+ self.tokenizers[0] if fallback_tokenizer is None else fallback_tokenizer
+ )
self.language_prefix_token_ids = [
tokenizer.language_prefix_token[1] for tokenizer in self.tokenizers
]
@@ -43,8 +48,18 @@ def pre_tokenize(self, text: str) -> list[tuple[str, tuple[int, int]]]:
if not self.split_text
else self.language_detector.split_n_detect(text, self.sep)
)
+ last_end_index = 0
for detection in language_detections:
+ # If there is text between the last detected language and the current one
+ if detection.start_index != last_end_index:
+ pre_tokenized_text.append(
+ (
+ text[last_end_index : detection.start_index],
+ (last_end_index, detection.start_index),
+ )
+ )
detected_text = text[detection.start_index : detection.end_index]
+ last_end_index = detection.end_index
tokenizer = self.get_tokenizer_by_language(detection.language)
output: list[tuple[str, tuple[int, int]]] = (
tokenizer.tokenizer.pre_tokenizer.pre_tokenize_str(detected_text)
@@ -71,6 +86,11 @@ def pre_tokenize(self, text: str) -> list[tuple[str, tuple[int, int]]]:
for token, (start, end) in output
]
pre_tokenized_text.extend(output)
+ # If there is text after the last detected language
+ if last_end_index < len(text):
+ pre_tokenized_text.append(
+ (text[last_end_index:], (last_end_index, len(text)))
+ )
return pre_tokenized_text
def get_tokenizer_by_language(
@@ -105,7 +125,14 @@ def encode(self, text: str) -> tuple[list[int], list[str]]:
if not self.split_text
else self.language_detector.split_n_detect(text, self.sep)
)
+ last_end_index = 0
for detection in language_detections:
+ if detection.start_index != last_end_index:
+ encoding = self.fallback_tokenizer.encode(
+ text[last_end_index : detection.start_index]
+ )
+ ids.extend(encoding.ids)
+ tokens.extend(encoding.tokens)
detected_text = text[detection.start_index : detection.end_index]
tokenizer = self.get_tokenizer_by_language(detection.language)
detected_text = (
@@ -116,6 +143,11 @@ def encode(self, text: str) -> tuple[list[int], list[str]]:
encoding = tokenizer.tokenizer.encode(detected_text)
ids.extend(encoding.ids)
tokens.extend(encoding.tokens)
+ last_end_index = detection.end_index
+ if last_end_index < len(text):
+ encoding = self.fallback_tokenizer.encode(text[last_end_index:])
+ ids.extend(encoding.ids)
+ tokens.extend(encoding.tokens)
return ids, tokens
def decode(self, token_ids: list[int]) -> str:
@@ -131,7 +163,15 @@ def decode(self, token_ids: list[int]) -> str:
j += 1
decoded_str.append(cur_tokenizer.decode(token_ids[i : j + 1]))
i = j + 1
- return " ".join(decoded_str)
+ else:
+ while (
+ j < len(token_ids)
+ and token_ids[j] not in self.language_prefix_token_ids
+ ):
+ j += 1
+ decoded_str.append(self.fallback_tokenizer.decode(token_ids[i:j]))
+ i = j
+ return "".join(decoded_str)
def save(self, path: str) -> None:
"""Save Tokenizer."""
@@ -154,6 +194,4 @@ def get_vocab(self) -> dict[str, dict[str, int]]:
def get_vocab_size(self) -> int:
"""Get Vocabulary Size."""
vocab = self.get_vocab()
- return sum(
- len(vocab[language]) for language in vocab
- ) # TODO: This is probably wrong
+ return max(len(vocab[language]) for language in vocab)
diff --git a/support/try_multitokenizer.ipynb b/support/try_multitokenizer.ipynb
index 03b4a2e..5ccfe06 100644
--- a/support/try_multitokenizer.ipynb
+++ b/support/try_multitokenizer.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -11,7 +11,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -25,7 +25,7 @@
},
{
"cell_type": "code",
- "execution_count": 34,
+ "execution_count": 14,
"metadata": {},
"outputs": [
{
@@ -39,6 +39,7 @@
" ('Ġto', (30, 33)),\n",
" ('Ġenglish', (33, 41)),\n",
" ('', (39, 40)),\n",
+ " (' ', (40, 41)),\n",
" ('', (41, 42)),\n",
" ('-', (42, 43)),\n",
" ('Ġब', (43, 45)),\n",
@@ -57,10 +58,11 @@
" ('र', (60, 61)),\n",
" ('à¥Ģ', (61, 62)),\n",
" ('Ġह', (62, 64)),\n",
- " ('', (62, 63))]"
+ " ('', (62, 63)),\n",
+ " ('ै.', (63, 65))]"
]
},
- "execution_count": 34,
+ "execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@@ -74,16 +76,36 @@
},
{
"cell_type": "code",
- "execution_count": 36,
+ "execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "48840"
+ "65"
]
},
- "execution_count": 36,
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(sentence)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "25000"
+ ]
+ },
+ "execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@@ -94,7 +116,7 @@
},
{
"cell_type": "code",
- "execution_count": 37,
+ "execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@@ -103,67 +125,55 @@
},
{
"cell_type": "code",
- "execution_count": 38,
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['', 'Tr', 'ans', 'l', 'ate', 'Ġthis', 'Ġhind', 'i', 'Ġsentence', 'Ġto', 'Ġeng', 'lish', '', 'Ġ', '', '-', 'Ġब', 'ि', 'ल', 'à¥į', 'ल', 'à¥Ģ', 'Ġबह', 'à¥ģ', 'त', 'Ġप', 'à¥į', 'य', 'ा', 'र', 'à¥Ģ', 'Ġह', '', 'à', '¥', 'Ī', '.']\n",
+ "[3, 7235, 6614, 86, 755, 775, 10763, 83, 19412, 276, 3602, 9113, 4, 231, 9, 23, 290, 277, 285, 282, 285, 273, 342, 286, 283, 294, 282, 292, 270, 272, 273, 287, 10, 167, 109, 241, 24]\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(tokens)\n",
+ "print(ids)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "['',\n",
- " 'Tr',\n",
- " 'ans',\n",
- " 'l',\n",
- " 'ate',\n",
- " 'Ġthis',\n",
- " 'Ġhind',\n",
- " 'i',\n",
- " 'Ġsentence',\n",
- " 'Ġto',\n",
- " 'Ġeng',\n",
- " 'lish',\n",
- " '',\n",
- " '',\n",
- " '-',\n",
- " 'Ġब',\n",
- " 'ि',\n",
- " 'ल',\n",
- " 'à¥į',\n",
- " 'ल',\n",
- " 'à¥Ģ',\n",
- " 'Ġबह',\n",
- " 'à¥ģ',\n",
- " 'त',\n",
- " 'Ġप',\n",
- " 'à¥į',\n",
- " 'य',\n",
- " 'ा',\n",
- " 'र',\n",
- " 'à¥Ģ',\n",
- " 'Ġह',\n",
- " '']"
+ "37"
]
},
- "execution_count": 38,
+ "execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "tokens"
+ "len(tokens)"
]
},
{
"cell_type": "code",
- "execution_count": 39,
+ "execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "'Translate this hindi sentence to english - बिल्ली बहुत प्यारी ह'"
+ "'Translate this hindi sentence to english - बिल्ली बहुत प्यारी है.'"
]
},
- "execution_count": 39,
+ "execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
@@ -174,7 +184,7 @@
},
{
"cell_type": "code",
- "execution_count": 40,
+ "execution_count": 21,
"metadata": {},
"outputs": [
{
@@ -183,7 +193,7 @@
"'Translate this hindi sentence to english - बिल्ली बहुत प्यारी है.'"
]
},
- "execution_count": 40,
+ "execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
@@ -194,10 +204,60 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 22,
"metadata": {},
- "outputs": [],
- "source": []
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from transformers import AutoTokenizer\n",
+ "\n",
+ "aya_tokenizer = AutoTokenizer.from_pretrained(\"CohereForAI/aya-23-8B\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "24 ['Translate', 'Ġthis', 'Ġhindi', 'Ġsentence', 'Ġto', 'Ġenglish', 'Ġ-', 'Ġब', 'ि', 'ल', 'à¥į', 'ल', 'à¥Ģ', 'Ġबह', 'à¥ģ', 'त', 'Ġप', 'à¥į', 'य', 'ा', 'र', 'à¥Ģ', 'Ġह', 'à¥Ī.']\n"
+ ]
+ }
+ ],
+ "source": [
+ "tokens = aya_tokenizer.tokenize(sentence)\n",
+ "print(len(tokens), tokens)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "255029"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(aya_tokenizer.get_vocab())"
+ ]
},
{
"cell_type": "code",