diff --git a/.gitignore b/.gitignore index a7e73f1..c5ef002 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ env.sh .mypy_cache notebooks/output notebooks/repos +.vscode/ diff --git a/notebooks/codesearchnet-opennmt.py b/notebooks/codesearchnet-opennmt.py index 321f301..75e4556 100644 --- a/notebooks/codesearchnet-opennmt.py +++ b/notebooks/codesearchnet-opennmt.py @@ -1,3 +1,14 @@ +""" +CLI tool for converting CodeSearchNet dataset to OpenNMT format for +function name suggestion task. + +Usage example: + wget 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip' + unzip java.zip + python notebooks/codesearchnet-opennmt.py \ + --data-dir='java/final/jsonl/valid' \ + --newline='\\n' +""" from argparse import ArgumentParser, Namespace import logging from pathlib import Path @@ -5,32 +16,29 @@ from typing import List, Tuple import pandas as pd -from torch.utils.data import Dataset logging.basicConfig(level=logging.INFO) +# catch SIGPIPE to make it nix CLI friendly e.g. | head +from signal import signal, SIGPIPE, SIG_DFL -class CodeSearchNetRAM(Dataset): - """Stores one split of CodeSearchNet data in memory +signal(SIGPIPE, SIG_DFL) - Usage example: - wget 'https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2/java.zip' - unzip java.zip - python notebooks/codesearchnet-opennmt.py \ - --data_dir='java/final/jsonl/valid' \ - --newline='\\n' - """ + +class CodeSearchNetRAM: + """Stores one split of CodeSearchNet data in memory""" def __init__(self, split_path: Path, newline_repl: str): super().__init__() self.pd = pd + self.newline_repl = newline_repl files = sorted(split_path.glob("**/*.gz")) logging.info(f"Total number of files: {len(files):,}") assert files, "could not find files under %s" % split_path - columns_list = ["code", "func_name"] + columns_list = ["code", "func_name", "code_tokens"] start = time() self.pd = self._jsonl_list_to_dataframe(files, columns_list) @@ -61,39 +69,144 @@ def __getitem__(self, idx: int) -> Tuple[str, str]: # drop fn signature code = row["code"] - fn_body = code[code.find("{") + 1 : code.rfind("}")].lstrip().rstrip() - fn_body = fn_body.replace("\n", "\\n") + fn_body = code[ + code.find("{", code.find(fn_name) + len(fn_name)) + 1 : code.rfind("}") + ] + fn_body = fn_body.strip() + fn_body = fn_body.replace("\n", self.newline_repl) # fn_body_enc = self.enc.encode(fn_body) - return (fn_name, fn_body) + + tokens = row["code_tokens"] + body_tokens = tokens[tokens.index(fn_name) + 2 :] + try: + fn_body_tokens = body_tokens[ + body_tokens.index("{") + 1 : len(body_tokens) - 1 + ] + except ValueError as ve: # '{' might be missing + logging.error("'%s' fn body extraction failed: %s", body_tokens, ve) + fn_body_tokens = None + + return (fn_name, fn_body, fn_body_tokens) def __len__(self) -> int: return len(self.pd) +# id splitting from +# https://github.com/microsoft/dpu-utils/blob/dfc44e354b57a4e2617828bdf4d76c1c4d81c021/python/dpu_utils/codeutils/identifiersplitting.py +from functools import lru_cache +from typing import List + + +def split_camelcase(camel_case_identifier: str) -> List[str]: + """ + Split camelCase identifiers. + """ + if not len(camel_case_identifier): + return [] + + # split into words based on adjacent cases being the same + result = [] + current = str(camel_case_identifier[0]) + prev_upper = camel_case_identifier[0].isupper() + prev_digit = camel_case_identifier[0].isdigit() + prev_special = not camel_case_identifier[0].isalnum() + for c in camel_case_identifier[1:]: + upper = c.isupper() + digit = c.isdigit() + special = not c.isalnum() + new_upper_word = upper and not prev_upper + new_digit_word = digit and not prev_digit + new_special_word = special and not prev_special + if new_digit_word or new_upper_word or new_special_word: + result.append(current) + current = c + elif not upper and prev_upper and len(current) > 1: + result.append(current[:-1]) + current = current[-1] + c + elif not digit and prev_digit: + result.append(current) + current = c + elif not special and prev_special: + result.append(current) + current = c + else: + current += c + prev_digit = digit + prev_upper = upper + prev_special = special + result.append(current) + return result + + +@lru_cache(maxsize=5000) +def split_identifier_into_parts(identifier: str) -> List[str]: + """ + Split a single identifier into parts on snake_case and camelCase + """ + snake_case = identifier.split("_") + + identifier_parts = [] # type: List[str] + for i in range(len(snake_case)): + part = snake_case[i] + if len(part) > 0: + identifier_parts.extend(s.lower() for s in split_camelcase(part)) + if len(identifier_parts) == 0: + return [identifier] + return identifier_parts + + def main(args: Namespace) -> None: dataset = CodeSearchNetRAM(Path(args.data_dir), args.newline) split_name = Path(args.data_dir).name with open(args.src_file % split_name, mode="w", encoding="utf8") as s, open( args.tgt_file % split_name, mode="w", encoding="utf8" ) as t: - for fn_name, fn_body in dataset: + for fn_name, fn_body, fn_body_tokens in dataset: if not fn_name or not fn_body: continue - print(fn_body, file=s) - print(fn_name if args.word_level_targets else " ".join(fn_name), file=t) + + if args.token_level_sources: + if not fn_body_tokens: + continue + src = " ".join(fn_body_tokens).replace("\n", args.newline) + else: + src = fn_body + + if args.word_level_targets: + tgt = fn_name + elif args.token_level_targets: + tgt = " ".join(split_identifier_into_parts(fn_name)) + else: + tgt = " ".join(fn_name) + if args.print: + print(f"'{tgt[:40]:40}' - '{src[:40]:40}'") + else: + print(src, file=s) + print(tgt, file=t) if __name__ == "__main__": parser = ArgumentParser(add_help=False) parser.add_argument( - "--data_dir", + "--data-dir", type=str, default="java/final/jsonl/test", help="Path to the unziped input data (CodeSearchNet)", ) + parser.add_argument("--newline", default="\\n", help="Replace newline with this") + parser.add_argument( - "--newline", type=str, default="\\n", help="Replace newline with this" + "--token-level-sources", + action="store_true", + help="Use language-specific token sources instead of word level ones", + ) + + parser.add_argument( + "--token-level-targets", + action="store_true", + help="Use camlCase and snake_case split token sources instead of word or char level ones", ) parser.add_argument( @@ -103,11 +216,15 @@ def main(args: Namespace) -> None: ) parser.add_argument( - "--src_file", type=str, default="src-%s.txt", help="File with function bodies", + "--src-file", default="src-%s.txt", help="File with function bodies", + ) + + parser.add_argument( + "--tgt-file", default="tgt-%s.txt", help="File with function texts" ) parser.add_argument( - "--tgt_file", type=str, default="tgt-%s.txt", help="File with function texts" + "--print", action="store_true", help="Print data preview to the STDOUT" ) args = parser.parse_args()