-
Notifications
You must be signed in to change notification settings - Fork 527
/
Copy path_flatbuffer.py
354 lines (292 loc) · 12.7 KB
/
_flatbuffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import importlib.resources
import os
import re
import shutil
import subprocess
import tempfile
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Sequence
# If this environment variable is set to true, save the flatc input files when
# serialization fails.
_SAVE_FLATC_ENV: str = "ET_EXIR_SAVE_FLATC_INPUTS_ON_FAILURE"
def _is_valid_alignment(alignment: int) -> bool:
"""Returns True if the alignment is valid, or is None."""
if alignment is None:
return True
return alignment > 0 and (alignment & (alignment - 1)) == 0
def _patch_schema_alignment(
schema: bytes,
constant_tensor_alignment: Optional[int],
delegate_alignment: Optional[int],
) -> bytes:
"""Modifies annotated "force_align" values in a flatbuffer schema.
Args:
schema: The flatbuffer schema to modify.
constant_tensor_alignment: If provided, the alignment to use for lines annotated
with "@executorch-tensor-alignment". If not provided, does not patch
tensor alignment.
delegate_alignment: If provided, the alignment to use for lines
annotated with "@executorch-delegate-alignment". If not provided,
does not patch delegate alignment.
Returns:
The possibly-modified flatbuffer schema.
"""
def assert_valid_alignment(alignment: Optional[int], name: str) -> None:
if not (alignment is None or _is_valid_alignment(alignment)):
raise ValueError(f"Bad {name} {alignment}")
assert_valid_alignment(constant_tensor_alignment, "constant_tensor_alignment")
assert_valid_alignment(delegate_alignment, "delegate_alignment")
def patch_alignment(line: bytes, alignment: int) -> bytes:
"""Replaces an existing alignment with a new alignment."""
return re.sub(
rb"\(\s*force_align\s*:\s*\d+\s*\)",
f"(force_align: {alignment})".encode("utf-8"),
line,
)
lines = []
for line in schema.splitlines():
if constant_tensor_alignment and b"@executorch-tensor-alignment" in line:
lines.append(patch_alignment(line, constant_tensor_alignment))
elif delegate_alignment and b"@executorch-delegate-alignment" in line:
lines.append(patch_alignment(line, delegate_alignment))
else:
lines.append(line)
return b"\n".join(lines)
class _SchemaMaxAlignmentGetter:
"""Finds the largest (force_align: N) N value in flatbuffer schemas."""
def __init__(self) -> None:
self.max_alignment: int = 0
def __call__(self, schema: bytes) -> bytes:
"""Finds all `(force_align: N)` instances and updates max_alignment.
Returns the input schema unmodified.
"""
regex = re.compile(rb"\(\s*force_align\s*:\s*(\d+)\s*\)")
matches = regex.findall(schema)
for alignment in [int(match) for match in matches]:
if alignment > self.max_alignment:
self.max_alignment = alignment
return schema
class _ResourceFiles:
"""Manages a collection of python resources that will be written to files."""
def __init__(self, resource_names: Sequence[str]) -> None:
"""Load the resources with the provided names."""
# Map each name to its contents.
self._files: Dict[str, bytes] = {}
for name in resource_names:
self._files[name] = importlib.resources.read_binary(__package__, name)
def patch_files(self, patch_fn: Callable[[bytes], bytes]) -> None:
"""Uses the provided patching function to update the contents of all
files. `patch_fn` takes the current contents of a file as input and
returns the new contents.
"""
for name in self._files.keys():
self._files[name] = patch_fn(self._files[name])
def write_to(self, out_dir: str) -> None:
"""Writes the files to the specified directory. File names are based on
the original resource names.
"""
for name, data in self._files.items():
with open(os.path.join(out_dir, name), "wb") as fp:
fp.write(data)
@dataclass
class _SchemaInfo:
# Path to a file containing the root schema. Other included schema files may
# be present in the same directly.
root_path: str
# An alignment value that can satisfy all "force_align" entries found in the
# schema files.
max_alignment: int
def _prepare_schema(
out_dir: str,
constant_tensor_alignment: Optional[int] = None,
delegate_alignment: Optional[int] = None,
) -> _SchemaInfo:
"""Returns the path to the program schema file after copying it and its deps
into out_dir. May patch the schema contents depending on the parameters to
this function.
"""
program_schema = "program.fbs"
# Included by the root program schema; must also be present.
deps = ["scalar_type.fbs"]
schemas = _ResourceFiles([program_schema] + deps)
# Update annotated alignments in the schema files.
schemas.patch_files(
lambda data: _patch_schema_alignment(
schema=data,
constant_tensor_alignment=constant_tensor_alignment,
delegate_alignment=delegate_alignment,
),
)
# Find the largest alignment used in the patched schema files.
get_alignments = _SchemaMaxAlignmentGetter()
schemas.patch_files(get_alignments)
# Write the patched schema files to the filesystem.
schemas.write_to(out_dir)
return _SchemaInfo(
root_path=os.path.join(out_dir, program_schema),
max_alignment=get_alignments.max_alignment,
)
@dataclass
class _FlatbufferResult:
# Serialized flatbuffer data.
data: bytes
# The maximum "force_align" value from the schema used to serialize the data.
max_alignment: int
# Name of an optional resource containing the `flatc` executable.
_FLATC_RESOURCE_NAME: str = "flatbuffers-flatc"
def _run_flatc(args: Sequence[str]) -> None:
"""Runs the `flatc` command with the provided args.
If a resource matching _FLATC_RESOURCE_NAME exists, uses that executable.
Otherwise, expects the `flatc` tool to be available on the system path.
"""
if importlib.resources.is_resource(__package__, _FLATC_RESOURCE_NAME):
# Use the provided flatc binary.
with importlib.resources.path(__package__, _FLATC_RESOURCE_NAME) as flatc_path:
subprocess.run([flatc_path] + list(args), check=True)
else:
# Expect the `flatc` tool to be on the system path or set as an env var.
flatc_path = os.getenv("FLATC_EXECUTABLE")
if not flatc_path:
flatc_path = "flatc"
subprocess.run([flatc_path] + list(args), check=True)
def _flatc_compile(output_dir: str, schema_path: str, json_path: str) -> None:
"""Serializes JSON data to a binary flatbuffer file.
Args:
output_dir: Directory under which to create the binary flatbuffer file.
schema_path: Path to the flatbuffer schema to use for serialization.
If the schema inclues other schema files, they must be present in
the same directory.
json_path: Path to the data to serialize, as JSON data whose structure
matches the schema.
"""
_run_flatc(
[
"--binary",
"-o",
output_dir,
schema_path,
json_path,
]
)
def _flatc_decompile(
output_dir: str,
schema_path: str,
bin_path: str,
flatc_additional_args: Optional[List[str]] = None,
) -> None:
"""Deserializes binary flatbuffer data to a JSON file.
Args:
output_dir: Directory under which to create the JSON file.
schema_path: Path to the flatbuffer schema to use for deserialization.
If the schema inclues other schema files, they must be present in
the same directory.
bin_path: Path to the data to deserialize, as binary data compatible
with the schema.
"""
flatc_additional_args = flatc_additional_args if flatc_additional_args else []
_run_flatc(
flatc_additional_args
+ [
"--json",
"--defaults-json",
"--strict-json",
"-o",
output_dir,
schema_path,
"--",
bin_path,
]
)
def _program_json_to_flatbuffer(
program_json: str,
*,
constant_tensor_alignment: Optional[int] = None,
delegate_alignment: Optional[int] = None,
) -> _FlatbufferResult:
"""Converts Program-compatible JSON into binary flatbuffer data.
Args:
program_json: The JSON to convert. Must be compatible with the root
table type of //executorch/schema/program.fbs.
constant_tensor_alignment: If provided, the alignment to use for tensor
data embedded in the output flatbuffer data. If not provided, uses
the alignment in the schema.
delegate_alignment: If provided, the alignment to use for delegate
data embedded in the output flatbuffer data. If not provided, uses
the alignment in the schema.
Returns: The flatbuffer data and associated metadata.
"""
with tempfile.TemporaryDirectory() as temp_dir:
schema_info = _prepare_schema(
out_dir=temp_dir,
constant_tensor_alignment=constant_tensor_alignment,
delegate_alignment=delegate_alignment,
)
file_stem = "data"
json_path = os.path.join(temp_dir, file_stem + ".json")
output_path = os.path.join(temp_dir, file_stem + ".pte")
with open(json_path, "wb") as json_file:
json_file.write(program_json.encode("ascii"))
try:
_flatc_compile(temp_dir, schema_info.root_path, json_path)
except Exception as err:
# It's helpful to save the breaking files for debugging. Optionally
# move them out of the auto-deleting temporary directory. Don't do
# this by default because some input files can be many GB in size,
# and these copies won't be auto-deleted.
should_save = os.getenv(_SAVE_FLATC_ENV, "").strip() not in {"", "0"}
extra_message = ""
if should_save:
try:
saved_dir = tempfile.mkdtemp(prefix="exir-saved-flatc-")
for f in os.listdir(temp_dir):
shutil.move(src=os.path.join(temp_dir, f), dst=saved_dir)
extra_message += f" Moved input files to '{saved_dir}'."
except Exception as err2:
extra_message += (
f" (Failed to save input files for debugging: {err2})"
)
else:
extra_message += (
f" Set {_SAVE_FLATC_ENV}=1 to save input files on failure."
)
raise RuntimeError(
f"Failed to compile {json_path} to {output_path}." + extra_message
) from err
with open(output_path, "rb") as output_file:
return _FlatbufferResult(
data=output_file.read(), max_alignment=schema_info.max_alignment
)
def _replace_infinity_in_json_file(content: bytes) -> bytes:
"""Replace -inf and inf with "inf" and "-inf" in the JSON file. program.fbs
is used to convert from flatbuffer to JSON. +-inf float values are not
supported by JSON, so we replace them with the string equivalent. When
converting from JSON to python dataclasses, the string is read as a Union
of float and string (see schema.py).
"""
content = re.sub(
rb'"double_val"\s*:\s*(-)?inf', rb'"double_val": "\g<1>inf"', content
)
return content
def _program_flatbuffer_to_json(program_flatbuffer: bytes) -> bytes:
"""Converts binary flatbuffer data into Program-compatible JSON.
The binary is parsed using the schema in //executorch/schema/program.fbs.
"""
with tempfile.TemporaryDirectory() as temp_dir:
# No need to patch the alignment when reading. "force_align" is only
# used during serialization.
schema_info = _prepare_schema(temp_dir)
file_stem = "data"
bin_path = os.path.join(temp_dir, file_stem + ".bin")
json_path = os.path.join(temp_dir, file_stem + ".json")
with open(bin_path, "wb") as bin_file:
bin_file.write(program_flatbuffer)
_flatc_decompile(temp_dir, schema_info.root_path, bin_path)
with open(json_path, "rb") as output_file:
json_data = output_file.read()
return _replace_infinity_in_json_file(json_data)