-
Notifications
You must be signed in to change notification settings - Fork 525
/
Copy pathvisualization_utils.py
156 lines (129 loc) · 5.41 KB
/
visualization_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import subprocess
import time
from typing import Any, Callable, Type
from executorch.exir import EdgeProgramManager, ExecutorchProgramManager
from executorch.exir.program._program import _update_exported_program_graph_module
from torch._export.verifier import Verifier
from torch.export.exported_program import ExportedProgram
from torch.fx import GraphModule
try:
from model_explorer import config, consts, visualize_from_config # type: ignore
except ImportError:
print(
"Error: 'model_explorer' is not installed. Install using devtools/install_requirement.sh"
)
raise
class SingletonModelExplorerServer:
"""Singleton context manager for starting a model-explorer server.
If multiple ModelExplorerServer contexts are nested, a single
server is still used.
"""
server: None | subprocess.Popen = None
num_open: int = 0
wait_after_start = 3.0
def __init__(self, open_in_browser: bool = True, port: int | None = None):
if SingletonModelExplorerServer.server is None:
command = ["model-explorer"]
if not open_in_browser:
command.append("--no_open_in_browser")
if port is not None:
command.append("--port")
command.append(str(port))
SingletonModelExplorerServer.server = subprocess.Popen(command)
def __enter__(self):
SingletonModelExplorerServer.num_open = (
SingletonModelExplorerServer.num_open + 1
)
time.sleep(SingletonModelExplorerServer.wait_after_start)
return self
def __exit__(self, type, value, traceback):
SingletonModelExplorerServer.num_open = (
SingletonModelExplorerServer.num_open - 1
)
if SingletonModelExplorerServer.num_open == 0:
if SingletonModelExplorerServer.server is not None:
SingletonModelExplorerServer.server.kill()
try:
SingletonModelExplorerServer.server.wait(
SingletonModelExplorerServer.wait_after_start
)
except subprocess.TimeoutExpired:
SingletonModelExplorerServer.server.terminate()
SingletonModelExplorerServer.server = None
class ModelExplorerServer:
"""Context manager for starting a model-explorer server."""
wait_after_start = 2.0
def __init__(self, open_in_browser: bool = True, port: int | None = None):
command = ["model-explorer"]
if not open_in_browser:
command.append("--no_open_in_browser")
if port is not None:
command.append("--port")
command.append(str(port))
self.server = subprocess.Popen(command)
def __enter__(self):
time.sleep(self.wait_after_start)
def __exit__(self, type, value, traceback):
self.server.kill()
try:
self.server.wait(self.wait_after_start)
except subprocess.TimeoutExpired:
self.server.terminate()
def _get_exported_program(
visualizable: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager,
) -> ExportedProgram:
if isinstance(visualizable, ExportedProgram):
return visualizable
if isinstance(visualizable, (EdgeProgramManager, ExecutorchProgramManager)):
return visualizable.exported_program()
raise RuntimeError(f"Cannot get ExportedProgram from {visualizable}")
def visualize(
visualizable: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager,
reuse_server: bool = True,
no_open_in_browser: bool = False,
**kwargs,
):
"""Wraps the visualize_from_config call from model_explorer.
For convenicence, figures out how to find the exported_program
from EdgeProgramManager and ExecutorchProgramManager for you.
See https://github.com/google-ai-edge/model-explorer/wiki/4.-API-Guide#visualize-pytorch-models
for full documentation.
"""
cur_config = config()
settings = consts.DEFAULT_SETTINGS
cur_config.add_model_from_pytorch(
"Executorch",
exported_program=_get_exported_program(visualizable),
settings=settings,
)
if reuse_server:
cur_config.set_reuse_server()
visualize_from_config(
cur_config,
no_open_in_browser=no_open_in_browser,
**kwargs,
)
def visualize_graph(
graph_module: GraphModule,
exported_program: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager,
reuse_server: bool = True,
no_open_in_browser: bool = False,
**kwargs,
):
"""Overrides the graph_module of the supplied exported_program with 'graph_module' before visualizing.
Also disables validating operators to allow visualizing graphs containing custom ops.
A typical example is after running passes, which returns a graph_module rather than an ExportedProgram.
"""
class _any_op(Verifier):
dialect = "ANY_OP"
def allowed_op_types(self) -> tuple[Type[Any], ...]:
return (Callable,) # type: ignore
exported_program = _get_exported_program(exported_program)
exported_program = _update_exported_program_graph_module(
exported_program, graph_module, override_verifiers=[_any_op]
)
visualize(exported_program, reuse_server, no_open_in_browser, **kwargs)