@@ -11,9 +11,9 @@ See the License for the specific language governing permissions and
11
11
limitations under the License.
12
12
*/
13
13
14
- #include < torch/extension.h>
15
14
#include < cuda.h>
16
15
#include < cuda_runtime.h>
16
+ #include < torch/extension.h>
17
17
18
18
#include < vector>
19
19
@@ -49,59 +49,53 @@ __device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) {
49
49
50
50
template <typename scalar_t >
51
51
__global__ void lltm_cuda_forward_kernel (
52
- const torch::PackedTensorAccessor <scalar_t ,3 , torch::RestrictPtrTraits, size_t > gates,
53
- const torch::PackedTensorAccessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t > old_cell,
54
- torch::PackedTensorAccessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t > new_h,
55
- torch::PackedTensorAccessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t > new_cell,
56
- torch::PackedTensorAccessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t > input_gate,
57
- torch::PackedTensorAccessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t > output_gate,
58
- torch::PackedTensorAccessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t > candidate_cell) {
59
- // batch index
52
+ const torch::PackedTensorAccessor32 <scalar_t , 3 , torch::RestrictPtrTraits> gates,
53
+ const torch::PackedTensorAccessor32 <scalar_t , 2 , torch::RestrictPtrTraits> old_cell,
54
+ torch::PackedTensorAccessor32 <scalar_t , 2 , torch::RestrictPtrTraits> new_h,
55
+ torch::PackedTensorAccessor32 <scalar_t , 2 , torch::RestrictPtrTraits> new_cell,
56
+ torch::PackedTensorAccessor32 <scalar_t , 2 , torch::RestrictPtrTraits> input_gate,
57
+ torch::PackedTensorAccessor32 <scalar_t , 2 , torch::RestrictPtrTraits> output_gate,
58
+ torch::PackedTensorAccessor32 <scalar_t , 2 , torch::RestrictPtrTraits> candidate_cell) {
59
+ // batch index
60
60
const int n = blockIdx .y ;
61
61
// column index
62
62
const int c = blockIdx .x * blockDim .x + threadIdx .x ;
63
- if (c < gates.size (2 )){
63
+ if (c < gates.size (2 )) {
64
64
input_gate[n][c] = sigmoid (gates[n][0 ][c]);
65
65
output_gate[n][c] = sigmoid (gates[n][1 ][c]);
66
66
candidate_cell[n][c] = elu (gates[n][2 ][c]);
67
- new_cell[n][c] =
68
- old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c];
67
+ new_cell[n][c] = old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c];
69
68
new_h[n][c] = tanh (new_cell[n][c]) * output_gate[n][c];
70
69
}
71
70
}
72
71
73
72
template <typename scalar_t >
74
73
__global__ void lltm_cuda_backward_kernel (
75
- torch::PackedTensorAccessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t > d_old_cell,
76
- torch::PackedTensorAccessor <scalar_t ,3 , torch::RestrictPtrTraits, size_t > d_gates,
77
- const torch::PackedTensorAccessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t > grad_h,
78
- const torch::PackedTensorAccessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t > grad_cell,
79
- const torch::PackedTensorAccessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t > new_cell,
80
- const torch::PackedTensorAccessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t > input_gate,
81
- const torch::PackedTensorAccessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t > output_gate,
82
- const torch::PackedTensorAccessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t > candidate_cell,
83
- const torch::PackedTensorAccessor <scalar_t ,3 , torch::RestrictPtrTraits, size_t > gate_weights) {
84
- // batch index
74
+ torch::PackedTensorAccessor32 <scalar_t , 2 , torch::RestrictPtrTraits> d_old_cell,
75
+ torch::PackedTensorAccessor32 <scalar_t , 3 , torch::RestrictPtrTraits> d_gates,
76
+ const torch::PackedTensorAccessor32 <scalar_t , 2 , torch::RestrictPtrTraits> grad_h,
77
+ const torch::PackedTensorAccessor32 <scalar_t , 2 , torch::RestrictPtrTraits> grad_cell,
78
+ const torch::PackedTensorAccessor32 <scalar_t , 2 , torch::RestrictPtrTraits> new_cell,
79
+ const torch::PackedTensorAccessor32 <scalar_t , 2 , torch::RestrictPtrTraits> input_gate,
80
+ const torch::PackedTensorAccessor32 <scalar_t , 2 , torch::RestrictPtrTraits> output_gate,
81
+ const torch::PackedTensorAccessor32 <scalar_t , 2 , torch::RestrictPtrTraits> candidate_cell,
82
+ const torch::PackedTensorAccessor32 <scalar_t , 3 , torch::RestrictPtrTraits> gate_weights) {
83
+ // batch index
85
84
const int n = blockIdx .y ;
86
85
// column index
87
86
const int c = blockIdx .x * blockDim .x + threadIdx .x ;
88
- if (c < d_gates.size (2 )){
87
+ if (c < d_gates.size (2 )) {
89
88
const auto d_output_gate = tanh (new_cell[n][c]) * grad_h[n][c];
90
89
const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c];
91
- const auto d_new_cell =
92
- d_tanh (new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c];
93
-
90
+ const auto d_new_cell = d_tanh (new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c];
94
91
95
92
d_old_cell[n][c] = d_new_cell;
96
93
const auto d_candidate_cell = input_gate[n][c] * d_new_cell;
97
94
const auto d_input_gate = candidate_cell[n][c] * d_new_cell;
98
95
99
- d_gates[n][0 ][c] =
100
- d_input_gate * d_sigmoid (gate_weights[n][0 ][c]);
101
- d_gates[n][1 ][c] =
102
- d_output_gate * d_sigmoid (gate_weights[n][1 ][c]);
103
- d_gates[n][2 ][c] =
104
- d_candidate_cell * d_elu (gate_weights[n][2 ][c]);
96
+ d_gates[n][0 ][c] = d_input_gate * d_sigmoid (gate_weights[n][0 ][c]);
97
+ d_gates[n][1 ][c] = d_output_gate * d_sigmoid (gate_weights[n][1 ][c]);
98
+ d_gates[n][2 ][c] = d_candidate_cell * d_elu (gate_weights[n][2 ][c]);
105
99
}
106
100
}
107
101
} // namespace
@@ -128,16 +122,16 @@ std::vector<torch::Tensor> lltm_cuda_forward(
128
122
const int threads = 1024 ;
129
123
const dim3 blocks ((state_size + threads - 1 ) / threads, batch_size);
130
124
131
- AT_DISPATCH_FLOATING_TYPES (gates.type (), " lltm_forward_cuda" , ([&] {
132
- lltm_cuda_forward_kernel<scalar_t ><<<blocks, threads>>> (
133
- gates.packed_accessor <scalar_t ,3 , torch::RestrictPtrTraits, size_t >(),
134
- old_cell.packed_accessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t >(),
135
- new_h.packed_accessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t >(),
136
- new_cell.packed_accessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t >(),
137
- input_gate.packed_accessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t >(),
138
- output_gate.packed_accessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t >(),
139
- candidate_cell.packed_accessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t >());
140
- }));
125
+ AT_DISPATCH_FLOATING_TYPES (gates.scalar_type (), " lltm_forward_cuda" , ([&] {
126
+ lltm_cuda_forward_kernel<scalar_t ><<<blocks, threads>>> (
127
+ gates.packed_accessor32 <scalar_t , 3 , torch::RestrictPtrTraits>(),
128
+ old_cell.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
129
+ new_h.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
130
+ new_cell.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
131
+ input_gate.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
132
+ output_gate.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
133
+ candidate_cell.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>());
134
+ }));
141
135
142
136
return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
143
137
}
@@ -161,18 +155,18 @@ std::vector<torch::Tensor> lltm_cuda_backward(
161
155
const int threads = 1024 ;
162
156
const dim3 blocks ((state_size + threads - 1 ) / threads, batch_size);
163
157
164
- AT_DISPATCH_FLOATING_TYPES (X.type (), " lltm_forward_cuda" , ([&] {
165
- lltm_cuda_backward_kernel<scalar_t ><<<blocks, threads>>> (
166
- d_old_cell.packed_accessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t >(),
167
- d_gates.packed_accessor <scalar_t ,3 , torch::RestrictPtrTraits, size_t >(),
168
- grad_h.packed_accessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t >(),
169
- grad_cell.packed_accessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t >(),
170
- new_cell.packed_accessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t >(),
171
- input_gate.packed_accessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t >(),
172
- output_gate.packed_accessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t >(),
173
- candidate_cell.packed_accessor <scalar_t ,2 , torch::RestrictPtrTraits, size_t >(),
174
- gates.packed_accessor <scalar_t ,3 , torch::RestrictPtrTraits, size_t >());
175
- }));
158
+ AT_DISPATCH_FLOATING_TYPES (X.scalar_type (), " lltm_forward_cuda" , ([&] {
159
+ lltm_cuda_backward_kernel<scalar_t ><<<blocks, threads>>> (
160
+ d_old_cell.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
161
+ d_gates.packed_accessor32 <scalar_t , 3 , torch::RestrictPtrTraits>(),
162
+ grad_h.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
163
+ grad_cell.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
164
+ new_cell.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
165
+ input_gate.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
166
+ output_gate.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
167
+ candidate_cell.packed_accessor32 <scalar_t , 2 , torch::RestrictPtrTraits>(),
168
+ gates.packed_accessor32 <scalar_t , 3 , torch::RestrictPtrTraits>());
169
+ }));
176
170
177
171
auto d_gate_weights = d_gates.flatten (1 , 2 );
178
172
auto d_weights = d_gate_weights.t ().mm (X);
0 commit comments