Skip to content

Commit

Permalink
Fix/constrained regen (#76)
Browse files Browse the repository at this point in the history
* 🐛 Fixed bug in Regen constrained generation

Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com>

* 🐛 Fix minor bug in create_input for regen

Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com>

* ✅ Improve coverage

Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com>

---------

Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com>
  • Loading branch information
marmg authored Feb 21, 2024
1 parent d16783e commit 4611509
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
3 changes: 3 additions & 0 deletions zshot/linker/linker_regen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def create_input(sentence, max_length, start_delimiter, end_delimiter):
left_index = max(0, start_delimiter_index - half_context)
right_index = min(len(sent_list),
end_delimiter_index + half_context + (half_context - (start_delimiter_index - left_index)))
if right_index == end_delimiter_index:
right_index += 1

left_index = left_index - max(0, (half_context - (right_index - end_delimiter_index)))
return " ".join(sent_list[left_index:right_index])

Expand Down
4 changes: 4 additions & 0 deletions zshot/tests/linker/test_regen_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,7 @@ def test_create_input():
sentence = "[START]" + " test" * times_rep + " [END]"
input_sentence = create_input(sentence, max_length, start_delimiter, end_delimiter)
assert input_sentence == " ".join(["test" for i in range(9)])

text = f"IBM headquarters are located in {start_delimiter} New York {end_delimiter} ."
input_ = create_input(text, max_length=4, start_delimiter=start_delimiter, end_delimiter=end_delimiter)
assert start_delimiter in input_ and end_delimiter in input_

0 comments on commit 4611509

Please sign in to comment.