@@ -33,9 +33,10 @@ struct Tensor {
33
33
unsigned int data_index (unsigned int const indices[N]) const {
34
34
unsigned int index = 0 ;
35
35
for (unsigned int i = 0 ; i < N; ++i) {
36
- ASSERT (indices[i] < shape[i]);
36
+ ASSERT (indices[i] < shape[i], " Invalid index " );
37
37
// TODO: 计算 index
38
38
}
39
+ return index ;
39
40
}
40
41
};
41
42
@@ -48,10 +49,12 @@ int main(int argc, char **argv) {
48
49
unsigned int i0[]{0 , 0 , 0 , 0 };
49
50
tensor[i0] = 1 ;
50
51
ASSERT (tensor[i0] == 1 , " tensor[i0] should be 1" );
52
+ ASSERT (tensor.data [0 ] == 1 , " tensor[i0] should be 1" );
51
53
52
54
unsigned int i1[]{1 , 2 , 3 , 4 };
53
- tensor[i0] = 2 ;
54
- ASSERT (tensor[i0] == 2 , " tensor[i1] should be 2" );
55
+ tensor[i1] = 2 ;
56
+ ASSERT (tensor[i1] == 2 , " tensor[i1] should be 2" );
57
+ ASSERT (tensor.data [119 ] == 2 , " tensor[i1] should be 2" );
55
58
}
56
59
{
57
60
unsigned int shape[]{7 , 8 , 128 };
@@ -60,10 +63,12 @@ int main(int argc, char **argv) {
60
63
unsigned int i0[]{0 , 0 , 0 };
61
64
tensor[i0] = 1 .f ;
62
65
ASSERT (tensor[i0] == 1 .f , " tensor[i0] should be 1" );
66
+ ASSERT (tensor.data [0 ] == 1 .f , " tensor[i0] should be 1" );
63
67
64
68
unsigned int i1[]{3 , 4 , 99 };
65
- tensor[i0] = 2 .f ;
66
- ASSERT (tensor[i0] == 2 .f , " tensor[i1] should be 2" );
69
+ tensor[i1] = 2 .f ;
70
+ ASSERT (tensor[i1] == 2 .f , " tensor[i1] should be 2" );
71
+ ASSERT (tensor.data [3683 ] == 2 .f , " tensor[i1] should be 2" );
67
72
}
68
73
return 0 ;
69
74
}
0 commit comments