|
23 | 23 | #include <executorch/extension/data_loader/buffer_data_loader.h>
|
24 | 24 | #include <executorch/extension/data_loader/mmap_data_loader.h>
|
25 | 25 | #include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
|
| 26 | +#include <executorch/extension/module/bundled_module.h> |
26 | 27 | #include <executorch/extension/threadpool/threadpool.h>
|
27 | 28 | #include <executorch/runtime/backend/interface.h>
|
28 | 29 | #include <executorch/runtime/core/data_loader.h>
|
@@ -96,6 +97,7 @@ using ::executorch::ET_RUNTIME_NAMESPACE::Program;
|
96 | 97 | using ::executorch::extension::BufferDataLoader;
|
97 | 98 | using ::executorch::extension::MallocMemoryAllocator;
|
98 | 99 | using ::executorch::extension::MmapDataLoader;
|
| 100 | +using ::executorch::extension::ET_BUNDLED_MODULE_NAMESPACE::BundledModule; |
99 | 101 | using ::executorch::runtime::ArrayRef;
|
100 | 102 | using ::executorch::runtime::DataLoader;
|
101 | 103 | using ::executorch::runtime::Error;
|
@@ -442,11 +444,12 @@ inline std::unique_ptr<Module> load_module_from_file(
|
442 | 444 |
|
443 | 445 | static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;
|
444 | 446 |
|
445 |
| -struct PyBundledModule final { |
| 447 | +struct PyBundledModule : public BundledModule { |
446 | 448 | explicit PyBundledModule(
|
447 | 449 | const py::bytes& buffer,
|
448 | 450 | uint32_t bundled_input_pool_size)
|
449 |
| - : bundled_program_ptr_(buffer), |
| 451 | + : BundledModule(buffer.cast<std::string_view>().data()), |
| 452 | + bundled_program_ptr_(buffer), |
450 | 453 | program_ptr_(static_cast<const void*>(
|
451 | 454 | bundled_program_flatbuffer::GetBundledProgram(
|
452 | 455 | get_bundled_program_ptr())
|
@@ -842,22 +845,20 @@ struct PyModule final {
|
842 | 845 | size_t testset_idx,
|
843 | 846 | double rtol = 1e-5,
|
844 | 847 | double atol = 1e-8) {
|
845 |
| - const void* bundled_program_ptr = m.get_bundled_program_ptr(); |
846 |
| - auto& method = module_->get_method(method_name); |
847 |
| - Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input( |
848 |
| - method, bundled_program_ptr, testset_idx); |
| 848 | + auto outputs = m.execute(method_name, testset_idx); |
| 849 | + |
849 | 850 | THROW_IF_ERROR(
|
850 |
| - status, |
851 |
| - "load_bundled_input failed with status 0x%" PRIx32, |
852 |
| - static_cast<uint32_t>(status)); |
853 |
| - py::list outputs = plan_execute(method_name); |
854 |
| - status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs( |
855 |
| - method, bundled_program_ptr, testset_idx, rtol, atol); |
| 851 | + outputs.error(), |
| 852 | + "Execution failed with status 0x%" PRIx32, |
| 853 | + static_cast<uint32_t>(outputs.error())); |
| 854 | + |
| 855 | + auto status = m.verify_method_outputs(method_name, testset_idx, rtol, atol); |
856 | 856 | THROW_IF_ERROR(
|
857 | 857 | status,
|
858 | 858 | "Result verification failed with status %" PRIu32,
|
859 | 859 | static_cast<uint32_t>(status));
|
860 |
| - return outputs; |
| 860 | + |
| 861 | + return get_outputs_as_py_list(outputs.get()); |
861 | 862 | }
|
862 | 863 |
|
863 | 864 | py::list plan_execute(
|
|
0 commit comments