Skip to content

Commit

Permalink
Update extract patch logic
Browse files Browse the repository at this point in the history
  • Loading branch information
john-b-yang committed Jul 8, 2024
1 parent 5ba6fd6 commit d200be5
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 34 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
'python-dotenv',
'requests',
'rich',
'unidiff',
'tqdm',
],
include_package_data=True,
Expand Down
6 changes: 5 additions & 1 deletion swebench/collect/build_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import os
from typing import Optional

from swebench.collect.utils import Repo, extract_patches, extract_problem_statement_and_hints
from swebench.collect.utils import (
extract_patches,
extract_problem_statement_and_hints,
Repo,
)

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
Expand Down
45 changes: 12 additions & 33 deletions swebench/collect/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ghapi.core import GhApi
from fastcore.net import HTTP404NotFoundError, HTTP403ForbiddenError
from typing import Callable, Iterator, Optional
from unidiff import PatchSet

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
Expand Down Expand Up @@ -311,40 +312,18 @@ def extract_patches(pull: dict, repo: Repo) -> tuple[str, str]:
patch_change_str (str): gold patch
patch_test_str (str): test patch
"""
# Convert diff to patch format with "index" lines removed
patch = requests.get(pull["diff_url"]).text
if patch.endswith("\n"):
patch = patch[:-1]
# Create change patch and test patch
patch_change, patch_test = [], []

# Flag to determine if current diff block is a test or general change
# Values: 'test', 'diff', None
flag = None

for line in patch.split("\n"):
# Exclude commit specific metadata
if line.startswith("index "):
continue
# Determine if current diff block is a test or general change
if line.startswith("diff --git a/"):
words = set(re.split(r" |_|\/|\.", line.lower()))
flag = (
"test"
if ("test" in words or "tests" in words or "testing" in words)
else "diff"
)
if flag != "test" and not line.strip().endswith(".py"):
flag = None
# Append line to separate patch depending on flag status
if flag == "test":
patch_test.append(line)
elif flag == "diff":
patch_change.append(line)

patch_change_str = "\n".join(patch_change) + "\n" if len(patch_change) > 0 else ""
patch_test_str = "\n".join(patch_test) + "\n" if len(patch_test) > 0 else ""
return patch_change_str, patch_test_str
patch_test = ""
patch_fix = ""
for hunk in PatchSet(patch):
if any(
test_word in hunk.path for test_word in
['test', 'tests', 'e2e', 'testing']
):
patch_test += str(hunk)
else:
patch_fix += str(hunk)
return patch_fix, patch_test


### MARK: Repo Specific Parsing Functions ###
Expand Down

0 comments on commit d200be5

Please sign in to comment.