diff --git a/multi_tokenizer/tokenizer.py b/multi_tokenizer/tokenizer.py index 1860651..7181a71 100644 --- a/multi_tokenizer/tokenizer.py +++ b/multi_tokenizer/tokenizer.py @@ -1,6 +1,7 @@ """Multi Tokenizer Module.""" import pickle +from typing import Any from lingua import Language @@ -14,6 +15,7 @@ class MultiTokenizer: def __init__( self, tokenizers: list[LanguageSpecificTokenizer | PretrainedTokenizers], + fallback_tokenizer: Any = None, split_text: bool = False, sep: str = " ", ) -> None: @@ -26,6 +28,9 @@ def __init__( ) for tokenizer in tokenizers ] + self.fallback_tokenizer = ( + self.tokenizers[0] if fallback_tokenizer is None else fallback_tokenizer + ) self.language_prefix_token_ids = [ tokenizer.language_prefix_token[1] for tokenizer in self.tokenizers ] @@ -43,8 +48,18 @@ def pre_tokenize(self, text: str) -> list[tuple[str, tuple[int, int]]]: if not self.split_text else self.language_detector.split_n_detect(text, self.sep) ) + last_end_index = 0 for detection in language_detections: + # If there is text between the last detected language and the current one + if detection.start_index != last_end_index: + pre_tokenized_text.append( + ( + text[last_end_index : detection.start_index], + (last_end_index, detection.start_index), + ) + ) detected_text = text[detection.start_index : detection.end_index] + last_end_index = detection.end_index tokenizer = self.get_tokenizer_by_language(detection.language) output: list[tuple[str, tuple[int, int]]] = ( tokenizer.tokenizer.pre_tokenizer.pre_tokenize_str(detected_text) @@ -71,6 +86,11 @@ def pre_tokenize(self, text: str) -> list[tuple[str, tuple[int, int]]]: for token, (start, end) in output ] pre_tokenized_text.extend(output) + # If there is text after the last detected language + if last_end_index < len(text): + pre_tokenized_text.append( + (text[last_end_index:], (last_end_index, len(text))) + ) return pre_tokenized_text def get_tokenizer_by_language( @@ -105,7 +125,14 @@ def encode(self, text: str) -> tuple[list[int], list[str]]: if not self.split_text else self.language_detector.split_n_detect(text, self.sep) ) + last_end_index = 0 for detection in language_detections: + if detection.start_index != last_end_index: + encoding = self.fallback_tokenizer.encode( + text[last_end_index : detection.start_index] + ) + ids.extend(encoding.ids) + tokens.extend(encoding.tokens) detected_text = text[detection.start_index : detection.end_index] tokenizer = self.get_tokenizer_by_language(detection.language) detected_text = ( @@ -116,6 +143,11 @@ def encode(self, text: str) -> tuple[list[int], list[str]]: encoding = tokenizer.tokenizer.encode(detected_text) ids.extend(encoding.ids) tokens.extend(encoding.tokens) + last_end_index = detection.end_index + if last_end_index < len(text): + encoding = self.fallback_tokenizer.encode(text[last_end_index:]) + ids.extend(encoding.ids) + tokens.extend(encoding.tokens) return ids, tokens def decode(self, token_ids: list[int]) -> str: @@ -131,7 +163,15 @@ def decode(self, token_ids: list[int]) -> str: j += 1 decoded_str.append(cur_tokenizer.decode(token_ids[i : j + 1])) i = j + 1 - return " ".join(decoded_str) + else: + while ( + j < len(token_ids) + and token_ids[j] not in self.language_prefix_token_ids + ): + j += 1 + decoded_str.append(self.fallback_tokenizer.decode(token_ids[i:j])) + i = j + return "".join(decoded_str) def save(self, path: str) -> None: """Save Tokenizer.""" @@ -154,6 +194,4 @@ def get_vocab(self) -> dict[str, dict[str, int]]: def get_vocab_size(self) -> int: """Get Vocabulary Size.""" vocab = self.get_vocab() - return sum( - len(vocab[language]) for language in vocab - ) # TODO: This is probably wrong + return max(len(vocab[language]) for language in vocab) diff --git a/support/try_multitokenizer.ipynb b/support/try_multitokenizer.ipynb index 03b4a2e..5ccfe06 100644 --- a/support/try_multitokenizer.ipynb +++ b/support/try_multitokenizer.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 16, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -39,6 +39,7 @@ " ('Ġto', (30, 33)),\n", " ('Ġenglish', (33, 41)),\n", " ('', (39, 40)),\n", + " (' ', (40, 41)),\n", " ('', (41, 42)),\n", " ('-', (42, 43)),\n", " ('Ġब', (43, 45)),\n", @@ -57,10 +58,11 @@ " ('र', (60, 61)),\n", " ('à¥Ģ', (61, 62)),\n", " ('Ġह', (62, 64)),\n", - " ('', (62, 63))]" + " ('', (62, 63)),\n", + " ('ै.', (63, 65))]" ] }, - "execution_count": 34, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -74,16 +76,36 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "48840" + "65" ] }, - "execution_count": 36, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(sentence)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "25000" + ] + }, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -94,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -103,67 +125,55 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['', 'Tr', 'ans', 'l', 'ate', 'Ġthis', 'Ġhind', 'i', 'Ġsentence', 'Ġto', 'Ġeng', 'lish', '', 'Ġ', '', '-', 'Ġब', 'ि', 'ल', 'à¥į', 'ल', 'à¥Ģ', 'Ġबह', 'à¥ģ', 'त', 'Ġप', 'à¥į', 'य', 'ा', 'र', 'à¥Ģ', 'Ġह', '', 'à', '¥', 'Ī', '.']\n", + "[3, 7235, 6614, 86, 755, 775, 10763, 83, 19412, 276, 3602, 9113, 4, 231, 9, 23, 290, 277, 285, 282, 285, 273, 342, 286, 283, 294, 282, 292, 270, 272, 273, 287, 10, 167, 109, 241, 24]\n" + ] + } + ], + "source": [ + "print(tokens)\n", + "print(ids)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['',\n", - " 'Tr',\n", - " 'ans',\n", - " 'l',\n", - " 'ate',\n", - " 'Ġthis',\n", - " 'Ġhind',\n", - " 'i',\n", - " 'Ġsentence',\n", - " 'Ġto',\n", - " 'Ġeng',\n", - " 'lish',\n", - " '',\n", - " '',\n", - " '-',\n", - " 'Ġब',\n", - " 'ि',\n", - " 'ल',\n", - " 'à¥į',\n", - " 'ल',\n", - " 'à¥Ģ',\n", - " 'Ġबह',\n", - " 'à¥ģ',\n", - " 'त',\n", - " 'Ġप',\n", - " 'à¥į',\n", - " 'य',\n", - " 'ा',\n", - " 'र',\n", - " 'à¥Ģ',\n", - " 'Ġह',\n", - " '']" + "37" ] }, - "execution_count": 38, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "tokens" + "len(tokens)" ] }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'Translate this hindi sentence to english - बिल्ली बहुत प्यारी ह'" + "'Translate this hindi sentence to english - बिल्ली बहुत प्यारी है.'" ] }, - "execution_count": 39, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -174,7 +184,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -183,7 +193,7 @@ "'Translate this hindi sentence to english - बिल्ली बहुत प्यारी है.'" ] }, - "execution_count": 40, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -194,10 +204,60 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + } + ], + "source": [ + "from transformers import AutoTokenizer\n", + "\n", + "aya_tokenizer = AutoTokenizer.from_pretrained(\"CohereForAI/aya-23-8B\")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "24 ['Translate', 'Ġthis', 'Ġhindi', 'Ġsentence', 'Ġto', 'Ġenglish', 'Ġ-', 'Ġब', 'ि', 'ल', 'à¥į', 'ल', 'à¥Ģ', 'Ġबह', 'à¥ģ', 'त', 'Ġप', 'à¥į', 'य', 'ा', 'र', 'à¥Ģ', 'Ġह', 'à¥Ī.']\n" + ] + } + ], + "source": [ + "tokens = aya_tokenizer.tokenize(sentence)\n", + "print(len(tokens), tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "255029" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(aya_tokenizer.get_vocab())" + ] }, { "cell_type": "code",