Skip to content

Token-level CodeSearchNet preprocessing option #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ env.sh
.mypy_cache
notebooks/output
notebooks/repos
.vscode/
159 changes: 138 additions & 21 deletions notebooks/codesearchnet-opennmt.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,44 @@
"""
CLI tool for converting CodeSearchNet dataset to OpenNMT format for
function name suggestion task.

Usage example:
wget 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip'
unzip java.zip
python notebooks/codesearchnet-opennmt.py \
--data-dir='java/final/jsonl/valid' \
--newline='\\n'
"""
from argparse import ArgumentParser, Namespace
import logging
from pathlib import Path
from time import time
from typing import List, Tuple

import pandas as pd
from torch.utils.data import Dataset


logging.basicConfig(level=logging.INFO)

# catch SIGPIPE to make it nix CLI friendly e.g. | head
from signal import signal, SIGPIPE, SIG_DFL

class CodeSearchNetRAM(Dataset):
"""Stores one split of CodeSearchNet data in memory
signal(SIGPIPE, SIG_DFL)

Usage example:
wget 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip'
unzip java.zip
python notebooks/codesearchnet-opennmt.py \
--data_dir='java/final/jsonl/valid' \
--newline='\\n'
"""

class CodeSearchNetRAM:
"""Stores one split of CodeSearchNet data in memory"""

def __init__(self, split_path: Path, newline_repl: str):
super().__init__()
self.pd = pd
self.newline_repl = newline_repl

files = sorted(split_path.glob("**/*.gz"))
logging.info(f"Total number of files: {len(files):,}")
assert files, "could not find files under %s" % split_path

columns_list = ["code", "func_name"]
columns_list = ["code", "func_name", "code_tokens"]

start = time()
self.pd = self._jsonl_list_to_dataframe(files, columns_list)
Expand Down Expand Up @@ -61,39 +69,144 @@ def __getitem__(self, idx: int) -> Tuple[str, str]:

# drop fn signature
code = row["code"]
fn_body = code[code.find("{") + 1 : code.rfind("}")].lstrip().rstrip()
fn_body = fn_body.replace("\n", "\\n")
fn_body = code[
code.find("{", code.find(fn_name) + len(fn_name)) + 1 : code.rfind("}")
]
fn_body = fn_body.strip()
fn_body = fn_body.replace("\n", self.newline_repl)
# fn_body_enc = self.enc.encode(fn_body)
return (fn_name, fn_body)

tokens = row["code_tokens"]
body_tokens = tokens[tokens.index(fn_name) + 2 :]
try:
fn_body_tokens = body_tokens[
body_tokens.index("{") + 1 : len(body_tokens) - 1
]
except ValueError as ve: # '{' might be missing
logging.error("'%s' fn body extraction failed: %s", body_tokens, ve)
fn_body_tokens = None

return (fn_name, fn_body, fn_body_tokens)

def __len__(self) -> int:
return len(self.pd)


# id splitting from
# https://github.com/microsoft/dpu-utils/blob/dfc44e354b57a4e2617828bdf4d76c1c4d81c021/python/dpu_utils/codeutils/identifiersplitting.py
from functools import lru_cache
from typing import List


def split_camelcase(camel_case_identifier: str) -> List[str]:
"""
Split camelCase identifiers.
"""
if not len(camel_case_identifier):
return []

# split into words based on adjacent cases being the same
result = []
current = str(camel_case_identifier[0])
prev_upper = camel_case_identifier[0].isupper()
prev_digit = camel_case_identifier[0].isdigit()
prev_special = not camel_case_identifier[0].isalnum()
for c in camel_case_identifier[1:]:
upper = c.isupper()
digit = c.isdigit()
special = not c.isalnum()
new_upper_word = upper and not prev_upper
new_digit_word = digit and not prev_digit
new_special_word = special and not prev_special
if new_digit_word or new_upper_word or new_special_word:
result.append(current)
current = c
elif not upper and prev_upper and len(current) > 1:
result.append(current[:-1])
current = current[-1] + c
elif not digit and prev_digit:
result.append(current)
current = c
elif not special and prev_special:
result.append(current)
current = c
else:
current += c
prev_digit = digit
prev_upper = upper
prev_special = special
result.append(current)
return result


@lru_cache(maxsize=5000)
def split_identifier_into_parts(identifier: str) -> List[str]:
"""
Split a single identifier into parts on snake_case and camelCase
"""
snake_case = identifier.split("_")

identifier_parts = [] # type: List[str]
for i in range(len(snake_case)):
part = snake_case[i]
if len(part) > 0:
identifier_parts.extend(s.lower() for s in split_camelcase(part))
if len(identifier_parts) == 0:
return [identifier]
return identifier_parts


def main(args: Namespace) -> None:
dataset = CodeSearchNetRAM(Path(args.data_dir), args.newline)
split_name = Path(args.data_dir).name
with open(args.src_file % split_name, mode="w", encoding="utf8") as s, open(
args.tgt_file % split_name, mode="w", encoding="utf8"
) as t:
for fn_name, fn_body in dataset:
for fn_name, fn_body, fn_body_tokens in dataset:
if not fn_name or not fn_body:
continue
print(fn_body, file=s)
print(fn_name if args.word_level_targets else " ".join(fn_name), file=t)

if args.token_level_sources:
if not fn_body_tokens:
continue
src = " ".join(fn_body_tokens).replace("\n", args.newline)
else:
src = fn_body

if args.word_level_targets:
tgt = fn_name
elif args.token_level_targets:
tgt = " ".join(split_identifier_into_parts(fn_name))
else:
tgt = " ".join(fn_name)
if args.print:
print(f"'{tgt[:40]:40}' - '{src[:40]:40}'")
else:
print(src, file=s)
print(tgt, file=t)


if __name__ == "__main__":
parser = ArgumentParser(add_help=False)
parser.add_argument(
"--data_dir",
"--data-dir",
type=str,
default="java/final/jsonl/test",
help="Path to the unziped input data (CodeSearchNet)",
)

parser.add_argument("--newline", default="\\n", help="Replace newline with this")

parser.add_argument(
"--newline", type=str, default="\\n", help="Replace newline with this"
"--token-level-sources",
action="store_true",
help="Use language-specific token sources instead of word level ones",
)

parser.add_argument(
"--token-level-targets",
action="store_true",
help="Use camlCase and snake_case split token sources instead of word or char level ones",
)

parser.add_argument(
Expand All @@ -103,11 +216,15 @@ def main(args: Namespace) -> None:
)

parser.add_argument(
"--src_file", type=str, default="src-%s.txt", help="File with function bodies",
"--src-file", default="src-%s.txt", help="File with function bodies",
)

parser.add_argument(
"--tgt-file", default="tgt-%s.txt", help="File with function texts"
)

parser.add_argument(
"--tgt_file", type=str, default="tgt-%s.txt", help="File with function texts"
"--print", action="store_true", help="Print data preview to the STDOUT"
)

args = parser.parse_args()
Expand Down