|
10 | 10 | from tensorflow.python.ops import variable_scope
|
11 | 11 | from backend_test_base import Tf2OnnxBackendTestBase
|
12 | 12 | from common import check_tf_min_version, unittest_main, check_opset_after_tf_version, \
|
13 |
| - skip_tf2, skip_tf_versions, check_op_count |
| 13 | + skip_tf2, skip_tf_versions, check_op_count, skip_tfjs |
14 | 14 |
|
15 | 15 | from tf2onnx.tf_loader import is_tf2
|
16 | 16 |
|
@@ -51,6 +51,7 @@ def new_graph_validator(g):
|
51 | 51 | # Skip checks for tflite graphs (no ":" in outputs)
|
52 | 52 | return good
|
53 | 53 | good = good and check_op_count(g, "LSTM", require_lstm_count, disabled=False)
|
| 54 | + # If LSTM op rewriter failed to work, Loop op will be shown in general. |
54 | 55 | good = good and check_op_count(g, "Loop", 0, disabled=False)
|
55 | 56 | return good
|
56 | 57 | try:
|
@@ -774,5 +775,23 @@ def func(x):
|
774 | 775 | return tf.identity(y[0], name="output"), tf.identity(y[1], name="output1")
|
775 | 776 | self.run_test_case(func, {"input:0": x_val}, [], ["output:0", "output1:0"], rtol=1e-05, atol=1e-06)
|
776 | 777 |
|
| 778 | + @check_tf_min_version("2.0") |
| 779 | + @skip_tfjs("TFJS converts model incorrectly") |
| 780 | + def test_keras_lstm_sigmoid_dropout(self): |
| 781 | + in_shape = [16, 16] |
| 782 | + batch_size = 2 |
| 783 | + x_val = np.random.uniform(size=[batch_size] + in_shape).astype(np.float32) |
| 784 | + |
| 785 | + model = tf.keras.models.Sequential() |
| 786 | + model_in = tf.keras.layers.Input(shape=tuple(in_shape), name="input") |
| 787 | + lstm = tf.keras.layers.LSTM(16, activation='sigmoid', dropout=0.1) |
| 788 | + model.add(model_in) |
| 789 | + model.add(lstm) |
| 790 | + |
| 791 | + def func(x): |
| 792 | + y = model(x) |
| 793 | + return tf.identity(y[0], name="output") |
| 794 | + self.run_test_case(func, {"input:0": x_val}, [], ["output:0"], rtol=1e-05, atol=1e-06) |
| 795 | + |
777 | 796 | if __name__ == '__main__':
|
778 | 797 | unittest_main()
|
0 commit comments