Skip to content

Commit f4902a4

Browse files
authored
Add lstm sigmoid dropout test for dependency optimization (#2027)
* Add lstm sigmoid_dropout test for dependency optimization Signed-off-by: Deyu Huang <[email protected]>
1 parent 7e3fd01 commit f4902a4

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

tests/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def group_nodes_by_type(graph):
454454

455455

456456
def check_op_count(graph, op_type, expected_count, disabled=True):
457-
# FIXME: after switching to grappler some of the op counts are off. Fix later.
457+
# The grappler optimization may change some of the op counts.
458458
return disabled or len(group_nodes_by_type(graph)[op_type]) == expected_count
459459

460460

tests/test_lstm.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tensorflow.python.ops import variable_scope
1111
from backend_test_base import Tf2OnnxBackendTestBase
1212
from common import check_tf_min_version, unittest_main, check_opset_after_tf_version, \
13-
skip_tf2, skip_tf_versions, check_op_count
13+
skip_tf2, skip_tf_versions, check_op_count, skip_tfjs
1414

1515
from tf2onnx.tf_loader import is_tf2
1616

@@ -51,6 +51,7 @@ def new_graph_validator(g):
5151
# Skip checks for tflite graphs (no ":" in outputs)
5252
return good
5353
good = good and check_op_count(g, "LSTM", require_lstm_count, disabled=False)
54+
# If LSTM op rewriter failed to work, Loop op will be shown in general.
5455
good = good and check_op_count(g, "Loop", 0, disabled=False)
5556
return good
5657
try:
@@ -774,5 +775,23 @@ def func(x):
774775
return tf.identity(y[0], name="output"), tf.identity(y[1], name="output1")
775776
self.run_test_case(func, {"input:0": x_val}, [], ["output:0", "output1:0"], rtol=1e-05, atol=1e-06)
776777

778+
@check_tf_min_version("2.0")
779+
@skip_tfjs("TFJS converts model incorrectly")
780+
def test_keras_lstm_sigmoid_dropout(self):
781+
in_shape = [16, 16]
782+
batch_size = 2
783+
x_val = np.random.uniform(size=[batch_size] + in_shape).astype(np.float32)
784+
785+
model = tf.keras.models.Sequential()
786+
model_in = tf.keras.layers.Input(shape=tuple(in_shape), name="input")
787+
lstm = tf.keras.layers.LSTM(16, activation='sigmoid', dropout=0.1)
788+
model.add(model_in)
789+
model.add(lstm)
790+
791+
def func(x):
792+
y = model(x)
793+
return tf.identity(y[0], name="output")
794+
self.run_test_case(func, {"input:0": x_val}, [], ["output:0"], rtol=1e-05, atol=1e-06)
795+
777796
if __name__ == '__main__':
778797
unittest_main()

0 commit comments

Comments
 (0)