Skip to content

Commit 5381184

Browse files
[ExecuTorch][#10447] Extend PyBundledModule with extension.BundledModule
Pull Request resolved: #10450 #10447 # Context This issue is a step of #9638. In #9638, we want to have `extension.Module` as the single source of implementation in `pybindings`, which means that `pybindings.PyModule` should use `extension.Module` rather than its own `pybindings.Module`. # Proposal Now that we have `extension.BundledModule` ready, we want to test it out by having our existing `PyBundledModule` to extend it, and let `verify_result_with_bundled_expected_output` to use it, so that we can test out the whole thing with https://github.com/pytorch/executorch/blob/fb45e19055a92d2a91a4d4b7008e135232cbb14b/devtools/bundled_program/test/test_end2end.py ghstack-source-id: 281084621 Differential Revision: [D73564127](https://our.internmc.facebook.com/intern/diff/D73564127/)
1 parent 89e9d19 commit 5381184

File tree

3 files changed

+16
-29
lines changed

3 files changed

+16
-29
lines changed

devtools/bundled_program/test/test_end2end.py

-16
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# flake8: noqa: F401
8-
import functools
9-
import inspect
10-
import os
11-
import random
128
import unittest
13-
from typing import Callable, Dict, Optional, Tuple, Type
14-
15-
import executorch.exir as exir
16-
17-
import executorch.exir.control_flow as control_flow
18-
19-
# @manual=//executorch/extension/pytree:pybindings
20-
import executorch.extension.pytree as pytree
21-
22-
import torch
239

2410
from executorch.devtools.bundled_program.core import BundledProgram
2511
from executorch.devtools.bundled_program.serialize import (
@@ -35,7 +21,6 @@
3521
try:
3622
from executorch.extension.pybindings.portable_lib import (
3723
_load_bundled_program_from_buffer,
38-
_load_for_executorch_from_buffer,
3924
_load_for_executorch_from_bundled_program,
4025
)
4126

@@ -47,7 +32,6 @@
4732
try:
4833
from executorch.extension.pybindings.aten_lib import ( # @manual=//executorch/extension/pybindings:aten_lib
4934
_load_bundled_program_from_buffer,
50-
_load_for_executorch_from_buffer,
5135
_load_for_executorch_from_bundled_program,
5236
)
5337

extension/pybindings/pybindings.cpp

+14-13
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <executorch/extension/data_loader/buffer_data_loader.h>
2424
#include <executorch/extension/data_loader/mmap_data_loader.h>
2525
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
26+
#include <executorch/extension/module/bundled_module.h>
2627
#include <executorch/extension/threadpool/threadpool.h>
2728
#include <executorch/runtime/backend/interface.h>
2829
#include <executorch/runtime/core/data_loader.h>
@@ -96,6 +97,7 @@ using ::executorch::ET_RUNTIME_NAMESPACE::Program;
9697
using ::executorch::extension::BufferDataLoader;
9798
using ::executorch::extension::MallocMemoryAllocator;
9899
using ::executorch::extension::MmapDataLoader;
100+
using ::executorch::extension::BundledModule;
99101
using ::executorch::runtime::ArrayRef;
100102
using ::executorch::runtime::DataLoader;
101103
using ::executorch::runtime::Error;
@@ -442,11 +444,12 @@ inline std::unique_ptr<Module> load_module_from_file(
442444

443445
static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;
444446

445-
struct PyBundledModule final {
447+
struct PyBundledModule : public BundledModule {
446448
explicit PyBundledModule(
447449
const py::bytes& buffer,
448450
uint32_t bundled_input_pool_size)
449-
: bundled_program_ptr_(buffer),
451+
: BundledModule(buffer.cast<std::string_view>().data()),
452+
bundled_program_ptr_(buffer),
450453
program_ptr_(static_cast<const void*>(
451454
bundled_program_flatbuffer::GetBundledProgram(
452455
get_bundled_program_ptr())
@@ -840,22 +843,20 @@ struct PyModule final {
840843
size_t testset_idx,
841844
double rtol = 1e-5,
842845
double atol = 1e-8) {
843-
const void* bundled_program_ptr = m.get_bundled_program_ptr();
844-
auto& method = module_->get_method(method_name);
845-
Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input(
846-
method, bundled_program_ptr, testset_idx);
846+
auto outputs = m.execute(method_name, testset_idx);
847+
847848
THROW_IF_ERROR(
848-
status,
849-
"load_bundled_input failed with status 0x%" PRIx32,
850-
static_cast<uint32_t>(status));
851-
py::list outputs = plan_execute(method_name);
852-
status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs(
853-
method, bundled_program_ptr, testset_idx, rtol, atol);
849+
outputs.error(),
850+
"Execution failed with status 0x%" PRIx32,
851+
static_cast<uint32_t>(outputs.error()));
852+
853+
auto status = m.verify_method_outputs(method_name, testset_idx, rtol, atol);
854854
THROW_IF_ERROR(
855855
status,
856856
"Result verification failed with status %" PRIu32,
857857
static_cast<uint32_t>(status));
858-
return outputs;
858+
859+
return get_outputs_as_py_list(outputs.get());
859860
}
860861

861862
py::list plan_execute(

shim_et/xplat/executorch/extension/pybindings/pybindings.bzl

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ PORTABLE_MODULE_DEPS = [
1616
"//executorch/extension/data_loader:buffer_data_loader",
1717
"//executorch/extension/data_loader:mmap_data_loader",
1818
"//executorch/extension/memory_allocator:malloc_memory_allocator",
19+
"//executorch/extension/module:bundled_module",
1920
"//executorch/runtime/executor/test:test_backend_compiler_lib",
2021
"//executorch/devtools/etdump:etdump_flatcc",
2122
] + get_all_cpu_backend_targets()
@@ -28,6 +29,7 @@ ATEN_MODULE_DEPS = [
2829
"//executorch/extension/data_loader:buffer_data_loader",
2930
"//executorch/extension/data_loader:mmap_data_loader",
3031
"//executorch/extension/memory_allocator:malloc_memory_allocator",
32+
"//executorch/extension/module:bundled_module_aten",
3133
"//executorch/devtools/bundled_program:runtime_aten",
3234
"//executorch/runtime/executor/test:test_backend_compiler_lib_aten",
3335
"//executorch/devtools/etdump:etdump_flatcc",

0 commit comments

Comments
 (0)