Skip to content

Commit db1cc4a

Browse files
committed
Add support for overriding artifacts
1 parent e92f139 commit db1cc4a

File tree

6 files changed

+139
-67
lines changed

6 files changed

+139
-67
lines changed

metaflow/cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def show(obj):
153153
echo_always("\n%s" % obj.graph.doc)
154154
for node_name in obj.graph.sorted_nodes:
155155
node = obj.graph[node_name]
156-
echo_always("\nStep *%s* and type: *%s*" % (node.name, node.type), err=False)
156+
echo_always("\nStep *%s* " % node.name, err=False)
157157
echo_always(node.doc if node.doc else "?", indent=True, err=False)
158158
if node.type != "end":
159159
echo_always(
@@ -474,8 +474,8 @@ def start(
474474
# ctx.obj.monitor = MONITOR_SIDECARS["nullSidecarMonitor"](
475475
# flow=ctx.obj.flow, env=ctx.obj.environment
476476
# )
477-
# ctx.obj.spin_datastore_impl = [d for d in DATASTORES if d.TYPE == "local"][0]
478-
ctx.obj.spin_datastore_impl = [d for d in DATASTORES if d.TYPE == "s3"][0]
477+
ctx.obj.spin_datastore_impl = [d for d in DATASTORES if d.TYPE == "local"][0]
478+
# ctx.obj.spin_datastore_impl = [d for d in DATASTORES if d.TYPE == "s3"][0]
479479
if datastore_root is None:
480480
datastore_root = ctx.obj.spin_datastore_impl.get_datastore_root_from_config(
481481
ctx.obj.echo

metaflow/cli_components/run_cmds.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -398,10 +398,11 @@ def run(
398398

399399

400400
@click.command(help="Spins up a task for a given step from a previous run locally.")
401-
@click.argument(
402-
"step-name",
403-
required=True,
404-
type=str,
401+
@click.option(
402+
"--step-name",
403+
default=None,
404+
show_default=True,
405+
help="Step name to spin up. Must provide either step-name or spin-pathspec.",
405406
)
406407
@click.option(
407408
"--spin-pathspec",
@@ -418,6 +419,15 @@ def run(
418419
show_default=True,
419420
help="Skip decorators attached to the step.",
420421
)
422+
@click.option(
423+
"--artifacts-module",
424+
default=None,
425+
show_default=True,
426+
help="Path to a module that contains artifacts to be used in the spun step. The artifacts should "
427+
"be defined as a dictionary called ARTIFACTS with keys as the artifact names and values as the "
428+
"artifact values. The artifact values will overwrite the default values of the artifacts used in "
429+
"the spun step.",
430+
)
421431
@click.option(
422432
"--max-log-size",
423433
default=10,
@@ -430,20 +440,31 @@ def run(
430440
@click.pass_obj
431441
def spin(
432442
obj,
433-
step_name,
443+
step_name=None,
434444
spin_pathspec=None,
445+
artifacts_module=None,
435446
skip_decorators=False,
436447
max_log_size=None,
437448
run_id_file=None,
438449
runner_attribute_file=None,
439450
**kwargs
440451
):
441452
before_run(obj, [], [])
442-
if spin_pathspec is None:
453+
if step_name and spin_pathspec:
454+
raise CommandException(
455+
"Cannot specify both step-name and spin-pathspec. Please specify only one."
456+
)
457+
if not (step_name or spin_pathspec):
458+
raise CommandException(
459+
"Please specify either step-name or spin-pathspec to spin a task."
460+
)
461+
if step_name is not None:
443462
spin_pathspec = get_latest_task_pathspec(obj.flow.name, step_name)
463+
else:
464+
step_name = spin_pathspec.split("/")[2]
444465

445466
obj.echo(
446-
f"Spinning up step *{step_name}* locally with task pathspec *{spin_pathspec}*"
467+
f"Spinning up step *{step_name}* locally using previous task pathspec *{spin_pathspec}*"
447468
)
448469
obj.flow._set_constants(obj.graph, kwargs, obj.config_options)
449470
step_func = getattr(obj.flow, step_name)
@@ -464,6 +485,7 @@ def spin(
464485
step_func,
465486
spin_pathspec,
466487
skip_decorators,
488+
artifacts_module,
467489
max_log_size * 1024 * 1024,
468490
)
469491
_system_logger.log_event(

metaflow/cli_components/step_cmd.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from ..exception import CommandException
77
from ..task import MetaflowTask
88
from ..unbounded_foreach import UBF_CONTROL, UBF_TASK
9-
from ..util import decompress_list
9+
from ..util import decompress_list, read_artifacts_module
1010
import metaflow.tracing as tracing
1111

1212

@@ -232,6 +232,15 @@ def step(
232232
show_default=True,
233233
help="Skip decorators attached to the step.",
234234
)
235+
@click.option(
236+
"--artifacts-module",
237+
default=None,
238+
show_default=True,
239+
help="Path to a module that contains artifacts to be used in the spun step. The artifacts should "
240+
"be defined as a dictionary called ARTIFACTS with keys as the artifact names and values as the "
241+
"artifact values. The artifact values will overwrite the default values of the artifacts used in "
242+
"the spun step.",
243+
)
235244
@click.pass_context
236245
def spin_internal(
237246
ctx,
@@ -245,6 +254,7 @@ def spin_internal(
245254
max_user_code_retries=None,
246255
namespace=None,
247256
skip_decorators=False,
257+
artifacts_module=None,
248258
):
249259
import time
250260

@@ -259,16 +269,18 @@ def spin_internal(
259269
echo = echo_always
260270

261271
input_paths = decompress_list(input_paths) if input_paths else []
262-
# print(f"Input paths: {input_paths} and type: {type(input_paths)}")
263-
# print(f"Split index: {split_index} and type: {type(split_index)}")
264-
# print(f"Retry count: {retry_count} and type: {type(retry_count)}")
265-
# print(f"Max user code retries: {max_user_code_retries} and type: {type(max_user_code_retries)}")
266-
267272
echo(
268273
f"Spinning a task, *{step_name}* with previous task pathspec: {spin_pathspec}",
269274
fg="magenta",
270275
bold=False,
271276
)
277+
# if namespace is not None:
278+
# namespace(namespace or None)
279+
280+
spin_artifacts = read_artifacts_module(artifacts_module) if artifacts_module else {}
281+
spin_artifacts = spin_artifacts.get("ARTIFACTS", {})
282+
283+
print(f"spin_artifacts: {spin_artifacts}")
272284

273285
task = MetaflowTask(
274286
ctx.obj.flow,
@@ -298,19 +310,8 @@ def spin_internal(
298310
retry_count,
299311
max_user_code_retries,
300312
spin_pathspec,
313+
skip_decorators,
314+
spin_artifacts,
301315
)
302316

303-
#
304-
# task.run_spin_step(
305-
# step_name,
306-
# task_pathspec,
307-
# run_id,
308-
# task_id,
309-
# input_paths,
310-
# split_index,
311-
# retry_count,
312-
# max_user_code_retries,
313-
# namespace,
314-
# skip_decorators,
315-
# )
316-
# print("Time taken for the whole thing: ", time.time() - start)
317+
echo(f"Time taken for the whole thing: {time.time() - start}")

metaflow/runtime.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def __init__(
9191
step_func,
9292
task_pathspec,
9393
skip_decorators=False,
94+
artifacts_module=None,
9495
max_log_size=MAX_LOG_SIZE,
9596
):
9697
from metaflow import Task
@@ -109,7 +110,6 @@ def __init__(
109110
# Spin specific metadata, event_logger, monitor, and flow_datastore
110111
self._spin_metadata = spin_metadata
111112
self._spin_flow_datastore = spin_flow_datastore
112-
print(f"Spin Flow Datastore in SpinRuntime: {self._spin_flow_datastore}")
113113

114114
self._step_func = step_func
115115
self._task_pathspec = task_pathspec
@@ -119,6 +119,7 @@ def __init__(
119119
self._whitelist_decorators = None
120120
self._config_file_name = None
121121
self._skip_decorators = skip_decorators
122+
self._artifacts_module = artifacts_module
122123
self._max_log_size = max_log_size
123124
self._encoding = sys.stdout.encoding or "UTF-8"
124125

@@ -139,7 +140,7 @@ def split_index(self):
139140

140141
@property
141142
def input_paths(self):
142-
start_time = time.time()
143+
st = time.time()
143144

144145
def _format_input_paths(task_pathspec):
145146
_, run_id, step_name, task_id = task_pathspec.split("/")
@@ -158,13 +159,12 @@ def _format_input_paths(task_pathspec):
158159
).task
159160
self._input_paths = [_format_input_paths(task.pathspec)]
160161
else:
161-
# print("I am in else of input_paths")
162162
self._input_paths = [
163163
_format_input_paths(task_pathspec)
164164
for task_pathspec in self._task.parent_task_pathspecs
165165
]
166-
end_time = time.time()
167-
# print(f"Time taken to get input paths: {end_time - start_time}")
166+
et = time.time()
167+
print(f"Time taken to get input paths: {et - st}")
168168
return self._input_paths
169169

170170
@property
@@ -232,6 +232,8 @@ def _launch_and_monitor_task(self):
232232
self._max_log_size,
233233
self._config_file_name,
234234
spin_pathspec=self._task_pathspec,
235+
skip_decorators=self._skip_decorators,
236+
artifacts_module=self._artifacts_module,
235237
)
236238

237239
# print("Worker created")
@@ -1710,9 +1712,13 @@ class CLIArgs(object):
17101712
for step execution in StepDecorator.runtime_step_cli().
17111713
"""
17121714

1713-
def __init__(self, task, spin_pathspec=None):
1715+
def __init__(
1716+
self, task, spin_pathspec=None, skip_decorators=False, artifacts_module=None
1717+
):
17141718
self.task = task
17151719
self.spin_pathspec = spin_pathspec
1720+
self.skip_decorators = skip_decorators
1721+
self.artifacts_module = artifacts_module
17161722
self.entrypoint = list(task.entrypoint)
17171723
self.top_level_options = {
17181724
"quiet": False,
@@ -1779,6 +1785,8 @@ def spin_args(self):
17791785
"max-user-code-retries": self.task.user_code_retries,
17801786
"namespace": get_namespace() or "",
17811787
"spin-pathspec": self.spin_pathspec,
1788+
"skip-decorators": self.skip_decorators,
1789+
"artifacts-module": self.artifacts_module,
17821790
}
17831791
self.env = {}
17841792

@@ -1820,10 +1828,20 @@ def __str__(self):
18201828

18211829

18221830
class Worker(object):
1823-
def __init__(self, task, max_logs_size, config_file_name, spin_pathspec=None):
1831+
def __init__(
1832+
self,
1833+
task,
1834+
max_logs_size,
1835+
config_file_name,
1836+
spin_pathspec=None,
1837+
skip_decorators=False,
1838+
artifacts_module=None,
1839+
):
18241840
self.task = task
18251841
self._config_file_name = config_file_name
18261842
self.spin_pathspec = spin_pathspec
1843+
self.skip_decorators = skip_decorators
1844+
self.artifacts_module = artifacts_module
18271845
self._proc = self._launch()
18281846

18291847
if task.retries > task.user_code_retries:
@@ -1855,7 +1873,12 @@ def __init__(self, task, max_logs_size, config_file_name, spin_pathspec=None):
18551873
# not it is properly shut down)
18561874

18571875
def _launch(self):
1858-
args = CLIArgs(self.task, spin_pathspec=self.spin_pathspec)
1876+
args = CLIArgs(
1877+
self.task,
1878+
spin_pathspec=self.spin_pathspec,
1879+
skip_decorators=self.skip_decorators,
1880+
artifacts_module=self.artifacts_module,
1881+
)
18591882
env = dict(os.environ)
18601883

18611884
if self.task.clone_run_id:

0 commit comments

Comments
 (0)