diff --git a/tritonparse/common.py b/tritonparse/common.py index 1020354..cbf6494 100644 --- a/tritonparse/common.py +++ b/tritonparse/common.py @@ -169,14 +169,33 @@ def copy_local_to_tmpdir(local_path: str, verbose: bool = False) -> str: Copy local log files to a temporary directory. Args: - local_path: Path to local directory containing logs + local_path: Path to local directory or single file containing logs verbose: Whether to print verbose information Returns: Path to temporary directory containing copied logs + + Raises: + RuntimeError: If the local_path does not exist """ + if not os.path.exists(local_path): + raise RuntimeError(f"Path does not exist: {local_path}") + temp_dir = tempfile.mkdtemp() + # Handle single file case + if os.path.isfile(local_path): + if os.path.basename(local_path).startswith(LOG_PREFIX): + if verbose: + logger.info(f"Copying single file {local_path} to {temp_dir}") + shutil.copy2(local_path, temp_dir) + return temp_dir + + # Handle directory case + if not os.path.isdir(local_path): + raise RuntimeError( + f"Path is neither a file nor a directory: {local_path}") + for item in os.listdir(local_path): item_path = os.path.join(local_path, item) if os.path.isfile(item_path) and os.path.basename(item_path).startswith( @@ -319,13 +338,6 @@ def save_logs(out_dir: Path, parsed_logs: str, overwrite: bool, verbose: bool) - if not out_dir.is_absolute(): out_dir = out_dir.resolve() - if out_dir.exists(): - if not overwrite: - raise RuntimeError( - f"{out_dir} already exists, pass --overwrite to overwrite" - ) - shutil.rmtree(out_dir) - os.makedirs(out_dir, exist_ok=True) logger.info(f"Copying parsed logs from {parsed_logs} to {out_dir}") diff --git a/tritonparse/utils.py b/tritonparse/utils.py index 81ddc57..c67e07a 100644 --- a/tritonparse/utils.py +++ b/tritonparse/utils.py @@ -66,6 +66,17 @@ def oss_parse(args): source = Source(args.source, verbose) rank_config = RankConfig.from_cli_args(args.rank, args.all_ranks, source.type) + # Check output directory early if specified + if args.out is not None: + out_dir = Path(args.out) + if out_dir.exists(): + if not args.overwrite: + raise RuntimeError( + f"{out_dir} already exists, pass --overwrite to overwrite" + ) + shutil.rmtree(out_dir) + os.makedirs(out_dir, exist_ok=True) + # For signpost logging (not implemented in Python version) if source.type == SourceType.LOCAL: @@ -75,18 +86,8 @@ def oss_parse(args): elif source.type == SourceType.LOCAL_FILE: local_path = source.value - - if args.out is not None: - out_dir = Path(args.out) - if out_dir.exists(): - if not args.overwrite: - raise RuntimeError( - f"{out_dir} already exists, pass --overwrite to overwrite" - ) - shutil.rmtree(out_dir) - - os.makedirs(out_dir, exist_ok=True) - return + # Copy the single file to a temp directory, then parse it + logs = copy_local_to_tmpdir(local_path, verbose) parsed_log_dir, _ = parse_logs(logs, rank_config, verbose) if args.out is not None: