Skip to content

Commit f829dfa

Browse files
[ExecuTorch][#10375] Add extension.BundledModule to Wrap extension.Module with Bundled Program Logic
Pull Request resolved: #10449 #10375 # Context This issue is a step of #9638. In #9638, we want to have `extension.Module` as the single source of implementation in `pybindings`, which means that `pybindings.PyModule` should use `extension.Module` rather than its own `pybindings.Module`. The issue is that `pybindings.PyModule` is dependent on the `method` getter from `pybindings.Module`, which `extension.Module` do not have. Since we don't want to expose `method` getter in `extension.Module`, we have to protect the getter, wrap the functions that is dependent on it and use the protected getter there, ultimately decouple `pybindings` from a `method` getter. # Proposal Now that we have a protected `method` getter, we can introduce a `extension.BundledModule`, a child class inheriting `extension.Module` which wraps up bundled program logic that is dependent on the `method` getter. ghstack-source-id: 280825735 Differential Revision: [D73564125](https://our.internmc.facebook.com/intern/diff/D73564125/)
1 parent df75088 commit f829dfa

File tree

11 files changed

+398
-20
lines changed

11 files changed

+398
-20
lines changed

devtools/bundled_program/schema/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def define_common_targets():
7474
visibility = [
7575
"//executorch/devtools/bundled_program/...",
7676
"//executorch/extension/pybindings/...",
77+
"//executorch/extension/module/...",
7778
],
7879
exported_headers = {
7980
OUTPUT_BUNDLED_HEADER: ":{}[{}]".format(BUNDLED_GEN_RULE_NAME, OUTPUT_BUNDLED_HEADER),

extension/module/bundled_module.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/module/bundled_module.h>
10+
11+
#include <executorch/devtools/bundled_program/bundled_program.h>
12+
#include <executorch/devtools/bundled_program/schema/bundled_program_schema_generated.h>
13+
#include <executorch/extension/data_loader/buffer_data_loader.h>
14+
#include <executorch/extension/data_loader/file_data_loader.h>
15+
16+
/**
17+
* Unwrap a Result to obtain its value (direct object, not a pointer).
18+
* If the Result contains an error, propagate the error via trivial function
19+
* return. The macro wraps the object into a unique_ptr.
20+
*
21+
* Note: A function using ET_UNWRAP_UNIQUE should itself return a Result or
22+
* Error.
23+
*
24+
* @param[in] result__ Expression yielding the result to unwrap.
25+
*/
26+
#define ET_UNWRAP_UNIQUE(result__) \
27+
({ \
28+
auto et_result__ = (result__); \
29+
if (!et_result__.ok()) { \
30+
return et_result__.error(); \
31+
} \
32+
std::make_unique<std::remove_reference_t<decltype(*et_result__)>>( \
33+
std::move(*et_result__)); \
34+
})
35+
36+
namespace executorch {
37+
namespace extension {
38+
39+
namespace {
40+
std::unique_ptr<BufferDataLoader> program_data_loader(
41+
const void* bundled_program_ptr) {
42+
auto bundled_program =
43+
bundled_program_flatbuffer::GetBundledProgram(bundled_program_ptr);
44+
// the program inside the bundled program
45+
auto program = bundled_program->program();
46+
return std::make_unique<BufferDataLoader>(program->data(), program->size());
47+
}
48+
} // namespace
49+
50+
BundledModule::BundledModule(
51+
const void* bundled_program_ptr,
52+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
53+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
54+
std::unique_ptr<runtime::EventTracer> event_tracer,
55+
std::unique_ptr<runtime::DataLoader> data_map_loader)
56+
: Module(
57+
program_data_loader(bundled_program_ptr),
58+
std::move(memory_allocator),
59+
std::move(temp_allocator),
60+
std::move(event_tracer),
61+
std::move(data_map_loader)),
62+
bundled_program_ptr_(bundled_program_ptr) {}
63+
64+
runtime::Result<std::vector<runtime::EValue>> BundledModule::execute(
65+
const std::string& method_name,
66+
const size_t testset_idx) {
67+
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
68+
auto& method = methods_.at(method_name).method;
69+
auto& inputs = methods_.at(method_name).inputs;
70+
71+
ET_CHECK_OK_OR_RETURN_ERROR(
72+
executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input(
73+
*method, bundled_program_ptr_, testset_idx));
74+
ET_CHECK_OK_OR_RETURN_ERROR(method->get_inputs(inputs.data(), inputs.size()));
75+
ET_CHECK_OK_OR_RETURN_ERROR(method->execute());
76+
77+
const auto outputs_size = method->outputs_size();
78+
std::vector<runtime::EValue> outputs(outputs_size);
79+
ET_CHECK_OK_OR_RETURN_ERROR(
80+
method->get_outputs(outputs.data(), outputs_size));
81+
82+
return outputs;
83+
}
84+
85+
runtime::Error BundledModule::verify_method_outputs(
86+
const std::string& method_name,
87+
const size_t testset_idx,
88+
double rtol,
89+
double atol) {
90+
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
91+
auto& method = methods_.at(method_name).method;
92+
return executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs(
93+
*method, bundled_program_ptr_, testset_idx, rtol, atol);
94+
}
95+
96+
} // namespace extension
97+
} // namespace executorch

extension/module/bundled_module.h

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/extension/module/module.h>
12+
13+
namespace executorch {
14+
namespace extension {
15+
16+
/**
17+
* A facade class for loading bundled programs and executing methods within
18+
* them.
19+
*/
20+
class BundledModule : public Module {
21+
public:
22+
/**
23+
* Constructs an instance with the bundled program buffer pointer.
24+
*
25+
* This constructor reads the program from bundled program buffer to load the
26+
* module with data loader. The bundled program pointer is preserved so that
27+
* the portion outside of program is accessible.
28+
*
29+
* @param[in] bundled_program_ptr A DataLoader used for loading program data.
30+
* @param[in] memory_allocator A MemoryAllocator used for memory management.
31+
* @param[in] temp_allocator A MemoryAllocator to use when allocating
32+
* temporary data during kernel or delegate execution.
33+
* @param[in] event_tracer A EventTracer used for tracking and logging events.
34+
* @param[in] data_map_loader A DataLoader used for loading external weights.
35+
*/
36+
explicit BundledModule(
37+
const void* bundled_program_ptr,
38+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
39+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
40+
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr,
41+
std::unique_ptr<runtime::DataLoader> data_map_loader = nullptr);
42+
43+
// Disallow copying
44+
BundledModule(const BundledModule&) = delete;
45+
BundledModule& operator=(const BundledModule&) = delete;
46+
// Disallow copying
47+
BundledModule(BundledModule&&) = delete;
48+
BundledModule& operator=(BundledModule&&) = delete;
49+
// Default destructor
50+
~BundledModule() = default;
51+
52+
/**
53+
* Execute a specific method with the input value at the given `testset_idx`
54+
* from the bundle to the method. Loads the program and method before
55+
* executing if needed.
56+
*
57+
* This function is a wrapper of `load_bundled_input` in `bundled_program`.
58+
*
59+
* @param[in] method_name The name of the method to execute.
60+
* @param[in] testset_idx The index of the input value to be passed to the
61+
* method.
62+
*
63+
* @returns Return Error::Ok on a successful load, or the error happens during
64+
* execution.
65+
*/
66+
ET_NODISCARD
67+
runtime::Result<std::vector<runtime::EValue>> execute(
68+
const std::string& method_name,
69+
const size_t testset_idx);
70+
71+
/**
72+
* Verify the output of a specific method with the expected output from the
73+
* program bundle at the given `testset_idx`.
74+
*
75+
* This function is a wrapper of `verify_method_outputs` in `bundled_program`.
76+
*
77+
* @param[in] method_name The name of the method to extract outputs from.
78+
* @param[in] testset_idx The index of expected output needs to be compared.
79+
* @param[in] rtol Relative tolerance used for data comparsion.
80+
* @param[in] atol Absolute tolerance used for data comparsion.
81+
*
82+
* @returns Return Error::Ok if two outputs match, or the error happens during
83+
* execution.
84+
*/
85+
ET_NODISCARD
86+
runtime::Error verify_method_outputs(
87+
const std::string& method_name,
88+
const size_t testset_idx,
89+
double rtol = 1e-5,
90+
double atol = 1e-8);
91+
92+
private:
93+
const void* bundled_program_ptr_;
94+
};
95+
96+
} // namespace extension
97+
} // namespace executorch

extension/module/module.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -302,15 +302,5 @@ runtime::Error Module::set_output(
302302
output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
303303
}
304304

305-
ET_NODISCARD inline runtime::Result<Method*> Module::get_method(
306-
const std::string& method_name) {
307-
ET_CHECK_OR_RETURN_ERROR(
308-
methods_.count(method_name) > 0,
309-
InvalidArgument,
310-
"no such method in program: %s",
311-
method_name.c_str());
312-
return methods_[method_name].method.get();
313-
}
314-
315305
} // namespace extension
316306
} // namespace executorch

extension/module/module.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -493,16 +493,6 @@ class Module {
493493
std::unique_ptr<NamedDataMap> data_map_;
494494

495495
protected:
496-
/**
497-
* Get a method by method name.
498-
*
499-
* @param[in] method_name The name of the method to get.
500-
*
501-
* @returns A Result object containing either a pointer to the requested
502-
* method or an error to indicate failure.
503-
*/
504-
ET_NODISCARD inline runtime::Result<Method*> get_method(
505-
const std::string& method_name);
506496
std::unordered_map<std::string, MethodHolder> methods_;
507497

508498
friend class ExecuTorchJni;

extension/module/targets.bzl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,25 @@ def define_common_targets():
3131
"//executorch/runtime/executor:program" + aten_suffix,
3232
],
3333
)
34+
35+
runtime.cxx_library(
36+
name = "bundled_module" + aten_suffix,
37+
srcs = [
38+
"bundled_module.cpp",
39+
],
40+
exported_headers = [
41+
"bundled_module.h",
42+
],
43+
visibility = [
44+
"@EXECUTORCH_CLIENTS",
45+
],
46+
deps = [
47+
"//executorch/extension/data_loader:buffer_data_loader",
48+
"//executorch/extension/data_loader:file_data_loader",
49+
"//executorch/devtools/bundled_program:runtime" + aten_suffix,
50+
"//executorch/devtools/bundled_program/schema:bundled_program_schema_fbs",
51+
],
52+
exported_deps = [
53+
"//executorch/extension/module:module" + aten_suffix,
54+
],
55+
)
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/module/bundled_module.h>
10+
11+
#include <gtest/gtest.h>
12+
13+
#include <executorch/extension/data_loader/file_data_loader.h>
14+
15+
using namespace ::executorch::extension;
16+
using namespace ::executorch::runtime;
17+
18+
class BundledModuleTest : public ::testing::Test {
19+
protected:
20+
static void SetUpTestSuite() {
21+
std::string resources_path;
22+
if (const char* env = std::getenv("RESOURCES_PATH")) {
23+
resources_path = env;
24+
}
25+
bpte_path_ = resources_path + "/bundled_program.bpte";
26+
}
27+
28+
static inline std::string bpte_path_;
29+
};
30+
31+
#include <fstream>
32+
33+
std::vector<uint8_t> load_file_or_die(const char* path) {
34+
std::ifstream file(path, std::ios::binary | std::ios::ate);
35+
const size_t nbytes = file.tellg();
36+
file.seekg(0, std::ios::beg);
37+
auto file_data = std::vector<uint8_t>(nbytes);
38+
ET_CHECK_MSG(
39+
file.read(reinterpret_cast<char*>(file_data.data()), nbytes),
40+
"Could not load contents of file '%s'",
41+
path);
42+
return file_data;
43+
}
44+
45+
TEST_F(BundledModuleTest, TestExecute) {
46+
std::vector<uint8_t> file_data = load_file_or_die(bpte_path_.c_str());
47+
BundledModule bundled_module(reinterpret_cast<void*>(file_data.data()));
48+
49+
auto outputs = bundled_module.execute("forward", 0);
50+
EXPECT_EQ(outputs.error(), Error::Ok);
51+
auto status = bundled_module.verify_method_outputs(
52+
"forward", 0, 1e-3, 1e-5);
53+
EXPECT_EQ(status, Error::Ok);
54+
}

extension/module/test/resources/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,9 @@
2121
```
2222
python -m examples.portable.scripts.export --model_name="linear" -e
2323
```
24+
25+
### bundled_program.bpte
26+
27+
```
28+
python3 extension/module/test/resources/gen_bundled_program.py
29+
```
Binary file not shown.

0 commit comments

Comments
 (0)