Skip to content

Commit 4a34847

Browse files
committed
remove custom asserts and default to ACTIVATION_RELU
1 parent de03f80 commit 4a34847

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

src/rnn.c

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,17 +92,20 @@ void compute_dense(const DenseLayer *layer, float *output, const float *input)
9292
sum += layer->input_weights[j*stride + i]*input[j];
9393
output[i] = WEIGHTS_SCALE*sum;
9494
}
95-
if (layer->activation == ACTIVATION_SIGMOID) {
95+
switch (layer->activation) {
96+
case ACTIVATION_SIGMOID:
9697
for (i=0;i<N;i++)
9798
output[i] = sigmoid_approx(output[i]);
98-
} else if (layer->activation == ACTIVATION_TANH) {
99+
break;
100+
case ACTIVATION_TANH:
99101
for (i=0;i<N;i++)
100102
output[i] = tansig_approx(output[i]);
101-
} else if (layer->activation == ACTIVATION_RELU) {
103+
break;
104+
default:
105+
case ACTIVATION_RELU:
102106
for (i=0;i<N;i++)
103107
output[i] = relu(output[i]);
104-
} else {
105-
*(int*)0=0;
108+
break;
106109
}
107110
}
108111

@@ -145,14 +148,16 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
145148
sum += gru->input_weights[2*N + j*stride + i]*input[j];
146149
for (j=0;j<N;j++)
147150
sum += gru->recurrent_weights[2*N + j*stride + i]*state[j]*r[j];
148-
if (gru->activation == ACTIVATION_SIGMOID) sum = sigmoid_approx(WEIGHTS_SCALE*sum);
149-
else if (gru->activation == ACTIVATION_TANH) sum = tansig_approx(WEIGHTS_SCALE*sum);
150-
else if (gru->activation == ACTIVATION_RELU) sum = relu(WEIGHTS_SCALE*sum);
151-
else *(int*)0=0;
152-
h[i] = z[i]*state[i] + (1-z[i])*sum;
153-
}
154-
for (i=0;i<N;i++)
155-
state[i] = h[i];
151+
switch (gru->activation) {
152+
case ACTIVATION_SIGMOID: sum = sigmoid_approx(WEIGHTS_SCALE*sum);break;
153+
case ACTIVATION_TANH: sum = tansig_approx(WEIGHTS_SCALE*sum); break;
154+
default:
155+
case ACTIVATION_RELU: sum = relu(WEIGHTS_SCALE*sum); break;
156+
}
157+
h[i] = z[i]*state[i] + (1-z[i])*sum;
158+
}
159+
for (i=0;i<N;i++)
160+
state[i] = h[i];
156161
}
157162

158163
#define INPUT_SIZE 42

0 commit comments

Comments
 (0)