Skip to content

Commit 96f8b96

Browse files
committed
Update bn_folding_test.py
1 parent d23596b commit 96f8b96

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

tests/bn_folding_test.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -464,17 +464,13 @@ def test_same_training_and_prediction(model_name):
464464
if model_name == "conv2d":
465465
x_shape = (2, 2, 1)
466466
kernel = np.array([[[[1., 1.]], [[1., 0.]]], [[[1., 1.]], [[0., 1.]]]])
467-
gamma = np.array([2., 1.])
468-
beta = np.array([0., 1.])
469-
moving_mean = np.array([1., 1.])
470-
moving_variance = np.array([1., 2.])
471467
elif model_name == "dense":
472468
x_shape = (4,)
473469
kernel = np.array([[1., 1.], [1., 0.], [1., 1.], [0., 1.]])
474-
gamma = np.array([2., 1.])
475-
beta = np.array([0., 1.])
476-
moving_mean = np.array([1., 1.])
477-
moving_variance = np.array([1., 2.])
470+
gamma = np.array([2., 1.])
471+
beta = np.array([0., 1.])
472+
moving_mean = np.array([1., 1.])
473+
moving_variance = np.array([1., 2.])
478474
iteration = np.array(-1)
479475

480476
train_ds = generate_dataset(train_size=10, batch_size=10, input_shape=x_shape,

0 commit comments

Comments
 (0)