From 290d947981d32c800349216b204435a119ddd2bd Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Wed, 5 Feb 2025 13:50:23 -0800 Subject: [PATCH 1/2] Tmp commit --- metaflow/cli.py | 66 +++- metaflow/cli_components/run_cmds.py | 142 ++++++- metaflow/cli_components/step_cmd.py | 108 ++++++ metaflow/datastore/spin_datastore/__init__.py | 0 .../spin_datastore/inputs_datastore.py | 104 ++++++ .../spin_datastore/step_datastore.py | 108 ++++++ metaflow/flowspec.py | 16 +- metaflow/metaflow_config.py | 14 +- metaflow/metaflow_current.py | 2 + metaflow/runner/metaflow_runner.py | 194 ++++++++-- metaflow/runtime.py | 349 ++++++++++++++++-- metaflow/task.py | 214 +++++++++++ metaflow/util.py | 44 +++ 13 files changed, 1278 insertions(+), 83 deletions(-) create mode 100644 metaflow/datastore/spin_datastore/__init__.py create mode 100644 metaflow/datastore/spin_datastore/inputs_datastore.py create mode 100644 metaflow/datastore/spin_datastore/step_datastore.py diff --git a/metaflow/cli.py b/metaflow/cli.py index 3a8dc4ecaa9..eac63d952e8 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -134,6 +134,8 @@ def config_merge_cb(ctx, param, value): "step": "metaflow.cli_components.step_cmd.step", "run": "metaflow.cli_components.run_cmds.run", "resume": "metaflow.cli_components.run_cmds.resume", + "spin": "metaflow.cli_components.run_cmds.spin", + "spin-internal": "metaflow.cli_components.step_cmd.spin_internal", }, ) def cli(ctx): @@ -256,6 +258,12 @@ def version(obj): type=click.Choice([m.TYPE for m in METADATA_PROVIDERS]), help="Metadata service type", ) +# @click.option( +# "--spin-metadata", +# default="local", +# show_default=True, +# help="Spin Metadata service type", +# ) @click.option( "--environment", default=DEFAULT_ENVIRONMENT, @@ -272,6 +280,14 @@ def version(obj): help="Data backend type", is_eager=True, ) +# @click.option( +# "--spin-datastore", +# default=DEFAULT_DATASTORE, +# show_default=True, +# type=click.Choice([d.TYPE for d in DATASTORES]), +# help="Data backend type", +# is_eager=True, +# ) @click.option("--datastore-root", help="Root path for datastore") @click.option( "--package-suffixes", @@ -384,7 +400,6 @@ def start( # second one processed will return the actual options. The order of processing # depends on what (and in what order) the user specifies on the command line. config_options = config_file or config_value - if ( hasattr(ctx, "saved_args") and ctx.saved_args @@ -462,14 +477,10 @@ def start( ctx.obj.event_logger = LOGGING_SIDECARS[event_logger]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.event_logger.start() - _system_logger.init_system_logger(ctx.obj.flow.name, ctx.obj.event_logger) ctx.obj.monitor = MONITOR_SIDECARS[monitor]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.monitor.start() - _system_monitor.init_system_monitor(ctx.obj.flow.name, ctx.obj.monitor) ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == metadata][0]( ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor @@ -484,6 +495,39 @@ def start( ) ctx.obj.config_options = config_options + ctx.obj.is_spin = False + + # Add new top-level options for spin and spin-internal commands + # print(f"ctx.saved_args is {ctx.saved_args}") + if hasattr(ctx, "saved_args") and ctx.saved_args and "spin" in ctx.saved_args[0]: + # For spin, we will only use the local metadata provider, datastore, environment + # and null event logger and monitor + ctx.obj.is_spin = True + ctx.obj.spin_metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "local"][0]( + ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor + ) + ctx.obj.spin_datastore_impl = [d for d in DATASTORES if d.TYPE == "local"][0] + if datastore_root is None: + datastore_root = ctx.obj.spin_datastore_impl.get_datastore_root_from_config( + ctx.obj.echo + ) + ctx.obj.spin_datastore_impl.datastore_root = datastore_root + ctx.obj.spin_flow_datastore = FlowDataStore( + ctx.obj.flow.name, + ctx.obj.environment, + ctx.obj.spin_metadata, # local metadata provider + ctx.obj.event_logger, + ctx.obj.monitor, + storage_impl=ctx.obj.spin_datastore_impl, + ) + # print(f"ctx.obj.spin_flow_datastore: {ctx.obj.spin_flow_datastore}") + + # Start event logger and monitor + ctx.obj.event_logger.start() + _system_logger.init_system_logger(ctx.obj.flow.name, ctx.obj.event_logger) + + ctx.obj.monitor.start() + _system_monitor.init_system_monitor(ctx.obj.flow.name, ctx.obj.monitor) decorators._init(ctx.obj.flow) @@ -493,14 +537,14 @@ def start( ctx.obj.flow, ctx.obj.graph, ctx.obj.environment, - ctx.obj.flow_datastore, - ctx.obj.metadata, + ctx.obj.flow_datastore if not ctx.obj.is_spin else ctx.obj.spin_flow_datastore, + ctx.obj.metadata if not ctx.obj.is_spin else ctx.obj.spin_metadata, ctx.obj.logger, echo, deco_options, ) - # In the case of run/resume, we will want to apply the TL decospecs + # In the case of run/resume/spin, we will want to apply the TL decospecs # *after* the run decospecs so that they don't take precedence. In other # words, for the same decorator, we want `myflow.py run --with foo` to # take precedence over any other `foo` decospec @@ -516,7 +560,7 @@ def start( parameters.set_parameter_context( ctx.obj.flow.name, ctx.obj.echo, - ctx.obj.flow_datastore, + ctx.obj.flow_datastore if not ctx.obj.is_spin else ctx.obj.spin_flow_datastore, { k: ConfigValue(v) for k, v in ctx.obj.flow.__class__._flow_state.get( @@ -528,9 +572,9 @@ def start( if ( hasattr(ctx, "saved_args") and ctx.saved_args - and ctx.saved_args[0] not in ("run", "resume") + and ctx.saved_args[0] not in ("run", "resume", "spin") ): - # run/resume are special cases because they can add more decorators with --with, + # run/resume/spin are special cases because they can add more decorators with --with, # so they have to take care of themselves. all_decospecs = ctx.obj.tl_decospecs + list( ctx.obj.environment.decospecs() or [] diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index bf77d16ad1f..2e714040365 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -9,17 +9,17 @@ from ..graph import FlowGraph from ..metaflow_current import current from ..package import MetaflowPackage -from ..runtime import NativeRuntime +from ..runtime import NativeRuntime, SpinRuntime from ..system import _system_logger from ..tagging_util import validate_tags -from ..util import get_latest_run_id, write_latest_run_id +from ..util import get_latest_run_id, write_latest_run_id, get_latest_task_pathspec def before_run(obj, tags, decospecs): validate_tags(tags) - # There's a --with option both at the top-level and for the run + # There's a --with option both at the top-level and for the run/resume/spin # subcommand. Why? # # "run --with shoes" looks so much better than "--with shoes run". @@ -39,7 +39,7 @@ def before_run(obj, tags, decospecs): + list(obj.environment.decospecs() or []) ) if all_decospecs: - # These decospecs are the ones from run/resume PLUS the ones from the + # These decospecs are the ones from run/resume/spin PLUS the ones from the # environment (for example the @conda) decorators._attach_decorators(obj.flow, all_decospecs) decorators._init(obj.flow) @@ -51,7 +51,11 @@ def before_run(obj, tags, decospecs): # obj.environment.init_environment(obj.logger) decorators._init_step_decorators( - obj.flow, obj.graph, obj.environment, obj.flow_datastore, obj.logger + obj.flow, + obj.graph, + obj.environment, + obj.flow_datastore if not obj.is_spin else obj.spin_flow_datastore, + obj.logger, ) obj.metadata.add_sticky_tags(tags=tags) @@ -70,6 +74,28 @@ def write_file(file_path, content): f.write(str(content)) +def common_runner_options(func): + @click.option( + "--run-id-file", + default=None, + show_default=True, + type=str, + help="Write the ID of this run to the file specified.", + ) + @click.option( + "--runner-attribute-file", + default=None, + show_default=True, + type=str, + help="Write the metadata and pathspec of this run to the file specified. Used internally for Metaflow's Runner API.", + ) + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + def common_run_options(func): @click.option( "--tag", @@ -110,20 +136,6 @@ def common_run_options(func): "option multiple times to attach multiple decorators " "in steps.", ) - @click.option( - "--run-id-file", - default=None, - show_default=True, - type=str, - help="Write the ID of this run to the file specified.", - ) - @click.option( - "--runner-attribute-file", - default=None, - show_default=True, - type=str, - help="Write the metadata and pathspec of this run to the file specified. Used internally for Metaflow's Runner API.", - ) @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) @@ -167,6 +179,7 @@ def wrapper(*args, **kwargs): @click.argument("step-to-rerun", required=False) @click.command(help="Resume execution of a previous run of this flow.") @common_run_options +@common_runner_options @click.pass_obj def resume( obj, @@ -285,6 +298,7 @@ def resume( @click.command(help="Run the workflow locally.") @tracing.cli_entrypoint("cli/run") @common_run_options +@common_runner_options @click.option( "--namespace", "user_namespace", @@ -360,3 +374,93 @@ def run( f, ) runtime.execute() + + +@click.command(help="Spins up a step locally") +@click.argument( + "step-name", + required=True, + type=str, +) +@click.option( + "--task-pathspec", + default=None, + show_default=True, + help="Task ID to use when spinning up the step. The spinned up step will use the artifacts" + "corresponding to this task ID. If not provided, an arbitrary task ID from the latest run will be used.", +) +@click.option( + "--skip-decorators/--no-skip-decorators", + is_flag=True, + default=False, + show_default=True, + help="Skip decorators attached to the step.", +) +@common_runner_options +@click.pass_obj +def spin( + obj, + step_name, + task_pathspec=None, + skip_decorators=False, + run_id_file=None, + runner_attribute_file=None, + **kwargs +): + before_run(obj, [], []) + if task_pathspec is None: + task_pathspec = get_latest_task_pathspec(obj.flow.name, step_name) + + obj.echo( + f"Spinning up step *{step_name}* locally with task pathspec *{task_pathspec}*" + ) + obj.flow._set_constants(obj.graph, kwargs, obj.config_options) + step_func = getattr(obj.flow, step_name) + + spin_runtime = SpinRuntime( + obj.flow, + obj.graph, + obj.flow_datastore, + obj.metadata, + obj.environment, + obj.package, + obj.logger, + obj.entrypoint, + obj.event_logger, + obj.monitor, + obj.spin_metadata, + obj.spin_flow_datastore, + step_func, + task_pathspec, + skip_decorators, + ) + _system_logger.log_event( + level="info", + module="metaflow.task", + name="spin", + payload={ + "msg": str( + { + "step_name": step_name, + "task_pathspec": task_pathspec, + } + ) + }, + ) + + write_latest_run_id(obj, spin_runtime.run_id) + write_file(run_id_file, spin_runtime.run_id) + spin_runtime.execute() + + if runner_attribute_file: + with open(runner_attribute_file, "w") as f: + json.dump( + { + "task_id": spin_runtime.task.task_id, + "step_name": step_name, + "run_id": spin_runtime.run_id, + "flow_name": obj.flow.name, + "metadata": f"{obj.spin_metadata.__class__.TYPE}@{obj.spin_metadata.__class__.INFO}", + }, + f, + ) diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index 4b40c9e5e54..0497129427b 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -174,3 +174,111 @@ def step( ) echo("Success", fg="green", bold=True, indent=True) + + +@click.command(help="Internal command to spin a single task.", hidden=True) +@click.argument("step-name") +@click.option( + "--run-id", + default=None, + required=True, + help="Run ID for the step that's about to be spun", +) +@click.option( + "--task-id", + default=None, + required=True, + help="Task ID for the step that's about to be spun", +) +@click.option( + "--task-pathspec", + default=None, + show_default=True, + help="Task Pathspec to be used in the spun step.", +) +@click.option( + "--input-paths", + help="A comma-separated list of pathspecs specifying inputs for this step.", +) +@click.option( + "--split-index", + type=int, + default=None, + show_default=True, + help="Index of this foreach split.", +) +@click.option( + "--retry-count", + default=0, + help="How many times we have attempted to run this task.", +) +@click.option( + "--max-user-code-retries", + default=0, + help="How many times we should attempt running the user code.", +) +@click.option( + "--namespace", + "namespace", + default=None, + help="Change namespace from the default (your username) to the specified tag.", +) +@click.option( + "--skip-decorators/--no-skip-decorators", + is_flag=True, + default=False, + show_default=True, + help="Skip decorators attached to the step.", +) +@click.pass_context +def spin_internal( + ctx, + step_name, + run_id=None, + task_id=None, + task_pathspec=None, + input_paths=None, + split_index=None, + retry_count=None, + max_user_code_retries=None, + namespace=None, + skip_decorators=False, +): + import time + + start = time.time() + import sys + + if ctx.obj.is_quiet: + echo = echo_dev_null + else: + echo = echo_always + echo("Spinning a task, *%s*" % step_name, fg="magenta", bold=False) + + input_paths = decompress_list(input_paths) if input_paths else [] + task = MetaflowTask( + ctx.obj.flow, + ctx.obj.flow_datastore, # local datastore + ctx.obj.metadata, # local metadata provider + ctx.obj.environment, # local environment + ctx.obj.echo, + ctx.obj.event_logger, # null logger + ctx.obj.monitor, # null monitor + None, # no unbounded foreach context + ctx.obj.spin_flow_datastore, # spin flow datastore + ctx.obj.spin_metadata, # spin metadata provider + ) + + task.run_spin_step( + step_name, + task_pathspec, + run_id, + task_id, + input_paths, + split_index, + retry_count, + max_user_code_retries, + namespace, + skip_decorators, + ) + print("Time taken for the whole thing: ", time.time() - start) diff --git a/metaflow/datastore/spin_datastore/__init__.py b/metaflow/datastore/spin_datastore/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/metaflow/datastore/spin_datastore/inputs_datastore.py b/metaflow/datastore/spin_datastore/inputs_datastore.py new file mode 100644 index 00000000000..23e5b4006e9 --- /dev/null +++ b/metaflow/datastore/spin_datastore/inputs_datastore.py @@ -0,0 +1,104 @@ +from itertools import chain + + +class LinearStepDatastore(object): + def __init__(self, datastore, artifacts={}): + self._datastore = datastore + self._artifacts = artifacts + self._data = {} + + # Set them to empty dictionaries in order to persist artifacts + # See `persist` method in `TaskDatastore` for more details + self._objects = {} + self._info = {} + + def __contains__(self, name): + try: + _ = self.__getattr__(name) + except AttributeError: + return False + return True + + def __getitem__(self, name): + return self.__getattr__(name) + + def __setitem__(self, name, value): + self._data[name] = value + + def __getattr__(self, name): + # Check internal data first + if name in self._data: + return self._data[name] + + # We always look for any artifacts provided by the user first + if name in self._artifacts: + return self._artifacts[name] + + # If the linear step is part of a foreach step, we need to set the input attribute + # and the index attribute + if name == "input": + if not self._task.index: + raise AttributeError( + f"Attribute '{name}' does not exist for step `{self._task.parent.id}` as it is not part of " + f"a foreach step." + ) + # input only exists for steps immediately after a foreach split + # we check for that by comparing the length of the foreach-step-names + # attribute of the task and its immediate ancestors + foreach_step_names = self._task.metadata_dict.get("foreach-step-names") + prev_task_foreach_step_names = self.previous_task.metadata_dict.get( + "foreach-step-names" + ) + if len(foreach_step_names) <= len(prev_task_foreach_step_names): + return None # input does not exist, so we return None + + foreach_stack = self._task["_foreach_stack"].data + foreach_index = foreach_stack[-1].index + foreach_var = foreach_stack[-1].var + + # Fetch the artifact corresponding to the foreach var and index from the previous task + input_val = self.previous_task[foreach_var].data[foreach_index] + setattr(self, name, input_val) + return input_val + + # If the linear step is part of a foreach step, we need to set the index attribute + if name == "index": + if not self._task.index: + raise AttributeError( + f"Attribute '{name}' does not exist for step `{self.step_name}` as it is not part of a " + f"foreach step." + ) + foreach_stack = self._task["_foreach_stack"].data + foreach_index = foreach_stack[-1].index + setattr(self, name, foreach_index) + return foreach_index + + # If the user has not provided the artifact, we look for it in the + # task using the client API + try: + return getattr(self._datastore, name) + except AttributeError: + raise AttributeError( + f"Attribute '{name}' not found in the previous execution of the task for " + f"`{self.step_name}`." + ) + + @property + def previous_task(self): + if self._previous_task: + return self._previous_task + + # This is a linear step, so we only have one immediate ancestor + from metaflow import Task + + prev_task_pathspec = list( + chain.from_iterable(self._immediate_ancestors.values()) + )[0] + self._previous_task = Task(prev_task_pathspec, _namespace_check=False) + return self._previous_task + + def get(self, key, default=None): + try: + return self.__getattr__(key) + except AttributeError: + return default diff --git a/metaflow/datastore/spin_datastore/step_datastore.py b/metaflow/datastore/spin_datastore/step_datastore.py new file mode 100644 index 00000000000..86e37e79908 --- /dev/null +++ b/metaflow/datastore/spin_datastore/step_datastore.py @@ -0,0 +1,108 @@ +from itertools import chain + + +class LinearStepDatastore(object): + def __init__(self, inp_datastore, foreach_stack, task_index=None, artifacts={}): + self._inp_datastore = inp_datastore + self._input_paths = input_paths + self._foreach_stack = foreach_stack + self._task_index = task_index + self._artifacts = artifacts + self._step_name = self._inp_datastore.step_name + + # Set them to empty dictionaries in order to persist artifacts + # See `persist` method in `TaskDatastore` for more details + self._data = {} + self._objects = {} + self._info = {} + + def __contains__(self, name): + try: + _ = self.__getattr__(name) + except AttributeError: + return False + return True + + def __getitem__(self, name): + return self.__getattr__(name) + + def __setitem__(self, name, value): + self._data[name] = value + + def __getattr__(self, name): + # Check internal data first + if name in self._data: + return self._data[name] + + # We always look for any artifacts provided by the user first + if name in self._artifacts: + return self._artifacts[name] + + # If the linear step is part of a foreach step, we need to set the input attribute + # and the index attribute + if name == "input": + if not self._task.index: + raise AttributeError( + f"Attribute '{name}' does not exist for step `{self._step_name}` as it is not part of " + f"a foreach step." + ) + # input only exists for steps immediately after a foreach split + # we check for that by comparing the length of the foreach-step-names + # attribute of the task and its immediate ancestors + foreach_step_names = self._task.metadata_dict.get("foreach-step-names") + prev_task_foreach_step_names = self.previous_task.metadata_dict.get( + "foreach-step-names" + ) + if len(foreach_step_names) <= len(prev_task_foreach_step_names): + return None # input does not exist, so we return None + + foreach_stack = self._task["_foreach_stack"].data + foreach_index = foreach_stack[-1].index + foreach_var = foreach_stack[-1].var + + # Fetch the artifact corresponding to the foreach var and index from the previous task + input_val = self.previous_task[foreach_var].data[foreach_index] + setattr(self, name, input_val) + return input_val + + # If the linear step is part of a foreach step, we need to set the index attribute + if name == "index": + if not self._task.index: + raise AttributeError( + f"Attribute '{name}' does not exist for step `{self.step_name}` as it is not part of a " + f"foreach step." + ) + foreach_stack = self._task["_foreach_stack"].data + foreach_index = foreach_stack[-1].index + setattr(self, name, foreach_index) + return foreach_index + + # If the user has not provided the artifact, we look for it in the + # task using the client API + try: + return getattr(self.previous_task.artifacts, name).data + except AttributeError: + raise AttributeError( + f"Attribute '{name}' not found in the previous execution of the task for " + f"`{self.step_name}`." + ) + + @property + def previous_task(self): + if self._previous_task: + return self._previous_task + + # This is a linear step, so we only have one immediate ancestor + from metaflow import Task + + prev_task_pathspec = list( + chain.from_iterable(self._immediate_ancestors.values()) + )[0] + self._previous_task = Task(prev_task_pathspec, _namespace_check=False) + return self._previous_task + + def get(self, key, default=None): + try: + return self.__getattr__(key) + except AttributeError: + return default diff --git a/metaflow/flowspec.py b/metaflow/flowspec.py index b36d9c3af75..69b82c133a3 100644 --- a/metaflow/flowspec.py +++ b/metaflow/flowspec.py @@ -184,6 +184,12 @@ def __init__(self, use_cli=True): self._transition = None self._cached_input = {} + # Spin state + self._spin = False + self._spin_index = None + self._spin_input = None + self._spin_foreach_stack = None + if use_cli: with parameters.flow_context(self.__class__) as _: from . import cli @@ -226,9 +232,7 @@ def _check_parameters(cls, config_parameters=False): def _process_config_decorators(cls, config_options, ignore_errors=False): # Fast path for no user configurations - if not cls._flow_state.get(_FlowState.CONFIG_DECORATORS) and all( - len(step.config_decorators) == 0 for step in cls._steps - ): + if not cls._flow_state.get(_FlowState.CONFIG_DECORATORS): # Process parameters to allow them to also use config values easily for var, param in cls._get_parameters(): if param.IS_CONFIG_PARAMETER: @@ -461,6 +465,8 @@ def index(self) -> Optional[int]: int, optional Index of the task in a foreach step. """ + if self._spin: + return self._spin_index if self._foreach_stack: return self._foreach_stack[-1].index @@ -481,6 +487,8 @@ def input(self) -> Optional[Any]: object, optional Input passed to the foreach task. """ + if self._spin: + return self._datastore["input"] return self._find_input() def foreach_stack(self) -> Optional[List[Tuple[int, int, Any]]]: @@ -530,6 +538,8 @@ def nest_2(self): List[Tuple[int, int, Any]] An array describing the current stack of foreach steps. """ + if self._spin: + return self._spin_foreach_stack return [ (frame.index, frame.num_splits, self._find_input(stack_index=i)) for i, frame in enumerate(self._foreach_stack) diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index e99d4f9a001..56be74c9592 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -47,6 +47,14 @@ "DEFAULT_FROM_DEPLOYMENT_IMPL", "argo-workflows" ) +### +# Spin configuration +### +SPIN_ALLOWED_DECORATORS = from_conf( + "SPIN_ALLOWED_DECORATORS", ["conda", "pypi", "environment"] +) + + ### # User configuration ### @@ -248,8 +256,7 @@ # Default container registry DEFAULT_CONTAINER_REGISTRY = from_conf("DEFAULT_CONTAINER_REGISTRY") # Controls whether to include foreach stack information in metadata. -# TODO(Darin, 05/01/24): Remove this flag once we are confident with this feature. -INCLUDE_FOREACH_STACK = from_conf("INCLUDE_FOREACH_STACK", False) +INCLUDE_FOREACH_STACK = from_conf("INCLUDE_FOREACH_STACK", True) # Maximum length of the foreach value string to be stored in each ForeachFrame. MAXIMUM_FOREACH_VALUE_CHARS = from_conf("MAXIMUM_FOREACH_VALUE_CHARS", 30) # The default runtime limit (In seconds) of jobs launched by any compute provider. Default of 5 days. @@ -433,9 +440,6 @@ # should result in an appreciable speedup in flow environment initialization. CONDA_DEPENDENCY_RESOLVER = from_conf("CONDA_DEPENDENCY_RESOLVER", "conda") -# Default to not using fast init binary. -CONDA_USE_FAST_INIT = from_conf("CONDA_USE_FAST_INIT", False) - ### # Escape hatch configuration ### diff --git a/metaflow/metaflow_current.py b/metaflow/metaflow_current.py index 8443c1d75ab..73c9f8359d8 100644 --- a/metaflow/metaflow_current.py +++ b/metaflow/metaflow_current.py @@ -45,6 +45,7 @@ def _set_env( username=None, metadata_str=None, is_running=True, + is_spin=False, tags=None, ): if flow is not None: @@ -60,6 +61,7 @@ def _set_env( self._username = username self._metadata_str = metadata_str self._is_running = is_running + self._is_spin = is_spin self._tags = tags def _update_env(self, env): diff --git a/metaflow/runner/metaflow_runner.py b/metaflow/runner/metaflow_runner.py index 4759a13b191..38a0525b0fb 100644 --- a/metaflow/runner/metaflow_runner.py +++ b/metaflow/runner/metaflow_runner.py @@ -4,7 +4,6 @@ import json from typing import Dict, Iterator, Optional, Tuple - from metaflow import Run from metaflow.plugins import get_runner_cli @@ -17,27 +16,33 @@ from .subprocess_manager import CommandManager, SubprocessManager -class ExecutingRun(object): +class ExecutingProcess(object): """ - This class contains a reference to a `metaflow.Run` object representing - the currently executing or finished run, as well as metadata related - to the process. + This is a base class for `ExecutingRun` and `ExecutingTask` classes. + The `ExecutingRun` and `ExecutingTask` classes are returned by methods + in `Runner` and `NBRunner`, and they are subclasses of this class. - `ExecutingRun` is returned by methods in `Runner` and `NBRunner`. It is not - meant to be instantiated directly. + The `ExecutingRun` class for instance contains a reference to a `metaflow.Run` + object representing the currently executing or finished run, as well as the metadata + related to the process. + + Similarly, the `ExecutingTask` class contains a reference to a `metaflow.Task` + object representing the currently executing or finished task, as well as the metadata + related to the process. + + This class or its subclasses are not meant to be instantiated directly. The class + works as a context manager, allowing you to use a pattern like: - This class works as a context manager, allowing you to use a pattern like ```python with Runner(...).run() as running: ... ``` - Note that you should use either this object as the context manager or - `Runner`, not both in a nested manner. + + Note that you should use either this object as the context manager or `Runner`, not both + in a nested manner. """ - def __init__( - self, runner: "Runner", command_obj: CommandManager, run_obj: Run - ) -> None: + def __init__(self, runner: "Runner", command_obj: CommandManager) -> None: """ Create a new ExecutingRun -- this should not be done by the user directly but instead user Runner.run() @@ -48,12 +53,9 @@ def __init__( Parent runner for this run. command_obj : CommandManager CommandManager containing the subprocess executing this run. - run_obj : Run - Run object corresponding to this run. """ self.runner = runner self.command_obj = command_obj - self.run = run_obj def __enter__(self) -> "ExecutingRun": return self @@ -72,11 +74,10 @@ async def wait( Parameters ---------- - timeout : float, optional, default None - The maximum time, in seconds, to wait for the run to finish. - If the timeout is reached, the run is terminated. If not specified, wait - forever. - stream : str, optional, default None + timeout : Optional[float], default None + The maximum time to wait for the run to finish. + If the timeout is reached, the run is terminated + stream : Optional[str], default None If specified, the specified stream is printed to stdout. `stream` can be one of `stdout` or `stderr`. @@ -173,7 +174,7 @@ async def stream_log( ---------- stream : str The stream to stream logs from. Can be one of `stdout` or `stderr`. - position : int, optional, default None + position : Optional[int], default None The position in the log file to start streaming from. If None, it starts from the beginning of the log file. This allows resuming streaming from a previously known position @@ -189,6 +190,83 @@ async def stream_log( yield position, line +class ExecutingTask(ExecutingProcess): + """ + This class contains a reference to a `metaflow.Task` object representing + the currently executing or finished task, as well as metadata related + to the process. + + `ExecutingTask` is returned by methods in `Runner` and `NBRunner`. It is not + meant to be instantiated directly. + + This class works as a context manager, allowing you to use a pattern like + ```python + with Runner(...).spin() as running: + ... + ``` + Note that you should use either this object as the context manager or + `Runner`, not both in a nested manner. + + """ + + def __init__( + self, runner: "Runner", command_obj: CommandManager, task_obj: "metaflow.Task" + ) -> None: + """ + Create a new ExecutingTask -- this should not be done by the user directly but + instead user Runner.spin() + + Parameters + ---------- + runner : Runner + Parent runner for this task. + command_obj : CommandManager + CommandManager containing the subprocess executing this task. + task_obj : Task + Task object corresponding to this task. + """ + super().__init__(runner, command_obj) + self.task = task_obj + + +class ExecutingRun(ExecutingProcess): + """ + This class contains a reference to a `metaflow.Run` object representing + the currently executing or finished run, as well as metadata related + to the process. + + `ExecutingRun` is returned by methods in `Runner` and `NBRunner`. It is not + meant to be instantiated directly. + + This class works as a context manager, allowing you to use a pattern like + ```python + with Runner(...).run() as running: + ... + ``` + Note that you should use either this object as the context manager or + `Runner`, not both in a nested manner. + """ + + def __init__( + self, runner: "Runner", command_obj: CommandManager, run_obj: Run + ) -> None: + """ + Create a new ExecutingRun -- this should not be done by the user directly but + instead user Runner.run() + + Parameters + ---------- + runner : Runner + Parent runner for this run. + command_obj : CommandManager + CommandManager containing the subprocess executing this run. + run_obj : Run + Run object corresponding to this run. + """ + super().__init__(runner, command_obj) + self.run = run_obj + + class RunnerMeta(type): def __new__(mcs, name, bases, dct): cls = super().__new__(mcs, name, bases, dct) @@ -322,6 +400,23 @@ def __get_executing_run(self, attribute_file_fd, command_obj): ) return ExecutingRun(self, command_obj, run_object) + def __get_executing_task(self, attribute_file_fd, command_obj): + content = handle_timeout(attribute_file_fd, command_obj, self.file_read_timeout) + + command_obj.sync_wait() + + content = json.loads(content) + pathspec = f"{content.get('flow_name')}/{content.get('run_id')}/{content.get('step_name')}/{content.get('task_id')}" + + # Set the correct metadata from the runner_attribute file corresponding to this run. + metadata_for_flow = content.get("metadata") + from metaflow import Task + + task_object = Task( + pathspec, _namespace_check=False, _current_metadata=metadata_for_flow + ) + return ExecutingTask(self, command_obj, task_object) + async def __async_get_executing_run(self, attribute_file_fd, command_obj): content = await async_handle_timeout( attribute_file_fd, command_obj, self.file_read_timeout @@ -337,6 +432,23 @@ async def __async_get_executing_run(self, attribute_file_fd, command_obj): ) return ExecutingRun(self, command_obj, run_object) + async def __async_get_executing_task(self, attribute_file_fd, command_obj): + content = await async_handle_timeout( + attribute_file_fd, command_obj, self.file_read_timeout + ) + content = json.loads(content) + pathspec = f"{content.get('flow_name')}/{content.get('run_id')}/{content.get('step_name')}/{content.get('task_id')}" + + # Set the correct metadata from the runner_attribute file corresponding to this run. + metadata_for_flow = content.get("metadata") + + from metaflow import Task + + task_object = Task( + pathspec, _namespace_check=False, _current_metadata=metadata_for_flow + ) + return ExecutingTask(self, command_obj, task_object) + def run(self, **kwargs) -> ExecutingRun: """ Blocking execution of the run. This method will wait until @@ -399,6 +511,44 @@ def resume(self, **kwargs) -> ExecutingRun: return self.__get_executing_run(attribute_file_fd, command_obj) + def spin(self, step_name, task_pathspec, **kwargs): + """ + Blocking spin execution of the run. + This method will wait until the spun run has completed execution. + + Parameters + ---------- + step_name : str + The name of the step to spin. + task_pathspec : str, optional, default None + The task pathspec to be used in the spun task. + **kwargs : Any + Additional arguments that you would pass to `python ./myflow.py` after + the `spin` command. + + Returns + ------- + ExecutingTask + ExecutingTask containing the results of the spun task. + """ + with temporary_fifo() as (attribute_file_path, attribute_file_fd): + command = self.api(**self.top_level_kwargs).spin( + step_name=step_name, + task_pathspec=task_pathspec, + runner_attribute_file=attribute_file_path, + **kwargs, + ) + + pid = self.spm.run_command( + [sys.executable, *command], + env=self.env_vars, + cwd=self.cwd, + show_output=self.show_output, + ) + command_obj = self.spm.get(pid) + + return self.__get_executing_task(attribute_file_fd, command_obj) + async def async_run(self, **kwargs) -> ExecutingRun: """ Non-blocking execution of the run. This method will return as soon as the diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 7e9269841fb..86d84eb015e 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -18,13 +18,14 @@ from io import BytesIO from functools import partial from concurrent import futures - -from metaflow.datastore.exceptions import DataException +from itertools import chain from contextlib import contextmanager from . import get_namespace +from metaflow.datastore.exceptions import DataException from .metadata_provider import MetaDatum from .metaflow_config import MAX_ATTEMPTS, UI_URL +from .metaflow_config import SPIN_ALLOWED_DECORATORS from .exception import ( MetaflowException, MetaflowInternalError, @@ -73,6 +74,251 @@ # TODO option: output dot graph periodically about execution +class SpinRuntime(object): + def __init__( + self, + flow, + graph, + flow_datastore, + metadata, + environment, + package, + logger, + entrypoint, + event_logger, + monitor, + spin_metadata, + spin_flow_datastore, + step_func, + task_pathspec, + skip_decorators=False, + max_log_size=MAX_LOG_SIZE, + ): + from metaflow import Task + + self._flow = flow + self._graph = graph + self._flow_datastore = flow_datastore + self._metadata = metadata + self._environment = environment + self._package = package + self._logger = logger + self._entrypoint = entrypoint + self._event_logger = event_logger + self._monitor = monitor + + # Spin specific metadata, event_logger, monitor, and flow_datastore + self._spin_metadata = spin_metadata + self._spin_flow_datastore = spin_flow_datastore + print(f"Spin Flow Datastore in SpinRuntime: {self._spin_flow_datastore}") + + self._step_func = step_func + self._task_pathspec = task_pathspec + self._task = Task(self._task_pathspec, _namespace_check=False) + self._input_paths = None + self._split_index = None + self._whitelist_decorators = None + self._config_file_name = None + self._skip_decorators = skip_decorators + self._max_log_size = max_log_size + self._encoding = sys.stdout.encoding or "UTF-8" + + # Create a new run_id for the spin task + self.run_id = self._spin_metadata.new_run_id() + for deco in self.whitelist_decorators: + print("-" * 100) + deco.runtime_init(flow, graph, package, self.run_id) + + @property + def split_index(self): + if self._split_index: + return self._split_index + + self._split_index = self._task.index + return self._split_index + + @property + def input_paths(self): + if self._input_paths: + return self._input_paths + + # We use the _input_paths artifact to get the input paths for the task + self._input_paths = self._task["_input_paths"].data + return self._input_paths + + @property + def whitelist_decorators(self): + if self._skip_decorators: + return [] + if self._whitelist_decorators: + return self._whitelist_decorators + self._whitelist_decorators = [ + deco + for deco in self._step_func.decorators + if any(deco.name.startswith(prefix) for prefix in SPIN_ALLOWED_DECORATORS) + ] + return self._whitelist_decorators + + def _new_task(self, step, input_paths=None, **kwargs): + return Task( + self._flow_datastore, + self._flow, + step, + self.run_id, + self._metadata, + self._environment, + self._entrypoint, + self._event_logger, + self._monitor, + input_paths=self.input_paths, + decos=self.whitelist_decorators, + logger=self._logger, + split_index=self.split_index, + spin_flow_datastore=self._spin_flow_datastore, + spin_metadata=self._spin_metadata, + is_spin=True, + **kwargs, + ) + + def execute(self): + exception = None + with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as config_file: + config_value = dump_config_values(self._flow) + if config_value: + json.dump(config_value, config_file) + config_file.flush() + self._config_file_name = config_file.name + else: + self._config_file_name = None + + self.task = self._new_task(self._step_func.name, {}) + _ds = self._spin_flow_datastore.get_task_datastore( + self.run_id, + self._step_func.name, + self.task.task_id, + attempt=0, + mode="w", + ) + + try: + self._launch_and_monitor_task() + except Exception as ex: + self._logger("Task failed.", system_msg=True, bad=True) + exception = ex + raise + finally: + for deco in self.whitelist_decorators: + deco.runtime_finished(exception) + + def _launch_and_monitor_task(self): + args = CLIArgs(self.task, spin=True, prev_task_pathspec=self._task_pathspec) + env = dict(os.environ) + + for deco in self.task.decos: + deco.runtime_step_cli( + args, + self.task.retries, + self.task.user_code_retries, + self.task.ubf_context, + ) + + # Add user configurations using a file to avoid using up too much space on the + # command line + if self._config_file_name: + args.top_level_options["local-config-file"] = self._config_file_name + + # Add the skip-decorators flag to the command options + args.command_options.update({"skip-decorators": self._skip_decorators}) + + env.update(args.get_env()) + env["PYTHONUNBUFFERED"] = "x" + + stdout_buffer = TruncatedBuffer("stdout", self._max_log_size) + stderr_buffer = TruncatedBuffer("stderr", self._max_log_size) + + cmdline = args.get_args() + self._logger(f"Launching command: {' '.join(cmdline)}", system_msg=True) + + try: + process = subprocess.Popen( + cmdline, + env=env, + bufsize=1, # Line buffering + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + text=True, + ) + except Exception as e: + raise TaskFailed(self.task, f"Failed to launch task: {str(e)}") + + poll = procpoll.make_poll() + poll.add(process.stdout.fileno()) + poll.add(process.stderr.fileno()) + + fd_map = { + process.stdout.fileno(): (process.stdout, stdout_buffer, False), + process.stderr.fileno(): (process.stderr, stderr_buffer, True), + } + + while True: + events = poll.poll(POLL_TIMEOUT) + + if not events: + if process.poll() is not None: + break + continue + + for event in events: + if event.can_read: + stream, buffer, is_stderr = fd_map[event.fd] + line = stream.readline() + if line: + self._process_output(line, buffer, is_stderr) + + if event.is_terminated: + poll.remove(event.fd) + + if process.poll() is not None: + break + + # Process remaining output + for stream, buffer, is_stderr in fd_map.values(): + for line in stream: + self._process_output(line, buffer, is_stderr) + + returncode = process.wait() + + self.task.save_metadata( + "runtime", + { + "return_code": returncode, + "success": returncode == 0, + }, + ) + + if returncode != 0: + raise TaskFailed(self.task, f"Task failed with return code {returncode}") + else: + self._logger("Task finished successfully.", system_msg=True) + + self.task.save_logs( + { + "stdout": stdout_buffer.get_buffer(), + "stderr": stderr_buffer.get_buffer(), + } + ) + + def _process_output(self, line, buffer, is_stderr=False): + buffer.write(line.encode(self._encoding)) + text = line.strip() + self.task.log( + text, + system_msg=False, + timestamp=datetime.now(), + ) + + class NativeRuntime(object): def __init__( self, @@ -440,6 +686,8 @@ def clone_original_run(self, generate_task_obj=False, verbose=True): ) ) + print(f"Inputs: {inputs}") + with futures.ThreadPoolExecutor(max_workers=self._max_workers) as executor: all_tasks = [ executor.submit( @@ -1030,6 +1278,9 @@ def __init__( task_id=None, resume_identifier=None, pathspec_index=None, + spin_flow_datastore=None, + spin_metadata=None, + is_spin=False, ): self.step = step self.flow = flow @@ -1068,6 +1319,11 @@ def __init__( self.datastore_sysroot = flow_datastore.datastore_root self._results_ds = None + # Spin metadata and flow datastore + self.spin_flow_datastore = spin_flow_datastore + self.spin_metadata = spin_metadata + self.is_spin = is_spin + # Only used in clone-only resume. self._is_resume_leader = None self._resume_done = None @@ -1258,6 +1514,7 @@ def __init__( system_msg=True, ) else: + # For spin tasks we will always short-circuit to here self._is_cloned = False if clone_only: # We are done -- we don't proceed to create new task-ids @@ -1267,10 +1524,11 @@ def __init__( # Open the output datastore only if the task is not being cloned. if not self._is_cloned: self.new_attempt() + print(f"I am already running runtime_task_created decorator hook here") for deco in decos: deco.runtime_task_created( self._ds, - task_id, + self.task_id, split_index, input_paths, self._is_cloned, @@ -1302,9 +1560,15 @@ def __init__( self.error_retries = 0 def new_attempt(self): - self._ds = self._flow_datastore.get_task_datastore( - self.run_id, self.step, self.task_id, attempt=self.retries, mode="w" - ) + print(f"I am in new_attempt and self.task_id: {self.task_id}") + if self.is_spin: + self._ds = self.spin_flow_datastore.get_task_datastore( + self.run_id, self.step, self.task_id, attempt=self.retries, mode="w" + ) + else: + self._ds = self._flow_datastore.get_task_datastore( + self.run_id, self.step, self.task_id, attempt=self.retries, mode="w" + ) self._ds.init_task() def log(self, msg, system_msg=False, pid=None, timestamp=True): @@ -1350,17 +1614,30 @@ def _get_task_id(self, task_id): if task_id is None: task_id = str( self.metadata.new_task_id(self.run_id, self.step, sys_tags=tags) + if not self.is_spin + else self.spin_metadata.new_task_id( + self.run_id, self.step, sys_tags=tags + ) ) already_existed = False else: # task_id is preset only by persist_constants(). - already_existed = not self.metadata.register_task_id( - self.run_id, - self.step, - task_id, - 0, - sys_tags=tags, - ) + if self.is_spin: + already_existed = not self.spin_metadata.register_task_id( + self.run_id, + self.step, + task_id, + 0, + sys_tags=tags, + ) + else: + already_existed = not self.metadata.register_task_id( + self.run_id, + self.step, + task_id, + 0, + sys_tags=tags, + ) self.task_id = task_id self._path = "%s/%s/%s" % (self.run_id, self.step, self.task_id) @@ -1508,11 +1785,13 @@ class CLIArgs(object): for step execution in StepDecorator.runtime_step_cli(). """ - def __init__(self, task): + def __init__(self, task, spin=False, prev_task_pathspec=None): self.task = task + self.spin = spin + self.prev_task_pathspec = prev_task_pathspec self.entrypoint = list(task.entrypoint) self.top_level_options = { - "quiet": True, + "quiet": True if spin else False, "metadata": self.task.metadata_type, "environment": self.task.environment_type, "datastore": self.task.datastore_type, @@ -1542,18 +1821,42 @@ def __init__(self, task): (k, ConfigInput.make_key_name(k)) for k in configs ] + # print(f"top_level_options: {self.top_level_options}") + + if spin: + self.spin_args() + else: + self.default_args() + + def default_args(self): self.commands = ["step"] self.command_args = [self.task.step] self.command_options = { - "run-id": task.run_id, - "task-id": task.task_id, - "input-paths": compress_list(task.input_paths), - "split-index": task.split_index, - "retry-count": task.retries, - "max-user-code-retries": task.user_code_retries, - "tag": task.tags, + "run-id": self.task.run_id, + "task-id": self.task.task_id, + "input-paths": compress_list(self.task.input_paths), + "split-index": self.task.split_index, + "retry-count": self.task.retries, + "max-user-code-retries": self.task.user_code_retries, + "tag": self.task.tags, + "namespace": get_namespace() or "", + "ubf-context": self.task.ubf_context, + } + self.env = {} + + def spin_args(self): + self.commands = ["spin-internal"] + self.command_args = [self.task.step] + + self.command_options = { + "run-id": self.task.run_id, + "task-id": self.task.task_id, + "task-pathspec": self.prev_task_pathspec, + "input-paths": self.task.input_paths, + "split-index": self.task.split_index, + "retry-count": self.task.retries, + "max-user-code-retries": self.task.user_code_retries, "namespace": get_namespace() or "", - "ubf-context": task.ubf_context, } self.env = {} diff --git a/metaflow/task.py b/metaflow/task.py index 6b73302652b..47239a7c49a 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -6,6 +6,7 @@ import time import traceback +from itertools import chain from types import MethodType, FunctionType from metaflow.sidecar import Message, MessageTypes @@ -47,6 +48,8 @@ def __init__( event_logger, monitor, ubf_context, + spin_flow_datastore=None, + spin_metadata=None, ): self.flow = flow self.flow_datastore = flow_datastore @@ -56,6 +59,8 @@ def __init__( self.event_logger = event_logger self.monitor = monitor self.ubf_context = ubf_context + self.spin_flow_datastore = spin_flow_datastore + self.spin_metadata = spin_metadata def _exec_step_function(self, step_function, input_obj=None): if input_obj is None: @@ -372,6 +377,214 @@ def _finalize_control_task(self): ) ) + def run_spin_step( + self, + step_name, + task_pathspec, + new_run_id, + new_task_id, + input_paths, + split_index, + retry_count, + max_user_code_retries, + namespace, + skip_decorators, + ): + t1 = time.time() + node = self.flow._graph[step_name] + join_type = None + if node.type == "join": + join_type = self.flow._graph[node.split_parents[-1]].type + print(f"node.type: {node.type}") + parent_type = self.flow._graph[node.split_parents[-1]].type + print(f"parent_type: {parent_type}") + t2 = time.time() + print(f"t2 - t1: {t2 - t1}") + + step_func = getattr(self.flow, step_name) + whitelisted_decorators = ( + [] + if skip_decorators + else [ + deco + for deco in step_func.decorators + if any( + deco.name.startswith(prefix) for prefix in SPIN_ALLOWED_DECORATORS + ) + ] + ) + t3 = time.time() + # initialize output datastore + output = self.spin_flow_datastore.get_task_datastore( + new_run_id, step_name, new_task_id, 0, mode="w" + ) + + output.init_task() + t4 = time.time() + print(f"t4 - t3: {t4 - t3}") + # + # How we access the input and index attributes depends on the execution context. + # If spin is set to True, we short-circuit attribute access to getattr directly + # Also set the other attributes that are needed for the task to execute + # from metaflow import Task + self.flow._spin = True + flow_id, prev_run_id, prev_step_name, prev_task_id = task_pathspec.split("/") + # print(f"input_paths: {input_paths}") + # print(f"prev_run_id: {prev_run_id}") + t5 = time.time() + inputs = self._init_data( + prev_run_id, + join_type, + input_paths, + ) + t6 = time.time() + print(f"t6 - t5: {t6 - t5}") + t7 = time.time() + prev_task_datastore = self.flow_datastore.get_task_datastore( + prev_run_id, + prev_step_name, + prev_task_id, + ) + self.flow._spin_index = split_index + self.flow._current_step = step_name + self.flow._success = False + self.flow._task_ok = None + self.flow._exception = None + self.flow._spin_foreach_stack = prev_task_datastore["_foreach_stack"] + t8 = time.time() + print(f"t8 - t7: {t8 - t7}") + # + # # Set inputs + # if node.type == "join": + # if join_type == "foreach": + # pass + # else: + # pass + # else: + # inp_datastore = SpinInputsDatastore( + # inputs[0], + # input_paths, + # foreach_stack=self.flow._spin_foreach_stack, + # task_index=split_index, + # artifacts={} + # ) + # inp_datastore = None + # is_join = is_join_step(immediate_ancestors) + # if is_join: + # # Join step + # if len(self.task.metadata_dict.get("previous-steps")) > 1: + # # Static join step + # inp_datastore = StaticSpinInputsDatastore( + # self.task, immediate_ancestors, artifacts={} + # ) + # else: + # # Foreach join step + # inp_datastore = SpinInputsDatastore( + # self.task, immediate_ancestors, artifacts={} + # ) + # self.flow._set_datastore(output) + # else: + # # Linear step + # self.flow._set_datastore( + # LinearStepDatastore(self.task, immediate_ancestors, artifacts={}) + # ) + # + # current._set_env( + # flow=self.flow, + # run_id=new_run_id, + # step_name=step_name, + # task_id=new_task_id, + # retry_count=retry_count, + # namespace=resolve_identity(), + # username=get_username(), + # metadata_str="%s@%s" + # % (self.metadata.__class__.TYPE, self.metadata.__class__.INFO), + # is_running=True, + # is_spin=True, + # ) + + # task_pre_step decorator hooks + # for deco in whitelisted_decorators: + # deco.task_pre_step( + # step_name=step_name, + # task_datastore=output, + # metadata=self.metadata, + # run_id=new_run_id, + # task_id=new_task_id, + # flow=self.flow, + # graph=self.flow._graph, + # retry_count=retry_count, + # max_user_code_retries=max_user_code_retries, + # ubf_context=self.ubf_context, + # inputs=inp_datastore, + # ) + # + # # task_decorate decorator hooks + # for deco in whitelisted_decorators: + # step_func = deco.task_decorate( + # step_func=step_func, + # flow=self.flow, + # graph=self.flow._graph, + # retry_count=retry_count, + # max_user_code_retries=max_user_code_retries, + # ubf_context=self.ubf_context, + # ) + # + # # Execute the step function + # try: + # if is_join: + # # Join step + # self._exec_step_function(step_func, input_obj=inp_datastore) + # else: + # self._exec_step_function(step_func) + # + # # task_post_step decorator hooks + # for deco in whitelisted_decorators: + # deco.task_post_step( + # step_name, + # self.flow, + # self.flow._graph, + # retry_count, + # max_user_code_retries, + # ) + # + # self.flow._task_ok = True + # self.flow._success = True + # except Exception as ex: + # exception_handled = False + # for deco in whitelisted_decorators: + # res = deco.task_exception( + # ex, + # step_name, + # self.flow, + # self.flow._graph, + # retry_count, + # max_user_code_retries, + # ) + # exception_handled = bool(res) or exception_handled + # + # if exception_handled: + # self.flow._task_ok = True + # else: + # self.flow._task_ok = False + # self.flow._exception = MetaflowExceptionWrapper(ex) + # print("%s failed:" % self.flow, file=sys.stderr) + # raise + # finally: + # output.persist(self.flow) + # output.done() + # + # # task_finish decorator hooks + # for deco in whitelisted_decorators: + # deco.task_finished( + # step_name, + # self.flow, + # self.flow._graph, + # self.flow._task_ok, + # retry_count, + # max_user_code_retries, + # ) + def run_step( self, step_name, @@ -559,6 +772,7 @@ def run_step( self.flow._success = False self.flow._task_ok = None self.flow._exception = None + self.flow._input_paths = input_paths # Note: All internal flow attributes (ie: non-user artifacts) # should either be set prior to running the user code or listed in # FlowSpec._EPHEMERAL to allow for proper merging/importing of diff --git a/metaflow/util.py b/metaflow/util.py index cd3447d0e48..e9355e31ae1 100644 --- a/metaflow/util.py +++ b/metaflow/util.py @@ -9,6 +9,7 @@ from itertools import takewhile import re +from typing import Callable from metaflow.exception import MetaflowUnknownUser, MetaflowInternalError try: @@ -193,6 +194,49 @@ def get_latest_run_id(echo, flow_name): return None +def get_latest_task_pathspec(flow_name: str, step_name: str) -> str: + """ + Returns a task pathspec from the latest run of the flow for the queried step. + If the queried step has several tasks, the task pathspec of the first task is returned. + + Parameters + ---------- + flow_name : str + The name of the flow. + step_name : str + The name of the step. + + Returns + ------- + str + The task pathspec of the first task of the queried step. + + Raises + ------ + MetaflowNotFound + If no task or run is found for the queried step. + """ + from metaflow import Flow, Step + from metaflow.exception import MetaflowNotFound + + run = Flow(flow_name, _namespace_check=False).latest_run + + if run is None: + raise MetaflowNotFound(f"No run found for the flow {flow_name}") + + try: + step = Step(f"{flow_name}/{run.id}/{step_name}", _namespace_check=False) + except Exception: + raise MetaflowNotFound( + f"No step *{step_name}* found in run *{run.id}* for flow *{flow_name}*" + ) + + task = next(iter(step.tasks()), None) + if task: + return f"{flow_name}/{run.id}/{step_name}/{task.id}" + raise MetaflowNotFound(f"No task found for the queried step {query_step}") + + def write_latest_run_id(obj, run_id): from metaflow.plugins.datastores.local_storage import LocalStorage From dc357aa346cbef23c7f1985d27a087f2d3527bbb Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Tue, 11 Feb 2025 23:11:27 -0800 Subject: [PATCH 2/2] Dummy commit --- metaflow/datastore/__init__.py | 1 + .../spin_datastore/step_datastore.py | 49 ++--- metaflow/task.py | 199 +++++++++--------- 3 files changed, 120 insertions(+), 129 deletions(-) diff --git a/metaflow/datastore/__init__.py b/metaflow/datastore/__init__.py index 793251b0cff..236ed2151ca 100644 --- a/metaflow/datastore/__init__.py +++ b/metaflow/datastore/__init__.py @@ -2,3 +2,4 @@ from .flow_datastore import FlowDataStore from .datastore_set import TaskDataStoreSet from .task_datastore import TaskDataStore +from .spin_datastore.step_datastore import SpinStepDatastore diff --git a/metaflow/datastore/spin_datastore/step_datastore.py b/metaflow/datastore/spin_datastore/step_datastore.py index 86e37e79908..85b9d868176 100644 --- a/metaflow/datastore/spin_datastore/step_datastore.py +++ b/metaflow/datastore/spin_datastore/step_datastore.py @@ -1,14 +1,13 @@ from itertools import chain -class LinearStepDatastore(object): - def __init__(self, inp_datastore, foreach_stack, task_index=None, artifacts={}): - self._inp_datastore = inp_datastore - self._input_paths = input_paths - self._foreach_stack = foreach_stack - self._task_index = task_index +class SpinStepDatastore(object): + def __init__(self, inp_datastores, _foreach_stack, artifacts={}): + # This is a linear step, so we only have one input datastore + self._inp_datastores = inp_datastores[0] + self._foreach_stack = _foreach_stack self._artifacts = artifacts - self._step_name = self._inp_datastore.step_name + self._step_name = self._inp_datastores.step_name # Set them to empty dictionaries in order to persist artifacts # See `persist` method in `TaskDatastore` for more details @@ -41,7 +40,7 @@ def __getattr__(self, name): # If the linear step is part of a foreach step, we need to set the input attribute # and the index attribute if name == "input": - if not self._task.index: + if len(self._foreach_stack) == 0: raise AttributeError( f"Attribute '{name}' does not exist for step `{self._step_name}` as it is not part of " f"a foreach step." @@ -49,31 +48,27 @@ def __getattr__(self, name): # input only exists for steps immediately after a foreach split # we check for that by comparing the length of the foreach-step-names # attribute of the task and its immediate ancestors - foreach_step_names = self._task.metadata_dict.get("foreach-step-names") - prev_task_foreach_step_names = self.previous_task.metadata_dict.get( - "foreach-step-names" - ) - if len(foreach_step_names) <= len(prev_task_foreach_step_names): + foreach_step = self._foreach_stack[-1].step + if self._step_name != foreach_step: return None # input does not exist, so we return None - foreach_stack = self._task["_foreach_stack"].data - foreach_index = foreach_stack[-1].index - foreach_var = foreach_stack[-1].var + foreach_index = self._foreach_stack[-1].index + foreach_var = self._foreach_stack[-1].var # Fetch the artifact corresponding to the foreach var and index from the previous task - input_val = self.previous_task[foreach_var].data[foreach_index] + input_val = self._inp_datastores[foreach_var].data[foreach_index] setattr(self, name, input_val) return input_val # If the linear step is part of a foreach step, we need to set the index attribute if name == "index": - if not self._task.index: + if len(self._foreach_stack) == 0: raise AttributeError( f"Attribute '{name}' does not exist for step `{self.step_name}` as it is not part of a " f"foreach step." ) - foreach_stack = self._task["_foreach_stack"].data - foreach_index = foreach_stack[-1].index + + foreach_index = self._foreach_stack[-1].index setattr(self, name, foreach_index) return foreach_index @@ -87,20 +82,6 @@ def __getattr__(self, name): f"`{self.step_name}`." ) - @property - def previous_task(self): - if self._previous_task: - return self._previous_task - - # This is a linear step, so we only have one immediate ancestor - from metaflow import Task - - prev_task_pathspec = list( - chain.from_iterable(self._immediate_ancestors.values()) - )[0] - self._previous_task = Task(prev_task_pathspec, _namespace_check=False) - return self._previous_task - def get(self, key, default=None): try: return self.__getattr__(key) diff --git a/metaflow/task.py b/metaflow/task.py index 47239a7c49a..03e5f1d55e0 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -15,7 +15,7 @@ from .metaflow_config import MAX_ATTEMPTS from .metadata_provider import MetaDatum from .mflog import TASK_LOG_SOURCE -from .datastore import Inputs, TaskDataStoreSet +from .datastore import Inputs, TaskDataStoreSet, SpinStepDatastore from .exception import ( MetaflowInternalError, MetaflowDataMissing, @@ -453,8 +453,15 @@ def run_spin_step( self.flow._spin_foreach_stack = prev_task_datastore["_foreach_stack"] t8 = time.time() print(f"t8 - t7: {t8 - t7}") + + inp_datastore = SpinStepDatastore( + inputs, + self.flow._spin_foreach_stack, + artifacts={}, + ) + # - # # Set inputs + # Set inputs # if node.type == "join": # if join_type == "foreach": # pass @@ -468,6 +475,7 @@ def run_spin_step( # task_index=split_index, # artifacts={} # ) + # print(f"inp_datastore: {inp_datastore}") # inp_datastore = None # is_join = is_join_step(immediate_ancestors) # if is_join: @@ -489,101 +497,102 @@ def run_spin_step( # LinearStepDatastore(self.task, immediate_ancestors, artifacts={}) # ) # - # current._set_env( - # flow=self.flow, - # run_id=new_run_id, - # step_name=step_name, - # task_id=new_task_id, - # retry_count=retry_count, - # namespace=resolve_identity(), - # username=get_username(), - # metadata_str="%s@%s" - # % (self.metadata.__class__.TYPE, self.metadata.__class__.INFO), - # is_running=True, - # is_spin=True, - # ) + current._set_env( + flow=self.flow, + run_id=new_run_id, + step_name=step_name, + task_id=new_task_id, + retry_count=retry_count, + namespace=resolve_identity(), + username=get_username(), + metadata_str="%s@%s" + % (self.metadata.__class__.TYPE, self.metadata.__class__.INFO), + is_running=True, + is_spin=True, + ) # task_pre_step decorator hooks - # for deco in whitelisted_decorators: - # deco.task_pre_step( - # step_name=step_name, - # task_datastore=output, - # metadata=self.metadata, - # run_id=new_run_id, - # task_id=new_task_id, - # flow=self.flow, - # graph=self.flow._graph, - # retry_count=retry_count, - # max_user_code_retries=max_user_code_retries, - # ubf_context=self.ubf_context, - # inputs=inp_datastore, - # ) - # - # # task_decorate decorator hooks - # for deco in whitelisted_decorators: - # step_func = deco.task_decorate( - # step_func=step_func, - # flow=self.flow, - # graph=self.flow._graph, - # retry_count=retry_count, - # max_user_code_retries=max_user_code_retries, - # ubf_context=self.ubf_context, - # ) - # - # # Execute the step function - # try: - # if is_join: - # # Join step - # self._exec_step_function(step_func, input_obj=inp_datastore) - # else: - # self._exec_step_function(step_func) - # - # # task_post_step decorator hooks - # for deco in whitelisted_decorators: - # deco.task_post_step( - # step_name, - # self.flow, - # self.flow._graph, - # retry_count, - # max_user_code_retries, - # ) - # - # self.flow._task_ok = True - # self.flow._success = True - # except Exception as ex: - # exception_handled = False - # for deco in whitelisted_decorators: - # res = deco.task_exception( - # ex, - # step_name, - # self.flow, - # self.flow._graph, - # retry_count, - # max_user_code_retries, - # ) - # exception_handled = bool(res) or exception_handled - # - # if exception_handled: - # self.flow._task_ok = True - # else: - # self.flow._task_ok = False - # self.flow._exception = MetaflowExceptionWrapper(ex) - # print("%s failed:" % self.flow, file=sys.stderr) - # raise - # finally: - # output.persist(self.flow) - # output.done() - # - # # task_finish decorator hooks - # for deco in whitelisted_decorators: - # deco.task_finished( - # step_name, - # self.flow, - # self.flow._graph, - # self.flow._task_ok, - # retry_count, - # max_user_code_retries, - # ) + for deco in whitelisted_decorators: + deco.task_pre_step( + step_name=step_name, + task_datastore=output, + metadata=self.metadata, + run_id=new_run_id, + task_id=new_task_id, + flow=self.flow, + graph=self.flow._graph, + retry_count=retry_count, + max_user_code_retries=max_user_code_retries, + ubf_context=self.ubf_context, + inputs=inp_datastore, + ) + + # task_decorate decorator hooks + for deco in whitelisted_decorators: + step_func = deco.task_decorate( + step_func=step_func, + flow=self.flow, + graph=self.flow._graph, + retry_count=retry_count, + max_user_code_retries=max_user_code_retries, + ubf_context=self.ubf_context, + ) + + # Execute the step function + try: + # if is_join: + # # Join step + # self._exec_step_function(step_func, input_obj=inp_datastore) + # else: + # self._exec_step_function(step_func) + self._exec_step_function(step_func) + + # task_post_step decorator hooks + for deco in whitelisted_decorators: + deco.task_post_step( + step_name, + self.flow, + self.flow._graph, + retry_count, + max_user_code_retries, + ) + + self.flow._task_ok = True + self.flow._success = True + except Exception as ex: + exception_handled = False + for deco in whitelisted_decorators: + res = deco.task_exception( + ex, + step_name, + self.flow, + self.flow._graph, + retry_count, + max_user_code_retries, + ) + exception_handled = bool(res) or exception_handled + + if exception_handled: + self.flow._task_ok = True + else: + self.flow._task_ok = False + self.flow._exception = MetaflowExceptionWrapper(ex) + print("%s failed:" % self.flow, file=sys.stderr) + raise + finally: + output.persist(self.flow) + output.done() + + # task_finish decorator hooks + for deco in whitelisted_decorators: + deco.task_finished( + step_name, + self.flow, + self.flow._graph, + self.flow._task_ok, + retry_count, + max_user_code_retries, + ) def run_step( self,