|
| 1 | +#include "core/conversion/converters/converters.h" |
| 2 | +#include "core/conversion/tensorcontainer/TensorContainer.h" |
| 3 | +#include "core/util/prelude.h" |
| 4 | + |
| 5 | +namespace torch_tensorrt { |
| 6 | +namespace core { |
| 7 | +namespace conversion { |
| 8 | +namespace converters { |
| 9 | +namespace impl { |
| 10 | +namespace { |
| 11 | + |
| 12 | +// clang-format off |
| 13 | +auto cat_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() |
| 14 | + .pattern({"aten::chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]", |
| 15 | + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { |
| 16 | + auto in = args[0].ITensorOrFreeze(ctx); |
| 17 | + auto chunks = args[1].unwrapToInt(); |
| 18 | + auto dim = args[2].unwrapToInt(); |
| 19 | + bool dynamic_shape = ctx->input_is_dynamic; |
| 20 | + int size = in->getDimensions().nbDims; |
| 21 | + int maxDim = static_cast<int32_t>(in->getDimensions().d[dim]); |
| 22 | + |
| 23 | + c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>(); |
| 24 | + c10::TypePtr elementType = lt->getElementType(); |
| 25 | + |
| 26 | + int offset = 0; |
| 27 | + if(dim < 0) { |
| 28 | + dim = in->getDimensions().nbDims + dim; |
| 29 | + } |
| 30 | + if (dynamic_shape) { |
| 31 | + TORCHTRT_ASSERT(in->getDimensions().d[dim] != -1, "Can't chunk on dynamic shape dimension!"); |
| 32 | + } |
| 33 | + if (chunks > in->getDimensions().d[dim]) { |
| 34 | + LOG_WARNING("The chunks size" << chunks << "along dimension" << dim << "is greater than tensor with size" << in->getDimensions().d[dim] |
| 35 | + << "it will default to dimension" << in->getDimensions().d[dim]); |
| 36 | + } |
| 37 | + int step = (maxDim + chunks - 1) / chunks; |
| 38 | + nvinfer1::Dims start_, size_, stride_; |
| 39 | + int nbdims = in->getDimensions().nbDims; |
| 40 | + start_.nbDims = nbdims; |
| 41 | + size_.nbDims = nbdims; |
| 42 | + stride_.nbDims = nbdims; |
| 43 | + |
| 44 | + int startIdx = 0; |
| 45 | + int endIdx = maxDim; |
| 46 | + |
| 47 | + for (int i = 0; i < nbdims; i++) { |
| 48 | + start_.d[i] = 0; |
| 49 | + size_.d[i] = 0; |
| 50 | + stride_.d[i] = 1; |
| 51 | + } |
| 52 | + // update slice layer |
| 53 | + auto list = c10::impl::GenericList(elementType); |
| 54 | + list.reserve(chunks); |
| 55 | + if(!dynamic_shape) { |
| 56 | + for (int chunk = 0; chunk < chunks; chunk++) { |
| 57 | + for (int i = 0; i < nbdims; i++) { |
| 58 | + if (i == dim) { |
| 59 | + start_.d[i] = offset; |
| 60 | + size_.d[i] = std::min(step, maxDim - offset); |
| 61 | + } |
| 62 | + } |
| 63 | + LOG_DEBUG("start_:" << start_); |
| 64 | + LOG_DEBUG("size_:" << size_); |
| 65 | + LOG_DEBUG("stride_:" << stride_); |
| 66 | + auto slice_layer = ctx->net->addSlice(*in, start_, size_, stride_); |
| 67 | + auto tensor_holder = TensorContainer(); |
| 68 | + tensor_holder.hold_tensor(slice_layer->getOutput(0)); |
| 69 | + auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder))); |
| 70 | + list.emplace_back(ival); |
| 71 | + offset = offset + step; |
| 72 | + } |
| 73 | + } |
| 74 | + auto split_output_ivalue = std::move(torch::jit::IValue(list)); |
| 75 | + ctx->AssociateValueAndIValue(n->outputs()[0], split_output_ivalue); |
| 76 | + return true; |
| 77 | + }}); |
| 78 | +// clang-format on |
| 79 | +} // namespace |
| 80 | +} // namespace impl |
| 81 | +} // namespace converters |
| 82 | +} // namespace conversion |
| 83 | +} // namespace core |
| 84 | +} // namespace torch_tensorrt |
0 commit comments