Skip to content

PatchWork AutoFix #1640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions patchwork/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def list_option_callback(ctx: click.Context, param: click.Parameter, value: str


def find_patchflow(possible_module_paths: Iterable[str], patchflow: str) -> Any | None:
# Define a whitelist of allowed module paths
allowed_modules = {'allowed_module_1', 'allowed_module_2'}

for module_path in possible_module_paths:
try:
spec = importlib.util.spec_from_file_location("custom_module", module_path)
Expand All @@ -71,14 +74,18 @@ def find_patchflow(possible_module_paths: Iterable[str], patchflow: str) -> Any
except Exception:
logger.debug(f"Patchflow {patchflow} not found as a file/directory in {module_path}")

try:
module = importlib.import_module(module_path)
logger.info(f"Patchflow {patchflow} loaded from {module_path}")
return getattr(module, patchflow)
except ModuleNotFoundError:
logger.debug(f"Patchflow {patchflow} not found as a module in {module_path}")
except AttributeError:
logger.debug(f"Patchflow {patchflow} not found in {module_path}")
# Check if the module is in the whitelist before importing
if module_path in allowed_modules:
try:
module = importlib.import_module(module_path)
logger.info(f"Patchflow {patchflow} loaded from {module_path}")
return getattr(module, patchflow)
except ModuleNotFoundError:
logger.debug(f"Patchflow {patchflow} not found as a module in {module_path}")
except AttributeError:
logger.debug(f"Patchflow {patchflow} not found in {module_path}")
else:
logger.warning(f"Module path {module_path} is not in the whitelist.")

return None

Expand Down
3 changes: 2 additions & 1 deletion patchwork/common/tools/bash_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ def execute(

try:
result = subprocess.run(
command, shell=True, cwd=self.path, capture_output=True, text=True, timeout=60 # Add timeout for safety
command.split(), shell=False, cwd=self.path, capture_output=True, text=True, timeout=60
)
return result.stdout if result.returncode == 0 else f"Error: {result.stderr}"
except subprocess.TimeoutExpired:
return "Error: Command timed out after 60 seconds"
except Exception as e:
return f"Error: {str(e)}"

9 changes: 5 additions & 4 deletions patchwork/common/tools/csvkit_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,12 @@ def execute(self, files: list[str], query: str) -> str:
files_to_insert = []
if db_path.is_file():
with sqlite3.connect(str(db_path)) as conn:
cursor = conn.cursor()
for file in files:
res = conn.execute(
f"SELECT 1 from {file.removesuffix('.csv')}",
)
if res.fetchone() is None:
table_name = file.removesuffix('.csv')
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,))
res = cursor.fetchone()
if res is None:
files_to_insert.append(file)
else:
files_to_insert = files
Expand Down
49 changes: 49 additions & 0 deletions patchwork/common/tools/git_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

import os
import subprocess

from patchwork.common.tools.tool import Tool


class GitTool(Tool, tool_name="git_tool", abc_register=False):
def __init__(self, path: str):
super().__init__()
self.path = path

@property
def json_schema(self) -> dict:
return {
"name": "git_tool",
"description": """\
Access to the Git CLI, the command is also `git` all args provided are used as is.
""",
"input_schema": {
"type": "object",
"properties": {
"args": {
"type": "array",
"items": {"type": "string"},
"description": """
The args to run `git` command with.
E.g.
[\"commit\", \"-m\", \"A commit message\"] to commit changes with a commit message.
[\"add\", \".\"] to stage all changed files.
""",
}
},
"required": ["args"],
},
}

def execute(self, args: list[str]) -> str:
env = os.environ.copy()
p = subprocess.run(
["git", *args],
env=env,
cwd=self.path,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
return p.stdout
2 changes: 1 addition & 1 deletion patchwork/common/tools/github_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from patchwork.common.tools.tool import Tool


class GitHubTool(Tool, tool_name="github_tool"):
class GitHubTool(Tool, tool_name="github_tool", abc_register=False):
def __init__(self, path: str, gh_token: str):
super().__init__()
self.path = path
Expand Down
5 changes: 4 additions & 1 deletion patchwork/common/utils/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
"notification": ["slack_sdk"],
}

__ALLOWED_MODULES = {module for modules in __DEPENDENCY_GROUPS.values() for module in modules}

@lru_cache(maxsize=None)
def import_with_dependency_group(name):
if name not in __ALLOWED_MODULES:
raise ImportError(f"Import of untrusted module '{name}' is not allowed.")

try:
return importlib.import_module(name)
except ImportError:
Expand All @@ -20,6 +24,5 @@ def import_with_dependency_group(name):
error_msg = f"Please `pip install patchwork-cli[{dependency_group}]` to use this step"
raise ImportError(error_msg)


def slack_sdk():
return import_with_dependency_group("slack_sdk")
8 changes: 7 additions & 1 deletion patchwork/common/utils/step_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,15 @@ def validate_step_type_config_with_inputs(


def validate_step_with_inputs(input_keys: Set[str], step: Type[Step]) -> Tuple[Set[str], Dict[str, str]]:
allowed_modules = {"module_1.typed", "module_2.typed"} # Example whitelist
module_path, _, _ = step.__module__.rpartition(".")
step_name = step.__name__
type_module = importlib.import_module(f"{module_path}.typed")
type_module_path = f"{module_path}.typed"

if type_module_path not in allowed_modules:
raise ValueError(f"Importing from {type_module_path} is not allowed")

type_module = importlib.import_module(type_module_path)
step_input_model = getattr(type_module, f"{step_name}Inputs", __NOT_GIVEN)
step_output_model = getattr(type_module, f"{step_name}Outputs", __NOT_GIVEN)
if step_input_model is __NOT_GIVEN:
Expand Down
3 changes: 2 additions & 1 deletion patchwork/steps/CallShell/CallShell.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def __parse_env_text(env_text: str) -> dict[str, str]:
return env

def run(self) -> dict:
p = subprocess.run(self.script, shell=True, capture_output=True, text=True, cwd=self.working_dir, env=self.env)
command_args = shlex.split(self.script)
p = subprocess.run(command_args, shell=False, capture_output=True, text=True, cwd=self.working_dir, env=self.env)
try:
p.check_returncode()
except subprocess.CalledProcessError as e:
Expand Down
8 changes: 6 additions & 2 deletions patchwork/steps/GitHubAgent/GitHubAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
AgentConfig,
AgenticStrategyV2,
)
from patchwork.common.tools.git_tool import GitTool
from patchwork.common.tools.github_tool import GitHubTool
from patchwork.common.utils.utils import mustache_render
from patchwork.step import Step
Expand Down Expand Up @@ -34,10 +35,13 @@ def __init__(self, inputs):
AgentConfig(
name="Assistant",
model="gemini-2.0-flash",
tool_set=dict(github_tool=GitHubTool(base_path, inputs["github_api_key"])),
tool_set=dict(
github_tool=GitHubTool(base_path, inputs["github_api_key"]),
git_tool=GitTool(base_path),
),
system_prompt="""\
You are a senior software developer helping the program manager to obtain some data from GitHub.
You can access github through the `gh` CLI app.
You can access github through the `gh` CLI app through the `github_tool`, and `git` through the `git_tool`.
Your `gh` app has already been authenticated.
""",
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "patchwork-cli"
version = "0.0.123"
version = "0.0.124"
description = ""
authors = ["patched.codes"]
license = "AGPL"
Expand Down