Skip to content

Commit 4ea3288

Browse files
[ExecuTorch][#10375] Add extension.BundledModule to Wrap extension.Module with Bundled Program Logic
Pull Request resolved: #10449 #10375 # Context This issue is a step of #9638. In #9638, we want to have `extension.Module` as the single source of implementation in `pybindings`, which means that `pybindings.PyModule` should use `extension.Module` rather than its own `pybindings.Module`. The issue is that `pybindings.PyModule` is dependent on the `method` getter from `pybindings.Module`, which `extension.Module` do not have. Since we don't want to expose `method` getter in `extension.Module`, we have to protect the getter, wrap the functions that is dependent on it and use the protected getter there, ultimately decouple `pybindings` from a `method` getter. # Proposal Now that we have a protected `method` getter, we can introduce an `extension.BundledModule`, a child class inheriting `extension.Module` which wraps up bundled program logic that is dependent on the `method` getter. We are also introducing ctests and a file-path-based factory function constructor. ghstack-source-id: 281114538 Differential Revision: [D73564125](https://our.internmc.facebook.com/intern/diff/D73564125/)
1 parent f852a27 commit 4ea3288

12 files changed

+523
-23
lines changed

devtools/bundled_program/bundled_program.cpp

+18-3
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,16 @@ ET_NODISCARD Error load_bundled_input(
260260
if (!method_test.ok()) {
261261
return method_test.error();
262262
}
263-
263+
auto test_cases = method_test.get()->test_cases();
264+
ET_CHECK_OR_RETURN_ERROR(
265+
testset_idx < test_cases->size(),
266+
InvalidArgument,
267+
"testset_idx %zu is out of range [0, %u]",
268+
testset_idx,
269+
test_cases->size());
264270
auto bundled_inputs =
265-
method_test.get()->test_cases()->Get(testset_idx)->inputs();
271+
test_cases->Get(static_cast<flatbuffers::uoffset_t>(testset_idx))
272+
->inputs();
266273

267274
for (size_t input_idx = 0; input_idx < method.inputs_size(); input_idx++) {
268275
auto bundled_input = bundled_inputs->GetMutableObject(input_idx);
@@ -359,8 +366,16 @@ ET_NODISCARD Error verify_method_outputs(
359366
return method_test.error();
360367
}
361368

369+
auto test_cases = method_test.get()->test_cases();
370+
ET_CHECK_OR_RETURN_ERROR(
371+
testset_idx < test_cases->size(),
372+
InvalidArgument,
373+
"testset_idx %zu is out of range [0, %u]",
374+
testset_idx,
375+
test_cases->size());
362376
auto bundled_expected_outputs =
363-
method_test.get()->test_cases()->Get(testset_idx)->expected_outputs();
377+
test_cases->Get(static_cast<flatbuffers::uoffset_t>(testset_idx))
378+
->expected_outputs();
364379

365380
if (bundled_expected_outputs->size() == 0) {
366381
// No bundled expected outputs, so we can't verify the method outputs.

devtools/bundled_program/schema/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def define_common_targets():
7474
visibility = [
7575
"//executorch/devtools/bundled_program/...",
7676
"//executorch/extension/pybindings/...",
77+
"//executorch/extension/module/...",
7778
],
7879
exported_headers = {
7980
OUTPUT_BUNDLED_HEADER: ":{}[{}]".format(BUNDLED_GEN_RULE_NAME, OUTPUT_BUNDLED_HEADER),

extension/module/bundled_module.cpp

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/module/bundled_module.h>
10+
11+
#include <executorch/devtools/bundled_program/bundled_program.h>
12+
#include <executorch/devtools/bundled_program/schema/bundled_program_schema_generated.h>
13+
#include <executorch/extension/data_loader/buffer_data_loader.h>
14+
#include <executorch/extension/data_loader/file_data_loader.h>
15+
16+
namespace executorch {
17+
namespace extension {
18+
namespace ET_BUNDLED_MODULE_NAMESPACE {
19+
20+
namespace {
21+
std::unique_ptr<BufferDataLoader> program_data_loader(
22+
const void* bundled_program_ptr) {
23+
auto bundled_program =
24+
bundled_program_flatbuffer::GetBundledProgram(bundled_program_ptr);
25+
// the program inside the bundled program
26+
auto program = bundled_program->program();
27+
return std::make_unique<BufferDataLoader>(program->data(), program->size());
28+
}
29+
} // namespace
30+
31+
BundledModule::BundledModule(
32+
const void* bundled_program_ptr,
33+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
34+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
35+
std::unique_ptr<runtime::EventTracer> event_tracer,
36+
std::unique_ptr<runtime::DataLoader> data_map_loader)
37+
: Module(
38+
program_data_loader(bundled_program_ptr),
39+
std::move(memory_allocator),
40+
std::move(temp_allocator),
41+
std::move(event_tracer),
42+
std::move(data_map_loader)),
43+
bundled_program_ptr_(bundled_program_ptr) {}
44+
45+
runtime::Result<std::unique_ptr<BundledModule>> BundledModule::from_file(
46+
const std::string& file_path,
47+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
48+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
49+
std::unique_ptr<runtime::EventTracer> event_tracer,
50+
std::unique_ptr<runtime::DataLoader> data_map_loader) {
51+
auto data_loader_result = FileDataLoader::from(file_path.c_str());
52+
if (!data_loader_result.ok()) {
53+
return data_loader_result.error();
54+
}
55+
56+
auto file_size_result = data_loader_result->size();
57+
if (!file_size_result.ok()) {
58+
return file_size_result.error();
59+
}
60+
61+
size_t file_size = file_size_result.get();
62+
auto file_data = std::make_unique<uint8_t[]>(file_size);
63+
auto buffer_result =
64+
data_loader_result->load_into(0, file_size, {}, file_data.get());
65+
if (buffer_result != runtime::Error::Ok) {
66+
return buffer_result;
67+
}
68+
69+
// Pass ownership of the data to BundledModule
70+
auto bm = std::make_unique<BundledModule>(
71+
file_data.release(),
72+
std::move(memory_allocator),
73+
std::move(temp_allocator),
74+
std::move(event_tracer),
75+
std::move(data_map_loader));
76+
77+
bm->is_loaded_from_file_ = true;
78+
79+
return bm;
80+
}
81+
82+
runtime::Result<std::vector<runtime::EValue>> BundledModule::execute(
83+
const std::string& method_name,
84+
const size_t testset_idx) {
85+
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
86+
auto& method = methods_.at(method_name).method;
87+
88+
ET_CHECK_OK_OR_RETURN_ERROR(
89+
executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input(
90+
*method, bundled_program_ptr_, testset_idx));
91+
ET_CHECK_OK_OR_RETURN_ERROR(method->execute());
92+
93+
const auto outputs_size = method->outputs_size();
94+
std::vector<runtime::EValue> outputs(outputs_size);
95+
ET_CHECK_OK_OR_RETURN_ERROR(
96+
method->get_outputs(outputs.data(), outputs_size));
97+
98+
return outputs;
99+
}
100+
101+
runtime::Error BundledModule::verify_method_outputs(
102+
const std::string& method_name,
103+
const size_t testset_idx,
104+
double rtol,
105+
double atol) {
106+
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
107+
auto& method = methods_.at(method_name).method;
108+
return executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs(
109+
*method, bundled_program_ptr_, testset_idx, rtol, atol);
110+
}
111+
112+
} // namespace ET_BUNDLED_MODULE_NAMESPACE
113+
} // namespace extension
114+
} // namespace executorch

extension/module/bundled_module.h

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/extension/module/module.h>
12+
13+
#ifdef USE_ATEN_LIB
14+
#define ET_BUNDLED_MODULE_NAMESPACE bundled_module::aten
15+
#else // !USE_ATEN_LIB
16+
#define ET_BUNDLED_MODULE_NAMESPACE bundled_module
17+
#endif // USE_ATEN_LIB
18+
19+
using executorch::extension::ET_MODULE_NAMESPACE::Module;
20+
21+
namespace executorch {
22+
namespace extension {
23+
namespace ET_BUNDLED_MODULE_NAMESPACE {
24+
25+
/**
26+
* A facade class for loading bundled programs and executing methods within
27+
* them.
28+
*/
29+
class BundledModule : public Module {
30+
public:
31+
/**
32+
* Constructs an instance with the bundled program buffer pointer.
33+
*
34+
* This constructor reads the program from bundled program buffer to load the
35+
* module with data loader. The bundled program pointer is preserved so that
36+
* the portion outside of program is accessible.
37+
*
38+
* @param[in] bundled_program_ptr A DataLoader used for loading program data.
39+
* @param[in] memory_allocator A MemoryAllocator used for memory management.
40+
* @param[in] temp_allocator A MemoryAllocator to use when allocating
41+
* temporary data during kernel or delegate execution.
42+
* @param[in] event_tracer A EventTracer used for tracking and logging events.
43+
* @param[in] data_map_loader A DataLoader used for loading external weights.
44+
*/
45+
explicit BundledModule(
46+
const void* bundled_program_ptr,
47+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
48+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
49+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
50+
std::unique_ptr<runtime::DataLoader> data_map_loader = nullptr);
51+
52+
// Disallow copying
53+
BundledModule(const BundledModule&) = delete;
54+
BundledModule& operator=(const BundledModule&) = delete;
55+
// Disallow copying
56+
BundledModule(BundledModule&&) = delete;
57+
BundledModule& operator=(BundledModule&&) = delete;
58+
// Default destructor
59+
~BundledModule() {
60+
if (is_loaded_from_file_) {
61+
delete[] static_cast<const uint8_t*>(bundled_program_ptr_);
62+
}
63+
}
64+
65+
/**
66+
* Constructs an instance by loading a bundled program from a file with
67+
* specified memory locking behavior.
68+
*
69+
* @param[in] file_path The path to the ExecuTorch bundled program file to
70+
* load.
71+
* @param[in] memory_allocator A MemoryAllocator used for memory management.
72+
* @param[in] temp_allocator A MemoryAllocator to use when allocating
73+
* temporary data during kernel or delegate execution.
74+
* @param[in] event_tracer A EventTracer used for tracking and logging events.
75+
* @param[in] data_map_loader A DataLoader used for loading external weights.
76+
*/
77+
ET_NODISCARD static runtime::Result<std::unique_ptr<BundledModule>> from_file(
78+
const std::string& file_path,
79+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
80+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
81+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
82+
std::unique_ptr<runtime::DataLoader> data_map_loader = nullptr);
83+
84+
using Module::execute;
85+
86+
/**
87+
* Execute a specific method with the input value at the given `testset_idx`
88+
* from the bundle to the method. Loads the program and method before
89+
* executing if needed.
90+
*
91+
* This function is a wrapper of `load_bundled_input` in `bundled_program`.
92+
*
93+
* @param[in] method_name The name of the method to execute.
94+
* @param[in] testset_idx The index of the input value to be passed to the
95+
* method.
96+
*
97+
* @returns Return Error::Ok on a successful load, or the error happens during
98+
* execution.
99+
*/
100+
ET_NODISCARD
101+
runtime::Result<std::vector<runtime::EValue>> execute(
102+
const std::string& method_name,
103+
const size_t testset_idx);
104+
105+
/**
106+
* Verify the output of a specific method with the expected output from the
107+
* program bundle at the given `testset_idx`.
108+
*
109+
* This function is a wrapper of `verify_method_outputs` in `bundled_program`.
110+
*
111+
* @param[in] method_name The name of the method to extract outputs from.
112+
* @param[in] testset_idx The index of expected output needs to be compared.
113+
* @param[in] rtol Relative tolerance used for data comparsion.
114+
* @param[in] atol Absolute tolerance used for data comparsion.
115+
*
116+
* @returns Return Error::Ok if two outputs match, or the error happens during
117+
* execution.
118+
*/
119+
ET_NODISCARD
120+
runtime::Error verify_method_outputs(
121+
const std::string& method_name,
122+
const size_t testset_idx,
123+
double rtol = 1e-5,
124+
double atol = 1e-8);
125+
126+
private:
127+
const void* bundled_program_ptr_;
128+
bool is_loaded_from_file_ = false;
129+
};
130+
131+
} // namespace ET_BUNDLED_MODULE_NAMESPACE
132+
} // namespace extension
133+
} // namespace executorch

extension/module/module.cpp

-10
Original file line numberDiff line numberDiff line change
@@ -303,16 +303,6 @@ runtime::Error Module::set_output(
303303
output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
304304
}
305305

306-
ET_NODISCARD inline runtime::Result<Method*> Module::get_method(
307-
const std::string& method_name) {
308-
ET_CHECK_OR_RETURN_ERROR(
309-
methods_.count(method_name) > 0,
310-
InvalidArgument,
311-
"no such method in program: %s",
312-
method_name.c_str());
313-
return methods_[method_name].method.get();
314-
}
315-
316306
} // namespace ET_MODULE_NAMESPACE
317307
} // namespace extension
318308
} // namespace executorch

extension/module/module.h

-10
Original file line numberDiff line numberDiff line change
@@ -498,16 +498,6 @@ class Module {
498498
std::unique_ptr<NamedDataMap> data_map_;
499499

500500
protected:
501-
/**
502-
* Get a method by method name.
503-
*
504-
* @param[in] method_name The name of the method to get.
505-
*
506-
* @returns A Result object containing either a pointer to the requested
507-
* method or an error to indicate failure.
508-
*/
509-
ET_NODISCARD inline runtime::Result<Method*> get_method(
510-
const std::string& method_name);
511501
std::unordered_map<std::string, MethodHolder> methods_;
512502

513503
friend class ExecuTorchJni;

extension/module/targets.bzl

+22
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,25 @@ def define_common_targets():
3131
"//executorch/runtime/executor:program" + aten_suffix,
3232
],
3333
)
34+
35+
runtime.cxx_library(
36+
name = "bundled_module" + aten_suffix,
37+
srcs = [
38+
"bundled_module.cpp",
39+
],
40+
exported_headers = [
41+
"bundled_module.h",
42+
],
43+
visibility = [
44+
"@EXECUTORCH_CLIENTS",
45+
],
46+
deps = [
47+
"//executorch/extension/data_loader:buffer_data_loader",
48+
"//executorch/extension/data_loader:file_data_loader",
49+
"//executorch/devtools/bundled_program:runtime" + aten_suffix,
50+
"//executorch/devtools/bundled_program/schema:bundled_program_schema_fbs",
51+
],
52+
exported_deps = [
53+
"//executorch/extension/module:module" + aten_suffix,
54+
],
55+
)

0 commit comments

Comments
 (0)