Skip to content

PatchWork GenerateDocstring #1633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions patchwork/common/tools/git_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

import os
import subprocess

from patchwork.common.tools.tool import Tool


class GitTool(Tool, tool_name="git_tool", abc_register=False):
def __init__(self, path: str):
super().__init__()
self.path = path

@property
def json_schema(self) -> dict:
return {
"name": "git_tool",
"description": """\
Access to the Git CLI, the command is also `git` all args provided are used as is.
""",
"input_schema": {
"type": "object",
"properties": {
"args": {
"type": "array",
"items": {"type": "string"},
"description": """
The args to run `git` command with.
E.g.
[\"commit\", \"-m\", \"A commit message\"] to commit changes with a commit message.
[\"add\", \".\"] to stage all changed files.
""",
}
},
"required": ["args"],
},
}

def execute(self, args: list[str]) -> str:
env = os.environ.copy()
p = subprocess.run(
["git", *args],
env=env,
cwd=self.path,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
return p.stdout
2 changes: 1 addition & 1 deletion patchwork/common/tools/github_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from patchwork.common.tools.tool import Tool


class GitHubTool(Tool, tool_name="github_tool"):
class GitHubTool(Tool, tool_name="github_tool", abc_register=False):
def __init__(self, path: str, gh_token: str):
super().__init__()
self.path = path
Expand Down
8 changes: 6 additions & 2 deletions patchwork/steps/GitHubAgent/GitHubAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
AgentConfig,
AgenticStrategyV2,
)
from patchwork.common.tools.git_tool import GitTool
from patchwork.common.tools.github_tool import GitHubTool
from patchwork.common.utils.utils import mustache_render
from patchwork.step import Step
Expand Down Expand Up @@ -34,10 +35,13 @@ def __init__(self, inputs):
AgentConfig(
name="Assistant",
model="gemini-2.0-flash",
tool_set=dict(github_tool=GitHubTool(base_path, inputs["github_api_key"])),
tool_set=dict(
github_tool=GitHubTool(base_path, inputs["github_api_key"]),
git_tool=GitTool(base_path),
),
system_prompt="""\
You are a senior software developer helping the program manager to obtain some data from GitHub.
You can access github through the `gh` CLI app.
You can access github through the `gh` CLI app through the `github_tool`, and `git` through the `git_tool`.
Your `gh` app has already been authenticated.
""",
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "patchwork-cli"
version = "0.0.123"
version = "0.0.124"
description = ""
authors = ["patched.codes"]
license = "AGPL"
Expand Down
32 changes: 32 additions & 0 deletions tests/cicd/generate_docstring/cpp_test_file.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,28 @@


template<typename T>
/**
* Adds two values of the same type.
*
* @param a The first value to be added.
* @param b The second value to be added.
* @return The result of adding a and b.
*/
T a_plus_b(T a, T b) {
return a + b;
}


/**
* Executes a SQL query on the given SQLite database and returns the results.
*
* @param db A pointer to the SQLite database object.
* @param query A string containing the SQL query to be executed.
* @return A vector of vectors of strings where each sub-vector represents a row from the query result,
* with each string in the sub-vector corresponding to a column value. Returns an empty vector
* if the query fails to prepare.
*/

std::vector<std::vector<std::string>> sqlite(sqlite3* db, const std::string& query) {
std::vector<std::vector<std::string>> results;
sqlite3_stmt* stmt;
Expand Down Expand Up @@ -38,6 +55,15 @@ std::vector<std::vector<std::string>> sqlite(sqlite3* db, const std::string& que


template<typename T, typename F>
/**
* Compares two items based on a key mapping function and returns an integer indicating their order.
*
* @param key_map A function that extracts a comparison key from an item.
* @param item1 The first item to be compared.
* @param item2 The second item to be compared.
* @return -1 if the first item is less than the second, 1 if the first item is greater than the second,
* and 0 if they are equal based on the mapping function.
*/
int compare(F key_map, const T& item1, const T& item2) {
auto val1 = key_map(item1);
auto val2 = key_map(item2);
Expand All @@ -48,6 +74,12 @@ int compare(F key_map, const T& item1, const T& item2) {
}


/**
* Generates a random string composed of lowercase and uppercase alphabets.
*
* @param length The length of the random string to be generated.
* @return A random string containing only alphabetic characters with the specified length.
*/
std::string random_alphabets(int length) {
static const std::string chars =
"abcdefghijklmnopqrstuvwxyz"
Expand Down
17 changes: 17 additions & 0 deletions tests/cicd/generate_docstring/java_test_file.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,25 @@
class Test {
/**
* Calculates the sum of two integers.
*
* @param a The first integer to be added.
* @param b The second integer to be added.
* @return The sum of the two integers.
*/
public static int a_plus_b(Integer a, Integer b) {
return a + b;
}

/**
* Compares two objects based on their keys mapped by a specified key mapping function.
*
* @param keymap A function that maps an object to a comparable value.
* @param a The first object to be compared.
* @param b The second object to be compared.
* @return An integer representing the comparison result: -1 if the key of 'a' is less than the key of 'b',
* 1 if the key of 'a' is greater than the key of 'b', and 0 if the keys are equal.
*/

public static int a_plus_b(Function<Object, Comparable> keymap, object a, Object b) {
if (keymap(a) < keymap(b)) {
return -1;
Expand Down
21 changes: 21 additions & 0 deletions tests/cicd/generate_docstring/js_test_file.py.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@

/**
* Computes the sum of two numbers.
* @param {number} a - The first number.
* @param {number} b - The second number.
* @returns {number} The sum of the two numbers.
*/
function a_plus_b(a, b) {
return a + b;
}

/**
* Compares two objects based on the value associated with a given key.
* @param {String} keymap - The key name to be used for comparison.
* @param {Object} a - The first object to compare.
* @param {Object} b - The second object to compare.
* @returns {Number} - Returns -1 if the value of 'a' is less than the value of 'b',
* 1 if greater, or 0 if they are equal.
*/
const compare = function (keymap, a, b) {
if (a[keymap] < b[keymap]) {
return -1;
Expand All @@ -13,6 +27,13 @@ const compare = function (keymap, a, b) {
}
}

/**
* Executes a query on a given SQLite database and applies a callback function to each result row.
* @param {Object} db - The SQLite database object to be queried.
* @param {string} query - The SQL query string to be executed on the database.
* @param {Function} callback - A function that will be called with each row of the result set.
* @returns {void}
*/
const sqlite = (db, query, callback) => {
db.serialize(function () {
db.each(query, callback);
Expand Down
34 changes: 34 additions & 0 deletions tests/cicd/generate_docstring/kotlin_test_file.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,27 @@ import java.sql.ResultSet
import kotlin.random.Random


/**
* Computes the sum of two numeric values by converting them to Double.
*
* This function accepts any type that extends Number, converts the values to Double,
* and returns the sum as a Double.
*
* @param a The first numeric value of type T, where T extends Number.
* @param b The second numeric value of type T, where T extends Number.
* @return The sum of a and b as a Double.
*/
fun <T : Number> aPlusB(a: T, b: T): Double = a.toDouble() + b.toDouble()


/**
* Executes a SQL query on a given database connection and returns the results as a list of lists.
* Each inner list represents a row from the result set, with each element corresponding to a column value.
*
* @param db The database connection to use for executing the query.
* @param query The SQL query to be executed on the database.
* @return A list of rows, where each row is represented as a list of objects. Each object corresponds to a column value in the result set. Returns an empty list if no results are found.
*/
fun sqlite(db: Connection, query: String): List<List<Any?>> {
db.createStatement().use { statement ->
statement.executeQuery(query).use { resultSet ->
Expand All @@ -27,6 +45,15 @@ fun sqlite(db: Connection, query: String): List<List<Any?>> {
}


/**
* Compares two items using a provided key mapping function, which extracts a comparable value from each item.
* Returns -1 if the first item is less than the second, 1 if it is greater, and 0 if they are equal, based on the comparable value.
*
* @param keyMap A function that maps an item of type T to a comparable value of type R.
* @param item1 The first item to be compared.
* @param item2 The second item to be compared.
* @return An integer result of the comparison: -1, 0, or 1.
*/
fun <T, R : Comparable<R>> compare(keyMap: (T) -> R, item1: T, item2: T): Int {
return when {
keyMap(item1) < keyMap(item2) -> -1
Expand All @@ -36,6 +63,13 @@ fun <T, R : Comparable<R>> compare(keyMap: (T) -> R, item1: T, item2: T): Int {
}


/**
* Generates a random string of alphabets with the specified length.
* The string includes both lowercase and uppercase English letters.
*
* @param length The desired length of the randomly generated string.
* @return A string consisting of random uppercase and lowercase alphabets.
*/
fun randomAlphabets(length: Int): String {
val charPool = ('a'..'z') + ('A'..'Z')
return (1..length)
Expand Down
37 changes: 37 additions & 0 deletions tests/cicd/generate_docstring/python_test_file.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,44 @@
# fmt: off
def a_plus_b(a, b):
"""Adds two numbers together.

Args:
a (int or float): The first number to be added.
b (int or float): The second number to be added.

Returns:
int or float: The sum of the two numbers.
"""
return a + b


def sqlite(db, query):
"""Executes a given SQL query on a SQLite database and returns the results.

Args:
db (sqlite3.Connection): A SQLite database connection object.
query (str): The SQL query to be executed on the database.

Returns:
list: A list of tuples containing the results of the query.
"""

cursor = db.cursor()
cursor.execute(query)
return cursor.fetchall()


def compare(key_map, item1, item2):
"""Compares two items based on a key mapping function and determines their order.

Args:
key_map (function): A function that extracts a comparison key from each item.
item1 (any): The first item to compare.
item2 (any): The second item to compare.

Returns:
int: -1 if item1 is less than item2, 1 if item1 is greater than item2, and 0 if they are equal.
"""
if key_map(item1) < key_map(item2):
return -1
elif key_map(item1) > key_map(item2):
Expand All @@ -21,4 +50,12 @@ def compare(key_map, item1, item2):
def random_alphabets(
length: int
):
"""Generates a random string of alphabets.

Args:
length (int): The desired length of the output string.

Returns:
str: A randomly generated string consisting of ASCII alphabets (both lower and uppercase) of the specified length.
"""
return ''.join(random.choices(string.ascii_letters, k=length))