|
23 | 23 | #include <executorch/extension/data_loader/buffer_data_loader.h>
|
24 | 24 | #include <executorch/extension/data_loader/mmap_data_loader.h>
|
25 | 25 | #include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
|
| 26 | +#include <executorch/extension/module/bundled_module.h> |
26 | 27 | #include <executorch/extension/threadpool/threadpool.h>
|
27 | 28 | #include <executorch/runtime/backend/interface.h>
|
28 | 29 | #include <executorch/runtime/core/data_loader.h>
|
@@ -94,6 +95,7 @@ using ::executorch::ET_RUNTIME_NAMESPACE::Kernel;
|
94 | 95 | using ::executorch::ET_RUNTIME_NAMESPACE::Method;
|
95 | 96 | using ::executorch::ET_RUNTIME_NAMESPACE::Program;
|
96 | 97 | using ::executorch::extension::BufferDataLoader;
|
| 98 | +using ::executorch::extension::BundledModule; |
97 | 99 | using ::executorch::extension::MallocMemoryAllocator;
|
98 | 100 | using ::executorch::extension::MmapDataLoader;
|
99 | 101 | using ::executorch::runtime::ArrayRef;
|
@@ -442,11 +444,12 @@ inline std::unique_ptr<Module> load_module_from_file(
|
442 | 444 |
|
443 | 445 | static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;
|
444 | 446 |
|
445 |
| -struct PyBundledModule final { |
| 447 | +struct PyBundledModule : public BundledModule { |
446 | 448 | explicit PyBundledModule(
|
447 | 449 | const py::bytes& buffer,
|
448 | 450 | uint32_t bundled_input_pool_size)
|
449 |
| - : bundled_program_ptr_(buffer), |
| 451 | + : BundledModule(buffer.cast<std::string_view>().data()), |
| 452 | + bundled_program_ptr_(buffer), |
450 | 453 | program_ptr_(static_cast<const void*>(
|
451 | 454 | bundled_program_flatbuffer::GetBundledProgram(
|
452 | 455 | get_bundled_program_ptr())
|
@@ -840,22 +843,20 @@ struct PyModule final {
|
840 | 843 | size_t testset_idx,
|
841 | 844 | double rtol = 1e-5,
|
842 | 845 | double atol = 1e-8) {
|
843 |
| - const void* bundled_program_ptr = m.get_bundled_program_ptr(); |
844 |
| - auto& method = module_->get_method(method_name); |
845 |
| - Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input( |
846 |
| - method, bundled_program_ptr, testset_idx); |
| 846 | + auto outputs = m.execute(method_name, testset_idx); |
| 847 | + |
847 | 848 | THROW_IF_ERROR(
|
848 |
| - status, |
849 |
| - "load_bundled_input failed with status 0x%" PRIx32, |
850 |
| - static_cast<uint32_t>(status)); |
851 |
| - py::list outputs = plan_execute(method_name); |
852 |
| - status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs( |
853 |
| - method, bundled_program_ptr, testset_idx, rtol, atol); |
| 849 | + outputs.error(), |
| 850 | + "Execution failed with status 0x%" PRIx32, |
| 851 | + static_cast<uint32_t>(outputs.error())); |
| 852 | + |
| 853 | + auto status = m.verify_method_outputs(method_name, testset_idx, rtol, atol); |
854 | 854 | THROW_IF_ERROR(
|
855 | 855 | status,
|
856 | 856 | "Result verification failed with status %" PRIu32,
|
857 | 857 | static_cast<uint32_t>(status));
|
858 |
| - return outputs; |
| 858 | + |
| 859 | + return get_outputs_as_py_list(outputs.get()); |
859 | 860 | }
|
860 | 861 |
|
861 | 862 | py::list plan_execute(
|
|
0 commit comments