From 0096cf3c909f2147c314e5ad263d9007da84424f Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Sat, 7 Sep 2024 10:27:40 -0400 Subject: [PATCH] patch werid hanging utterance bug --- batchalign/models/utterance/infer.py | 34 +++++++++++++++++---------- batchalign/pipelines/analysis/eval.py | 12 +++++++++- batchalign/version | 6 ++--- scratchpad.py | 9 ++++++- 4 files changed, 44 insertions(+), 17 deletions(-) diff --git a/batchalign/models/utterance/infer.py b/batchalign/models/utterance/infer.py index e5382ea..be8ed1a 100644 --- a/batchalign/models/utterance/infer.py +++ b/batchalign/models/utterance/infer.py @@ -35,6 +35,7 @@ def __init__(self, model): self.model.eval() def __call__(self, passage): + print(passage) # input passage words removed of all preexisting punctuation passage = passage.lower() passage = passage.replace('.','') @@ -67,7 +68,8 @@ def __call__(self, passage): prev_word_idx = None # for each word, perform the action - for indx, elem in enumerate(tokd.word_ids(0)): + wids = tokd.word_ids(0) + for indx, elem in enumerate(wids): # if its none, append nothing or if we have # seen it before, do nothing if elem is None or elem == prev_word_idx: @@ -81,23 +83,31 @@ def __call__(self, passage): # set the working variable w = input_tokenized[elem] - # perform the edit actions - if action == 1: - w = w[0].upper() + w[1:] - elif action == 2: - w = w+'.' - elif action == 3: - w = w+'?' - elif action == 4: - w = w+'!' - elif action == 5: - w = w+',' + # fix one word hanging issue + will_action = False + if indx < len(wids)-2 and classified_targets[0][indx+1] > 0: + will_action = True + + if not will_action: + # perform the edit actions + if action == 1: + w = w[0].upper() + w[1:] + elif action == 2: + w = w+'.' + elif action == 3: + w = w+'?' + elif action == 4: + w = w+'!' + elif action == 5: + w = w+',' + # append res_toks.append(w) # compose final passage final_passage = self.tokenizer.convert_tokens_to_string(res_toks) + print(final_passage) try: split_passage = sent_tokenize(final_passage) except LookupError: diff --git a/batchalign/pipelines/analysis/eval.py b/batchalign/pipelines/analysis/eval.py index 04574a4..9271256 100644 --- a/batchalign/pipelines/analysis/eval.py +++ b/batchalign/pipelines/analysis/eval.py @@ -8,7 +8,7 @@ from batchalign.pipelines.asr.utils import * from batchalign.utils.config import config_read -from batchalign.utils.dp import align, ExtraType, Extra +from batchalign.utils.dp import align, ExtraType, Extra, Match import logging L = logging.getLogger("batchalign") @@ -38,8 +38,16 @@ def __compute_wer(doc, gold): # ie: if we have +> substitution # but if we have this is 2 insertions + cleaned_alignment = [] + for i in alignment: + if isinstance(i, Extra): + if len(cleaned_alignment) > 0 and i.extra_type == ExtraType.REFERENCE and "name" in i.key and i.key[:4] != "name": + cleaned_alignment.pop(-1) + cleaned_alignment.append(Match(i.key, None, None)) + continue + if prev_error != None and prev_error != i.extra_type: # this is a substitution: we have different "extra"s in # reference vs. playload @@ -64,6 +72,8 @@ def __compute_wer(doc, gold): else: prev_error = None + cleaned_alignment.append(i) + diff = [] for i in alignment: if isinstance(i, Extra): diff --git a/batchalign/version b/batchalign/version index 7db6340..5577c70 100644 --- a/batchalign/version +++ b/batchalign/version @@ -1,3 +1,3 @@ -0.7.5-alpha.6 -September 3nd, 2024 -fix benchmark command, part 2 +0.7.5-alpha.7 +September 7th, 2024 +batch hanging utterance bug diff --git a/scratchpad.py b/scratchpad.py index f120f95..b3be2a0 100644 --- a/scratchpad.py +++ b/scratchpad.py @@ -15,7 +15,14 @@ ######## -# from batchalign import * +# from batchalign.models.utterance import infer + +# engine = infer.BertUtteranceModel("talkbank/CHATUtterance-zh_CN") +# engine("我 现在 想 听 你说 一些 你 自己 经 历 过 的 故 事 好不好 然后 呢 我们 会 一起 讨 论 有 六 种 不同 的 情 景 然后 在 每 一个 情 景 中 都 需要 你 去 讲 一个 关 于 你 自己 的 一个 故 事 小 故 事") + +# doc = Document.new(media_path="/Users/houjun/Downloads/trial.mp3", lang="zho") +# pipe = BatchalignPipeline.new("asr", lang="zho", num_speakers=2, engine="rev") +# res = pipe(doc) # # with open("schema.json", 'w') as df: # # json.dump(Document.model_json_schema(), df, indent=4)