Skip to content

Commit

Permalink
fix(VisualReplayStrategy): avoid re-using failing segmentations
Browse files Browse the repository at this point in the history
  • Loading branch information
abrichr authored Jul 25, 2024
1 parent a4470c3 commit 0045ebb
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
3 changes: 3 additions & 0 deletions openadapt/drivers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def get_response(
if "error" in result:
error = result["error"]
message = error["message"]
logger.warning(f"{message=}")
if "retry" in message:
return get_response(payload)
raise Exception(message)
return result

Expand Down
17 changes: 9 additions & 8 deletions openadapt/strategies/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,14 +391,15 @@ def get_window_segmentation(
if DEBUG:
original_image.show()

similar_segmentation, similar_segmentation_diff = find_similar_image_segmentation(
original_image,
)
if similar_segmentation:
# TODO XXX: create copy of similar_segmentation, but overwrite with segments of
# regions of new image where segments of similar_segmentation overlap non-zero
# regions of similar_segmentation_diff
return similar_segmentation
if not exceptions:
similar_segmentation, similar_segmentation_diff = (
find_similar_image_segmentation(original_image)
)
if similar_segmentation:
# TODO XXX: create copy of similar_segmentation, but overwrite with segments
# of regions of new image where segments of similar_segmentation overlap
# non-zero regions of similar_segmentation_diff
return similar_segmentation

segmentation_adapter = adapters.get_default_segmentation_adapter()
segmented_image = segmentation_adapter.fetch_segmented_image(original_image)
Expand Down
7 changes: 5 additions & 2 deletions openadapt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,10 @@ def parse_code_snippet(snippet: str) -> dict:
"""
code_block = extract_code_block(snippet)
# remove backtick lines
code_content = "\n".join(code_block.splitlines()[1:-1])
if "```" in code_block:
code_content = "\n".join(code_block.splitlines()[1:-1])
else:
code_content = code_block
# convert literals from Javascript to Python
to_by_from = {
"true": "True",
Expand Down Expand Up @@ -637,7 +640,7 @@ def extract_code_block(text: str) -> str:
raise ValueError("Uneven number of backtick lines")

if len(backtick_idxs) < 2:
return "" # No enclosing backticks found, return empty string
return text

# Extract only the lines between the first and last backtick line,
# including the backticks
Expand Down

0 comments on commit 0045ebb

Please sign in to comment.