|
23 | 23 | #include <executorch/extension/data_loader/buffer_data_loader.h>
|
24 | 24 | #include <executorch/extension/data_loader/mmap_data_loader.h>
|
25 | 25 | #include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
|
| 26 | +#include <executorch/extension/module/bundled_module.h> |
26 | 27 | #include <executorch/extension/threadpool/threadpool.h>
|
27 | 28 | #include <executorch/runtime/backend/interface.h>
|
28 | 29 | #include <executorch/runtime/core/data_loader.h>
|
@@ -442,11 +443,12 @@ inline std::unique_ptr<Module> load_module_from_file(
|
442 | 443 |
|
443 | 444 | static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;
|
444 | 445 |
|
445 |
| -struct PyBundledModule final { |
| 446 | +struct PyBundledModule : public BundledModule { |
446 | 447 | explicit PyBundledModule(
|
447 | 448 | const py::bytes& buffer,
|
448 | 449 | uint32_t bundled_input_pool_size)
|
449 |
| - : bundled_program_ptr_(buffer), |
| 450 | + : BundledModule(buffer.cast<std::string_view>().data()), |
| 451 | + bundled_program_ptr_(buffer), |
450 | 452 | program_ptr_(static_cast<const void*>(
|
451 | 453 | bundled_program_flatbuffer::GetBundledProgram(
|
452 | 454 | get_bundled_program_ptr())
|
@@ -842,22 +844,20 @@ struct PyModule final {
|
842 | 844 | size_t testset_idx,
|
843 | 845 | double rtol = 1e-5,
|
844 | 846 | double atol = 1e-8) {
|
845 |
| - const void* bundled_program_ptr = m.get_bundled_program_ptr(); |
846 |
| - auto& method = module_->get_method(method_name); |
847 |
| - Error status = executorch::BUNDLED_PROGRAM_NAMESPACE::load_bundled_input( |
848 |
| - method, bundled_program_ptr, testset_idx); |
| 847 | + auto outputs = m.execute(method_name, testset_idx); |
| 848 | + |
849 | 849 | THROW_IF_ERROR(
|
850 |
| - status, |
851 |
| - "load_bundled_input failed with status 0x%" PRIx32, |
852 |
| - static_cast<uint32_t>(status)); |
853 |
| - py::list outputs = plan_execute(method_name); |
854 |
| - status = executorch::BUNDLED_PROGRAM_NAMESPACE::verify_method_outputs( |
855 |
| - method, bundled_program_ptr, testset_idx, rtol, atol); |
| 850 | + outputs.error(), |
| 851 | + "Execution failed with status 0x%" PRIx32, |
| 852 | + static_cast<uint32_t>(outputs.error())); |
| 853 | + |
| 854 | + auto status = m.verify_method_outputs(method_name, testset_idx, rtol, atol); |
856 | 855 | THROW_IF_ERROR(
|
857 | 856 | status,
|
858 | 857 | "Result verification failed with status %" PRIu32,
|
859 | 858 | static_cast<uint32_t>(status));
|
860 |
| - return outputs; |
| 859 | + |
| 860 | + return get_outputs_as_py_list(outputs.get()); |
861 | 861 | }
|
862 | 862 |
|
863 | 863 | py::list plan_execute(
|
|
0 commit comments