Skip to content

Commit 82631fa

Browse files
authored
fix: Repair flaky TopK core test (#2022)
1 parent 780e398 commit 82631fa

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

tests/core/conversion/converters/test_topk.cpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,18 @@ TEST(Converters, ATenTopKConvertsCorrectly) {
1717
auto g = std::make_shared<torch::jit::Graph>();
1818
torch::jit::parseIR(graph, g.get());
1919

20-
auto in = at::rand({10, 10, 100}, {at::kCUDA});
20+
auto dim0 = 10, dim1 = 10, dim2 = 100;
21+
22+
// Initialize zero tensor to be filled with random indices along the final dimension
23+
auto in = at::zeros({dim0, dim1, dim2}, {at::kCUDA});
24+
25+
// For each final dimension, fill it with random scramble of unique integers in the range [0, dim0*dim1*dim2)
26+
for (auto i = 0; i < dim0; i++) {
27+
for (auto j = 0; j < dim1; j++) {
28+
auto random_index_permutation = at::randperm(dim0 * dim1 * dim2, c10::kInt, {}, at::kCUDA, {}).slice(0, 0, dim2);
29+
in.slice(0, i, i + 1).slice(1, j, j + 1) = random_index_permutation;
30+
}
31+
}
2132

2233
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
2334
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

0 commit comments

Comments
 (0)