Skip to content

Commit

Permalink
Merge pull request #25 from commit-0/restrict-git
Browse files Browse the repository at this point in the history
Restrict git
  • Loading branch information
wenting-zhao authored Sep 14, 2024
2 parents f9727de + 877e1e8 commit 2ba0357
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 29 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/system.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ jobs:
uses: docker/setup-buildx-action@v3
- name: Install the project
run: uv sync
- name: Clone
run: uv run commit0 clone simpy
- name: Setup
- name: Set up commit0
run: uv run commit0 setup simpy
- name: Build docker images
run: uv run commit0 build simpy
- name: Get tests
run: uv run commit0 get-tests simpy
Expand Down
4 changes: 2 additions & 2 deletions commit0/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def main() -> None:
# after hydra gets all configs, put command-line arguments back
sys.argv = sys_argv
# repo_split: split from command line has a higher priority than split in hydra
if command in ["clone", "build", "evaluate", "evaluate-reference", "save"]:
if command in ["setup", "build", "evaluate", "evaluate-reference", "save"]:
if len(sys.argv) >= 3:
if sys.argv[2] not in SPLIT:
raise ValueError(
Expand All @@ -39,7 +39,7 @@ def main() -> None:
config.repo_split = sys.argv[2]
config.base_dir = os.path.abspath(config.base_dir)

if command == "clone":
if command == "setup":
commit0.harness.setup.main(
config.dataset_name,
config.dataset_split,
Expand Down
2 changes: 1 addition & 1 deletion commit0/harness/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class RepoInstance(TypedDict):

# available commands
COMMANDS = [
"clone",
"setup",
"build",
"test",
"test-reference",
Expand Down
16 changes: 10 additions & 6 deletions commit0/harness/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import threading
import time
import traceback
import pwd
from pathlib import Path
from io import BytesIO
from typing import Optional, List, Union
Expand Down Expand Up @@ -158,23 +159,26 @@ def copy_ssh_pubkey_from_container(container: Container) -> None:
if exit_code != 0:
raise Exception(f"Error reading file: {output.decode('utf-8').strip()}")
public_key = output.decode("utf-8").strip()
public_key = f"no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty {public_key}"

local_authorized_keys_path = os.path.expanduser("~/.ssh/authorized_keys")
os.makedirs(os.path.dirname(local_authorized_keys_path), exist_ok=True)
if not os.path.exists(local_authorized_keys_path):
user_info = pwd.getpwnam("git")
home_directory = user_info.pw_dir
authorized_keys_path = os.path.join(home_directory, ".ssh", "authorized_keys")
os.makedirs(os.path.dirname(authorized_keys_path), exist_ok=True)
if not os.path.exists(authorized_keys_path):
# Since the file does not exist, create it
open(local_authorized_keys_path, "a").close()
open(authorized_keys_path, "a").close()
write = True
else:
with open(local_authorized_keys_path, "r") as authorized_keys_file:
with open(authorized_keys_path, "r") as authorized_keys_file:
content = authorized_keys_file.read()
if public_key not in content:
write = True
else:
write = False

if write:
with open(local_authorized_keys_path, "a") as authorized_keys_file:
with open(authorized_keys_path, "a") as authorized_keys_file:
authorized_keys_file.write(public_key + "\n")

except docker.errors.APIError as e:
Expand Down
5 changes: 3 additions & 2 deletions commit0/harness/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,19 @@ def main(
for name in tqdm(log_dirs):
report_file = os.path.join(name, "report.json")
name = name.split("/")[2]
test_ids = get_tests(name, stdout=False)
if not os.path.exists(report_file):
out.append(
{
"name": name,
"sum": 0,
"passed": 0,
"num_passed": 0,
"num_tests": len(test_ids),
}
)
continue
report = load_dataset("json", data_files=report_file, split="train") # type: ignore
test_ids = get_tests(name, stdout=False)
tests = {x["nodeid"]: x["call"] for x in report["tests"][0]} # type: ignore
status = []
runtimes = []
Expand All @@ -110,7 +111,7 @@ def main(
"sum": total,
"passed": passed,
"num_passed": status["passed"] + status["xfail"],
"num_tests": sum(status.values()),
"num_tests": len(test_ids),
}
)
print("repo,runtime,num_passed/num_tests")
Expand Down
19 changes: 10 additions & 9 deletions commit0/harness/run_pytest_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
extract_test_output,
get_hash_string,
get_ip,
get_user,
)
from commit0.harness.execution_context import (
Docker,
Expand Down Expand Up @@ -74,7 +73,6 @@ def main(
commit_id=commit_id,
test_ids=test_ids,
ip=get_ip(backend),
user=get_user(),
)
eval_file = Path(log_dir / "eval.sh")
eval_file.write_text(eval_script)
Expand All @@ -96,18 +94,21 @@ def main(
output, "--json-report --json-report-file=report.json"
)
context.write_test_output(test_output, timed_out)
print(test_output)
except EvaluationError as e:
error_msg = traceback.format_exc()
logger.info(error_msg)
print(e)
error_msg = (
f"Error in running pytest for {repo}: {e}\n"
f"{traceback.format_exc()}\n"
f"Check ({log_file}) for more information."
)
raise EvaluationError(repo, error_msg, logger)
except Exception as e:
error_msg = (
f"Error in running pytest for {spec.repo}: {e}\n"
f"General error: {e}\n"
f"{traceback.format_exc()}\n"
# f"Check ({logger.log_file}) for more information."
f"Check ({log_file}) for more information."
)
logger.error(error_msg)

raise RuntimeError(error_msg)
return str(log_dir)


Expand Down
9 changes: 8 additions & 1 deletion commit0/harness/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from datasets import load_dataset

from typing import Iterator
from commit0.harness.utils import clone_repo, create_branch
from commit0.harness.utils import (
clone_repo,
create_branch,
setup_git,
add_safe_directory,
)
from commit0.harness.constants import RepoInstance, SPLIT


Expand All @@ -18,6 +23,7 @@ def main(
dataset_name: str, dataset_split: str, repo_split: str, base_dir: str, branch: str
) -> None:
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
setup_git(logger)
for example in dataset:
repo_name = example["repo"].split("/")[-1]
if repo_split != "all" and repo_name not in SPLIT[repo_split]:
Expand All @@ -26,6 +32,7 @@ def main(
clone_dir = os.path.abspath(os.path.join(base_dir, repo_name))
local_repo = clone_repo(clone_url, clone_dir, example["base_commit"], logger)
create_branch(local_repo, branch, logger)
add_safe_directory(clone_dir, logger)


__all__ = []
2 changes: 1 addition & 1 deletion commit0/harness/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def make_eval_script_list(instance: RepoInstance, repo_directory: str) -> list[s
"ssh-keyscan {ip} >> ~/.ssh/known_hosts",
f"cd {repo_directory}",
"source .venv/bin/activate",
f"git remote add {origin_name} ssh://{{user}}@{{ip}}:{{local_repo}}",
f"git remote add {origin_name} ssh://git@{{ip}}:{{local_repo}}",
f"git fetch {origin_name}",
"git checkout {commit_id}",
"git status",
Expand Down
97 changes: 93 additions & 4 deletions commit0/harness/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import getpass
import git
import git.exc
import hashlib
Expand All @@ -7,7 +6,8 @@
import os
import time
import requests
from typing import Optional
import subprocess
from typing import Optional, Tuple

from fastcore.net import HTTP404NotFoundError, HTTP403ForbiddenError # type: ignore
from ghapi.core import GhApi
Expand Down Expand Up @@ -58,8 +58,97 @@ def get_ip(backend: str) -> str:
return ip


def get_user() -> str:
return getpass.getuser()
def run_command(command: str) -> Tuple[str, str, int]:
"""Runs a shell command and returns the output, error message, and exit code."""
try:
result = subprocess.run(
command,
shell=True,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
return (
result.stdout.decode("utf-8"),
result.stderr.decode("utf-8"),
result.returncode,
)
except subprocess.CalledProcessError as e:
return e.stdout.decode("utf-8"), e.stderr.decode("utf-8"), e.returncode


def handle_command(command: str, description: str, logger: logging.Logger) -> None:
"""Runs a command and handles success or failure with appropriate messages."""
stdout, stderr, exit_code = run_command(command)
if exit_code != 0:
logger.error(f"Error running '{command}' which {description}:\n{stderr}")
else:
logger.info(f"Succeeded in running '{command}' which {description}")


def setup_git(logger: logging.Logger) -> None:
"""Sets up the 'git' user with appropriate shell settings, .ssh directory, and git-shell as login shell."""
handle_command(
'sudo adduser --disabled-password --gecos "" git', "adds git user", logger
)

# Get git user's home directory dynamically
git_home_command = "getent passwd git | cut -d: -f6"
stdout, stderr, exit_code = run_command(git_home_command)
if exit_code != 0:
raise RuntimeError(f"Error getting git user's home directory: {stderr}")
git_home = stdout.strip() # Extract and trim the home directory

# Commands to be executed
commands = [
(f"sudo chmod 755 {git_home}", "make home of git viewable by others"),
(
f"sudo sh -c 'mkdir -p {git_home}/.ssh && chmod 755 {git_home}/.ssh && touch {git_home}/.ssh/authorized_keys && chmod 666 {git_home}/.ssh/authorized_keys'",
"sets up .ssh directory for git",
),
("sudo touch /etc/shells", "creates /etc/shells if it doesn't exist yet"),
("cat /etc/shells", "views available shells"),
(
"sudo sh -c 'which git-shell >> /etc/shells'",
"adds git-shell to /etc/shells",
),
(
"sudo chsh git -s $(which git-shell)",
"changes shell for git user to git-shell",
),
]

# Execute each command
for command, description in commands:
handle_command(command, description, logger)


def is_safe_directory_added(safe_directory: str) -> bool:
# Run command to get all safe directories
command = "sudo git config --system --get-all safe.directory"
stdout, stderr, exit_code = run_command(command)

# Check if the directory is listed
if exit_code == 0 and safe_directory in stdout.splitlines():
return True
else:
return False


def add_safe_directory(safe_directory: str, logger: logging.Logger) -> None:
safe_directory = os.path.join(safe_directory, ".git")
# Check if the directory is already added
if not is_safe_directory_added(safe_directory):
# Command to add the directory to safe.directory
command = f"sudo git config --system --add safe.directory {safe_directory}"
stdout, stderr, exit_code = run_command(command)

if exit_code == 0:
logger.info(f"Directory '{safe_directory}' added to safe.directory.")
else:
logger.error(f"Error adding directory: {stderr}")
else:
logger.info(f"Directory '{safe_directory}' is already in the list.")


def get_hash_string(input_string: str) -> str:
Expand Down

0 comments on commit 2ba0357

Please sign in to comment.