diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py index fa36dd00c..adeb46574 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py @@ -667,7 +667,11 @@ def testStripClusteringSequentialModel(self): stripped_model = cluster.strip_clustering(clustered_model) self.assertEqual(self._count_clustered_layers(stripped_model), 0) - self.assertEqual(model.get_config(), stripped_model.get_config()) + model_config = model.get_config() + for layer in model_config['layers']: + # New serialization format includes `build_config` in wrapper + layer.pop('build_config', None) + self.assertEqual(model_config, stripped_model.get_config()) def testClusterStrippingFunctionalModel(self): """Verifies that stripping the clustering wrappers from a functional model produces the expected config.""" diff --git a/tensorflow_model_optimization/python/core/quantization/keras/BUILD b/tensorflow_model_optimization/python/core/quantization/keras/BUILD index f7f400113..9e9da85d0 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/BUILD +++ b/tensorflow_model_optimization/python/core/quantization/keras/BUILD @@ -72,6 +72,7 @@ py_strict_test( visibility = ["//visibility:public"], deps = [ ":quantizers", + ":utils", # absl/testing:parameterized dep1, # numpy dep1, # tensorflow dep1, @@ -87,9 +88,10 @@ py_strict_library( srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ + ":quantizers", + ":utils", # six dep1, # tensorflow dep1, - "//tensorflow_model_optimization/python/core/quantization/keras:quantizers", ], ) @@ -125,6 +127,7 @@ py_strict_library( srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ + ":utils", # tensorflow dep1, ], ) @@ -152,6 +155,7 @@ py_strict_library( srcs_version = "PY3", visibility = ["//visibility:public"], deps = [ + ":utils", # tensorflow dep1, "//tensorflow_model_optimization/python/core/keras:utils", ], @@ -167,6 +171,7 @@ py_strict_test( deps = [ ":quantize_aware_activation", ":quantizers", + ":utils", # absl/testing:parameterized dep1, # numpy dep1, # tensorflow dep1, @@ -182,6 +187,7 @@ py_strict_library( visibility = ["//visibility:public"], deps = [ ":quantizers", + ":utils", # tensorflow dep1, "//tensorflow_model_optimization/python/core/keras:utils", ], @@ -211,6 +217,7 @@ py_strict_library( visibility = ["//visibility:public"], deps = [ ":quantize_aware_activation", + ":utils", # tensorflow dep1, # python/util tensorflow dep2, "//tensorflow_model_optimization/python/core/keras:metrics", @@ -249,6 +256,7 @@ py_strict_library( ":quantize_layer", ":quantize_wrapper", ":quantizers", + ":utils", # tensorflow dep1, "//tensorflow_model_optimization/python/core/keras:metrics", "//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry", diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py index e603c2d7f..64aa7421d 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py @@ -25,14 +25,15 @@ import tensorflow as tf from tensorflow_model_optimization.python.core.quantization.keras import quantizers +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry keras = tf.keras K = tf.keras.backend l = tf.keras.layers -deserialize_keras_object = tf.keras.utils.deserialize_keras_object -serialize_keras_object = tf.keras.utils.serialize_keras_object +deserialize_keras_object = quantize_utils.deserialize_keras_object +serialize_keras_object = quantize_utils.serialize_keras_object class _TestHelper(object): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py index 35363ec89..498360c74 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py @@ -24,6 +24,7 @@ from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer from tensorflow_model_optimization.python.core.quantization.keras import quantizers +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_configs from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms @@ -67,13 +68,17 @@ def _get_params(conv_layer, bn_layer, relu_layer=None): list(conv_layer['config'].items()) + list(bn_layer['config'].items())) if relu_layer is not None: - params['post_activation'] = keras.layers.deserialize(relu_layer) + params['post_activation'] = quantize_utils.deserialize_layer( + relu_layer, use_legacy_format=True + ) return params def _get_layer_node(fused_layer, weights): - layer_config = keras.layers.serialize(fused_layer) + layer_config = quantize_utils.serialize_layer( + fused_layer, use_legacy_format=True + ) layer_config['name'] = layer_config['config']['name'] # This config tracks which layers get quantized, and whether they have a # custom QuantizeConfig. @@ -118,7 +123,10 @@ def _replace(self, bn_layer_node, conv_layer_node): return bn_layer_node conv_layer_node.layer['config']['activation'] = ( - keras.activations.serialize(quantize_aware_activation.NoOpActivation())) + quantize_utils.serialize_activation( + quantize_aware_activation.NoOpActivation(), use_legacy_format=True + ) + ) bn_layer_node.metadata['quantize_config'] = ( default_8bit_quantize_configs.Default8BitOutputQuantizeConfig()) @@ -180,7 +188,10 @@ def _replace(self, relu_layer_node, bn_layer_node, conv_layer_node): return relu_layer_node conv_layer_node.layer['config']['activation'] = ( - keras.activations.serialize(quantize_aware_activation.NoOpActivation())) + quantize_utils.serialize_activation( + quantize_aware_activation.NoOpActivation(), use_legacy_format=True + ) + ) bn_layer_node.metadata['quantize_config'] = ( default_8bit_quantize_configs.NoOpQuantizeConfig()) @@ -261,7 +272,10 @@ def _replace(self, bn_layer_node, dense_layer_node): return bn_layer_node dense_layer_node.layer['config']['activation'] = ( - keras.activations.serialize(quantize_aware_activation.NoOpActivation())) + quantize_utils.serialize_activation( + quantize_aware_activation.NoOpActivation(), use_legacy_format=True + ) + ) bn_layer_node.metadata['quantize_config'] = ( default_8bit_quantize_configs.Default8BitOutputQuantizeConfig()) @@ -297,7 +311,10 @@ def _replace(self, relu_layer_node, bn_layer_node, dense_layer_node): return relu_layer_node dense_layer_node.layer['config']['activation'] = ( - keras.activations.serialize(quantize_aware_activation.NoOpActivation())) + quantize_utils.serialize_activation( + quantize_aware_activation.NoOpActivation(), use_legacy_format=True + ) + ) bn_layer_node.metadata['quantize_config'] = ( default_8bit_quantize_configs.NoOpQuantizeConfig()) @@ -408,7 +425,9 @@ def replacement(self, match_layer): else: spatial_dim = 2 - sepconv2d_layer_config = keras.layers.serialize(sepconv2d_layer) + sepconv2d_layer_config = quantize_utils.serialize_layer( + sepconv2d_layer, use_legacy_format=True + ) sepconv2d_layer_config['name'] = sepconv2d_layer.name # Needed to ensure these new layers are considered for quantization. @@ -420,7 +439,9 @@ def replacement(self, match_layer): expand_layer = tf.keras.layers.Lambda( lambda x: tf.expand_dims(x, spatial_dim), name=self._get_name('sepconv1d_expand')) - expand_layer_config = keras.layers.serialize(expand_layer) + expand_layer_config = quantize_utils.serialize_layer( + expand_layer, use_legacy_format=True + ) expand_layer_config['name'] = expand_layer.name expand_layer_metadata = { 'quantize_config': default_8bit_quantize_configs.NoOpQuantizeConfig()} @@ -428,7 +449,9 @@ def replacement(self, match_layer): squeeze_layer = tf.keras.layers.Lambda( lambda x: tf.squeeze(x, [spatial_dim]), name=self._get_name('sepconv1d_squeeze')) - squeeze_layer_config = keras.layers.serialize(squeeze_layer) + squeeze_layer_config = quantize_utils.serialize_layer( + squeeze_layer, use_legacy_format=True + ) squeeze_layer_config['name'] = squeeze_layer.name squeeze_layer_metadata = { 'quantize_config': default_8bit_quantize_configs.NoOpQuantizeConfig()} @@ -493,7 +516,9 @@ def replacement(self, match_layer): ) dconv_weights = collections.OrderedDict() dconv_weights['depthwise_kernel:0'] = sepconv_weights[0] - dconv_layer_config = keras.layers.serialize(dconv_layer) + dconv_layer_config = quantize_utils.serialize_layer( + dconv_layer, use_legacy_format=True + ) dconv_layer_config['name'] = dconv_layer.name # Needed to ensure these new layers are considered for quantization. dconv_metadata = {'quantize_config': None} @@ -521,7 +546,9 @@ def replacement(self, match_layer): conv_weights['kernel:0'] = sepconv_weights[1] if sepconv_layer['config']['use_bias']: conv_weights['bias:0'] = sepconv_weights[2] - conv_layer_config = keras.layers.serialize(conv_layer) + conv_layer_config = quantize_utils.serialize_layer( + conv_layer, use_legacy_format=True + ) conv_layer_config['name'] = conv_layer.name # Needed to ensure these new layers are considered for quantization. conv_metadata = {'quantize_config': None} @@ -588,7 +615,9 @@ def replacement(self, match_layer): quant_layer = quantize_layer.QuantizeLayer( quantizers.AllValuesQuantizer( num_bits=8, per_axis=False, symmetric=False, narrow_range=False)) - layer_config = keras.layers.serialize(quant_layer) + layer_config = quantize_utils.serialize_layer( + quant_layer, use_legacy_format=True + ) layer_config['name'] = quant_layer.name quant_layer_node = LayerNode( diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py index 9ab539287..f096ab963 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms_test.py @@ -707,4 +707,6 @@ def testConcatConcatTransformDisablesOutput(self): if __name__ == '__main__': + if hasattr(tf.keras.__internal__, 'enable_unsafe_deserialization'): + tf.keras.__internal__.enable_unsafe_deserialization() tf.test.main() diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py index 01be17c6d..96eae7e39 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/quantize_numerical_test.py @@ -200,4 +200,6 @@ def testModelEndToEnd(self, model_fn): if __name__ == '__main__': + if hasattr(tf.keras.__internal__, 'enable_unsafe_deserialization'): + tf.keras.__internal__.enable_unsafe_deserialization() tf.test.main() diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py index cc59fef52..231ef4695 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py @@ -25,14 +25,15 @@ import tensorflow as tf from tensorflow_model_optimization.python.core.quantization.keras import quantizers +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry as n_bit_registry keras = tf.keras K = tf.keras.backend l = tf.keras.layers -deserialize_keras_object = tf.keras.utils.deserialize_keras_object -serialize_keras_object = tf.keras.utils.serialize_keras_object +deserialize_keras_object = quantize_utils.deserialize_keras_object +serialize_keras_object = quantize_utils.serialize_keras_object class _TestHelper(object): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py index c2eeef955..320fb7267 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py @@ -24,6 +24,7 @@ from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer from tensorflow_model_optimization.python.core.quantization.keras import quantizers +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_configs as configs from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms @@ -67,13 +68,17 @@ def _get_params(conv_layer, bn_layer, relu_layer=None): list(conv_layer['config'].items()) + list(bn_layer['config'].items())) if relu_layer is not None: - params['post_activation'] = keras.layers.deserialize(relu_layer) + params['post_activation'] = quantize_utils.deserialize_layer( + relu_layer, use_legacy_format=True + ) return params def _get_layer_node(fused_layer, weights): - layer_config = keras.layers.serialize(fused_layer) + layer_config = quantize_utils.serialize_layer( + fused_layer, use_legacy_format=True + ) layer_config['name'] = layer_config['config']['name'] # This config tracks which layers get quantized, and whether they have a # custom QuantizeConfig. @@ -118,7 +123,10 @@ def _replace(self, bn_layer_node, conv_layer_node): return bn_layer_node conv_layer_node.layer['config']['activation'] = ( - keras.activations.serialize(quantize_aware_activation.NoOpActivation())) + quantize_utils.serialize_activation( + quantize_aware_activation.NoOpActivation(), use_legacy_format=True + ) + ) bn_layer_node.metadata['quantize_config'] = ( configs.DefaultNBitOutputQuantizeConfig( num_bits_weight=self._num_bits_weight, @@ -190,7 +198,10 @@ def _replace(self, relu_layer_node, bn_layer_node, conv_layer_node): return relu_layer_node conv_layer_node.layer['config']['activation'] = ( - keras.activations.serialize(quantize_aware_activation.NoOpActivation())) + quantize_utils.serialize_activation( + quantize_aware_activation.NoOpActivation(), use_legacy_format=True + ) + ) bn_layer_node.metadata['quantize_config'] = ( configs.NoOpQuantizeConfig()) @@ -284,7 +295,10 @@ def _replace(self, bn_layer_node, dense_layer_node): return bn_layer_node dense_layer_node.layer['config']['activation'] = ( - keras.activations.serialize(quantize_aware_activation.NoOpActivation())) + quantize_utils.serialize_activation( + quantize_aware_activation.NoOpActivation(), use_legacy_format=True + ) + ) bn_layer_node.metadata['quantize_config'] = ( configs.DefaultNBitOutputQuantizeConfig( num_bits_weight=self._num_bits_weight, @@ -324,7 +338,10 @@ def _replace(self, relu_layer_node, bn_layer_node, dense_layer_node): return relu_layer_node dense_layer_node.layer['config']['activation'] = ( - keras.activations.serialize(quantize_aware_activation.NoOpActivation())) + quantize_utils.serialize_activation( + quantize_aware_activation.NoOpActivation(), use_legacy_format=True + ) + ) bn_layer_node.metadata['quantize_config'] = ( configs.NoOpQuantizeConfig()) @@ -439,7 +456,9 @@ def replacement(self, match_layer): else: spatial_dim = 2 - sepconv2d_layer_config = keras.layers.serialize(sepconv2d_layer) + sepconv2d_layer_config = quantize_utils.serialize_layer( + sepconv2d_layer, use_legacy_format=True + ) sepconv2d_layer_config['name'] = sepconv2d_layer.name # Needed to ensure these new layers are considered for quantization. @@ -451,7 +470,9 @@ def replacement(self, match_layer): expand_layer = tf.keras.layers.Lambda( lambda x: tf.expand_dims(x, spatial_dim), name=self._get_name('sepconv1d_expand')) - expand_layer_config = keras.layers.serialize(expand_layer) + expand_layer_config = quantize_utils.serialize_layer( + expand_layer, use_legacy_format=True + ) expand_layer_config['name'] = expand_layer.name expand_layer_metadata = { 'quantize_config': @@ -460,7 +481,9 @@ def replacement(self, match_layer): squeeze_layer = tf.keras.layers.Lambda( lambda x: tf.squeeze(x, [spatial_dim]), name=self._get_name('sepconv1d_squeeze')) - squeeze_layer_config = keras.layers.serialize(squeeze_layer) + squeeze_layer_config = quantize_utils.serialize_layer( + squeeze_layer, use_legacy_format=True + ) squeeze_layer_config['name'] = squeeze_layer.name squeeze_layer_metadata = { 'quantize_config': @@ -530,7 +553,9 @@ def replacement(self, match_layer): ) dconv_weights = collections.OrderedDict() dconv_weights['depthwise_kernel:0'] = sepconv_weights[0] - dconv_layer_config = keras.layers.serialize(dconv_layer) + dconv_layer_config = quantize_utils.serialize_layer( + dconv_layer, use_legacy_format=True + ) dconv_layer_config['name'] = dconv_layer.name # Needed to ensure these new layers are considered for quantization. dconv_metadata = {'quantize_config': None} @@ -558,7 +583,9 @@ def replacement(self, match_layer): conv_weights['kernel:0'] = sepconv_weights[1] if sepconv_layer['config']['use_bias']: conv_weights['bias:0'] = sepconv_weights[2] - conv_layer_config = keras.layers.serialize(conv_layer) + conv_layer_config = quantize_utils.serialize_layer( + conv_layer, use_legacy_format=True + ) conv_layer_config['name'] = conv_layer.name # Needed to ensure these new layers are considered for quantization. conv_metadata = {'quantize_config': None} @@ -634,7 +661,9 @@ def replacement(self, match_layer): quantizers.AllValuesQuantizer( num_bits=self._num_bits_activation, per_axis=False, symmetric=False, narrow_range=False)) # activation/output - layer_config = keras.layers.serialize(quant_layer) + layer_config = quantize_utils.serialize_layer( + quant_layer, use_legacy_format=True + ) layer_config['name'] = quant_layer.name quant_layer_node = LayerNode( diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms_test.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms_test.py index 6b52a6b46..e5ba016e3 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms_test.py @@ -580,4 +580,6 @@ def testConcatMultipleLevels(self): if __name__ == '__main__': + if hasattr(tf.keras.__internal__, 'enable_unsafe_deserialization'): + tf.keras.__internal__.enable_unsafe_deserialization() tf.test.main() diff --git a/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/BUILD b/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/BUILD index 2b2153549..c084c6451 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/BUILD +++ b/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/BUILD @@ -66,5 +66,6 @@ py_strict_test( # absl/testing:parameterized dep1, # numpy dep1, # tensorflow dep1, + "//tensorflow_model_optimization/python/core/quantization/keras:utils", ], ) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py b/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py index 063c3d113..853d14ae1 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py @@ -23,6 +23,7 @@ import numpy as np import tensorflow as tf +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import model_transformer from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms @@ -159,7 +160,9 @@ def replacement(self, match_layer): match_layer_config = match_layer.layer['config'] my_dense_layer = self.MyDense(**match_layer_config) - replace_layer = keras.layers.serialize(my_dense_layer) + replace_layer = quantize_utils.serialize_layer( + my_dense_layer, use_legacy_format=True + ) replace_layer['name'] = replace_layer['config']['name'] return LayerNode(replace_layer, match_layer.weights, []) @@ -176,8 +179,11 @@ def testReplaceSingleLayerWithSingleLayer_OneOccurrence(self, model_type): # build_input_shape is a TensorShape object and the two objects are not # considered the same even though the shapes are the same. - self._assert_config(model.get_config(), transformed_model.get_config(), - ['class_name', 'build_input_shape']) + self._assert_config( + model.get_config(), + transformed_model.get_config(), + ['class_name', 'build_input_shape', 'module', 'registered_name'], + ) self.assertEqual( 'MyDense', @@ -209,8 +215,11 @@ def testReplaceSingleLayerWithSingleLayer_MultipleOccurrences( # build_input_shape is a TensorShape object and the two objects are not # considered the same even though the shapes are the same. - self._assert_config(model.get_config(), transformed_model.get_config(), - ['class_name', 'build_input_shape']) + self._assert_config( + model.get_config(), + transformed_model.get_config(), + ['class_name', 'build_input_shape', 'module', 'registered_name'], + ) self.assertEqual( 'MyDense', @@ -268,7 +277,9 @@ def replacement(self, match_layer): match_layer_config['use_bias'] = False new_dense_layer = keras.layers.Dense(**match_layer_config) - replace_layer = keras.layers.serialize(new_dense_layer) + replace_layer = quantize_utils.serialize_layer( + new_dense_layer, use_legacy_format=True + ) replace_layer['name'] = replace_layer['config']['name'] return LayerNode(replace_layer, match_layer_weights, []) @@ -311,7 +322,9 @@ def replacement(self, match_layer): match_layer_config = match_layer.layer['config'] my_dense_layer = QuantizedCustomDense(**match_layer_config) - replace_layer = keras.layers.serialize(my_dense_layer) + replace_layer = quantize_utils.serialize_layer( + my_dense_layer, use_legacy_format=True + ) replace_layer['name'] = replace_layer['config']['name'] return LayerNode(replace_layer, match_layer.weights, []) @@ -355,7 +368,9 @@ def pattern(self): def replacement(self, match_layer): activation_layer = keras.layers.Activation('linear') - layer_config = keras.layers.serialize(activation_layer) + layer_config = quantize_utils.serialize_layer( + activation_layer, use_legacy_format=True + ) layer_config['name'] = activation_layer.name activation_layer_node = LayerNode( @@ -397,7 +412,9 @@ def pattern(self): def replacement(self, match_layer): activation_layer = keras.layers.Activation('linear') - layer_config = keras.layers.serialize(activation_layer) + layer_config = quantize_utils.serialize_layer( + activation_layer, use_legacy_format=True + ) layer_config['name'] = activation_layer.name activation_layer_node = LayerNode( @@ -435,7 +452,9 @@ def replacement(self, match_layer): new_dense_layer = keras.layers.Dense(**dense_layer_config) - replace_layer = keras.layers.serialize(new_dense_layer) + replace_layer = quantize_utils.serialize_layer( + new_dense_layer, use_legacy_format=True + ) replace_layer['name'] = replace_layer['config']['name'] return LayerNode(replace_layer, dense_layer_weights, []) @@ -569,7 +588,9 @@ def pattern(self): return LayerPattern('ReLU') def replacement(self, match_layer): - replace_layer = keras.layers.serialize(keras.layers.Softmax()) + replace_layer = quantize_utils.serialize_layer( + keras.layers.Softmax(), use_legacy_format=True + ) replace_layer['name'] = replace_layer['config']['name'] return LayerNode(replace_layer) @@ -579,7 +600,9 @@ def pattern(self): return LayerPattern('Softmax') def replacement(self, match_layer): - replace_layer = keras.layers.serialize(keras.layers.ELU()) + replace_layer = quantize_utils.serialize_layer( + keras.layers.ELU(), use_legacy_format=True + ) replace_layer['name'] = replace_layer['config']['name'] return LayerNode(replace_layer) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize.py index 38fdc6924..f2b4f41f7 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize.py @@ -24,6 +24,7 @@ from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer from tensorflow_model_optimization.python.core.quantization.keras import quantize_wrapper from tensorflow_model_optimization.python.core.quantization.keras import quantizers +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_scheme from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry @@ -525,7 +526,7 @@ def _wrap_fixed_range( 'init_min': init_min, 'init_max': init_max, 'narrow_range': narrow_range}) - return tf.keras.utils.serialize_keras_object(config) + return quantize_utils.serialize_keras_object(config) def _is_serialized_node_data(nested): @@ -601,8 +602,9 @@ def fix_input_output_range( init_min=input_min, init_max=input_max, narrow_range=narrow_range) - serialized_fixed_input_quantizer = tf.keras.utils.serialize_keras_object( - fixed_input_quantizer) + serialized_fixed_input_quantizer = quantize_utils.serialize_keras_object( + fixed_input_quantizer + ) if _is_functional_model(model): input_layer_list = _nested_to_flatten_node_data_list(config['input_layers']) @@ -685,8 +687,9 @@ def remove_input_range(model): """ config = model.get_config() no_input_quantizer = quantizers.NoQuantizer() - serialized_input_quantizer = tf.keras.utils.serialize_keras_object( - no_input_quantizer) + serialized_input_quantizer = quantize_utils.serialize_keras_object( + no_input_quantizer + ) if _is_functional_model(model): input_layer_list = _nested_to_flatten_node_data_list(config['input_layers']) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py index e41686221..8359aeebd 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py @@ -23,8 +23,10 @@ import tensorflow as tf -deserialize_keras_object = tf.keras.utils.deserialize_keras_object -serialize_keras_object = tf.keras.utils.serialize_keras_object +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils + +deserialize_keras_object = quantize_utils.deserialize_keras_object +serialize_keras_object = quantize_utils.serialize_keras_object class QuantizeAnnotate(tf.keras.layers.Wrapper): @@ -112,7 +114,12 @@ def from_config(cls, config): module_objects=globals(), custom_objects=None) - layer = tf.keras.layers.deserialize(config.pop('layer')) + layer_config = config.pop('layer') + use_legacy_format = 'module' not in layer_config + + layer = quantize_utils.deserialize_layer( + layer_config, use_legacy_format=use_legacy_format + ) return cls(layer=layer, quantize_config=quantize_config, **config) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation.py index f52ba062b..8a4d58914 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation.py @@ -21,6 +21,7 @@ import tensorflow as tf from tensorflow_model_optimization.python.core.keras import utils +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils activations = tf.keras.activations @@ -39,6 +40,10 @@ def __call__(self, x): def get_config(self): return {} + @classmethod + def from_config(cls, config): + return cls(**config) + def __eq__(self, other): if not other or not isinstance(other, NoOpActivation): return False @@ -183,9 +188,13 @@ def quantizer_fn(x=x, @classmethod def from_config(cls, config): - return activations.deserialize(config['activation']) + return quantize_utils.deserialize_activation( + config['activation'], use_legacy_format=True + ) def get_config(self): return { - 'activation': activations.serialize(self.activation) + 'activation': quantize_utils.serialize_activation( + self.activation, use_legacy_format=True + ) } diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py index 2a3a5ad36..6c547ec67 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py @@ -25,12 +25,13 @@ from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation from tensorflow_model_optimization.python.core.quantization.keras import quantizers +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils keras = tf.keras activations = tf.keras.activations K = tf.keras.backend -deserialize_keras_object = tf.keras.utils.deserialize_keras_object -serialize_keras_object = tf.keras.utils.serialize_keras_object +deserialize_keras_object = quantize_utils.deserialize_keras_object +serialize_keras_object = quantize_utils.serialize_keras_object QuantizeAwareActivation = quantize_aware_activation.QuantizeAwareActivation MovingAverageQuantizer = quantizers.MovingAverageQuantizer @@ -154,13 +155,13 @@ def testSerializationReturnsWrappedActivation( 'config': activation_config } self.assertEqual(expected_config, serialized_quantize_activation) - - deserialized_activation = deserialize_keras_object( - serialized_quantize_activation, - custom_objects={ - 'QuantizeAwareActivation': QuantizeAwareActivation, - 'NoOpActivation': quantize_aware_activation.NoOpActivation - }) + with tf.keras.utils.custom_object_scope({ + 'QuantizeAwareActivation': QuantizeAwareActivation, + 'NoOpActivation': quantize_aware_activation.NoOpActivation, + }): + deserialized_activation = deserialize_keras_object( + serialized_quantize_activation + ) self.assertEqual(activation, deserialized_activation) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_config.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_config.py index bf94e130d..5664ec4c5 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_config.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_config.py @@ -19,6 +19,7 @@ import tensorflow as tf from tensorflow_model_optimization.python.core.quantization.keras import quantizers +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils @six.add_metaclass(abc.ABCMeta) @@ -259,13 +260,14 @@ def get_output_quantizers(self, layer): def get_config(self): return { - 'config': tf.keras.utils.serialize_keras_object(self.config), + 'config': quantize_utils.serialize_keras_object(self.config), 'num_bits': self.num_bits, 'init_min': self.init_min, 'init_max': self.init_max, - 'narrow_range': self.narrow_range} + 'narrow_range': self.narrow_range, + } @classmethod def from_config(cls, config): - config['config'] = tf.keras.utils.deserialize_keras_object(config['config']) + config['config'] = quantize_utils.deserialize_keras_object(config['config']) return cls(**config) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_integration_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_integration_test.py index 618da1234..4fe633b48 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_integration_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_integration_test.py @@ -63,11 +63,28 @@ def _batch(self, dims, batch_size): dims[0] = batch_size return dims - def _assert_models_equal(self, model1, model2): + def _assert_models_equal(self, model1, model2, exclude_keys=None): + def _remove_keys(config): + """Removes keys specified in `exclude_keys`.""" + for key in exclude_keys: + if key in config: + del config[key] + + for _, v in config.items(): + if isinstance(v, dict): + _remove_keys(v) + + if isinstance(v, list): + for item in v: + if isinstance(item, dict): + _remove_keys(item) + model1_config = model1.get_config() - model1_config.pop('build_input_shape', None) model2_config = model2.get_config() - model2_config.pop('build_input_shape', None) + exclude_keys = exclude_keys or [] + exclude_keys += ['build_input_shape', 'build_config'] + _remove_keys(model1_config) + _remove_keys(model2_config) self.assertEqual(model1_config, model2_config) self.assertAllClose(model1.get_weights(), model2.get_weights()) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py index be59458ca..1393388e6 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py @@ -26,9 +26,10 @@ from tensorflow_model_optimization.python.core.keras import utils from tensorflow_model_optimization.python.core.quantization.keras import quantizers +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils -serialize_keras_object = tf.keras.utils.serialize_keras_object -deserialize_keras_object = tf.keras.utils.deserialize_keras_object +serialize_keras_object = quantize_utils.serialize_keras_object +deserialize_keras_object = quantize_utils.deserialize_keras_object class QuantizeLayer(tf.keras.layers.Layer): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py index 02aaeb958..db39a5317 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py @@ -33,9 +33,10 @@ from tensorflow_model_optimization.python.core.keras import metrics from tensorflow_model_optimization.python.core.keras import utils from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils -deserialize_keras_object = tf.keras.utils.deserialize_keras_object -serialize_keras_object = tf.keras.utils.serialize_keras_object +deserialize_keras_object = quantize_utils.deserialize_keras_object +serialize_keras_object = quantize_utils.serialize_keras_object class QuantizeWrapper(tf.keras.layers.Wrapper): @@ -206,7 +207,12 @@ def from_config(cls, config): module_objects=globals(), custom_objects=None) - layer = tf.keras.layers.deserialize(config.pop('layer')) + layer_config = config.pop('layer') + use_legacy_format = 'module' not in layer_config + + layer = quantize_utils.deserialize_layer( + layer_config, use_legacy_format=use_legacy_format + ) return cls(layer=layer, quantize_config=quantize_config, **config) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py index 7df0567f2..628749f60 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py @@ -25,9 +25,10 @@ from tensorflow_model_optimization.python.core.keras import compat from tensorflow_model_optimization.python.core.quantization.keras import quantizers +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils -deserialize_keras_object = tf.keras.utils.deserialize_keras_object -serialize_keras_object = tf.keras.utils.serialize_keras_object +deserialize_keras_object = quantize_utils.deserialize_keras_object +serialize_keras_object = quantize_utils.serialize_keras_object @parameterized.parameters( diff --git a/tensorflow_model_optimization/python/core/quantization/keras/utils.py b/tensorflow_model_optimization/python/core/quantization/keras/utils.py index cf08bf13f..d5d7ac98d 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/utils.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/utils.py @@ -15,6 +15,7 @@ # pylint: disable=protected-access """Quantization specific utilities for generating, saving, testing, and evaluating models.""" +import inspect import tempfile import tensorflow as tf @@ -22,6 +23,72 @@ from tensorflow_model_optimization.python.core.keras import compat +def serialize_keras_object(obj): + if hasattr(tf.keras.utils, "legacy"): + return tf.keras.utils.legacy.serialize_keras_object(obj) + else: + return tf.keras.utils.serialize_keras_object(obj) + + +def deserialize_keras_object( + config, module_objects=None, custom_objects=None, printable_module_name=None +): + if hasattr(tf.keras.utils, "legacy"): + return tf.keras.utils.legacy.deserialize_keras_object( + config, custom_objects, module_objects, printable_module_name + ) + else: + return tf.keras.utils.deserialize_keras_object( + config, custom_objects, module_objects, printable_module_name + ) + + +def serialize_layer(layer, use_legacy_format=False): + if ( + "use_legacy_format" + in inspect.getfullargspec(tf.keras.layers.serialize).args + ): + return tf.keras.layers.serialize(layer, use_legacy_format=use_legacy_format) + else: + return tf.keras.layers.serialize(layer) + + +def deserialize_layer(config, use_legacy_format=False): + if ( + "use_legacy_format" + in inspect.getfullargspec(tf.keras.layers.deserialize).args + ): + return tf.keras.layers.deserialize( + config, use_legacy_format=use_legacy_format + ) + else: + return tf.keras.layers.deserialize(config) + + +def serialize_activation(activation, use_legacy_format=False): + if ( + "use_legacy_format" + in inspect.getfullargspec(tf.keras.activations.serialize).args + ): + return tf.keras.activations.serialize( + activation, use_legacy_format=use_legacy_format + ) + else: + return tf.keras.activations.serialize(activation) + + +def deserialize_activation(config, use_legacy_format=False): + if ( + "use_legacy_format" + in inspect.getfullargspec(tf.keras.activations.deserialize).args + ): + return tf.keras.activations.deserialize( + config, use_legacy_format=use_legacy_format + ) + else: + return tf.keras.activations.deserialize(config) + + def convert_keras_to_tflite(model, output_path, custom_objects=None, diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/BUILD b/tensorflow_model_optimization/python/core/sparsity/keras/BUILD index 970764e4d..4097eab19 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/BUILD +++ b/tensorflow_model_optimization/python/core/sparsity/keras/BUILD @@ -64,6 +64,7 @@ py_strict_library( deps = [ ":pruning_wrapper", # tensorflow dep1, + "//tensorflow_model_optimization/python/core/quantization/keras:utils", ], ) @@ -89,6 +90,7 @@ py_strict_test( # absl/testing:parameterized dep1, # tensorflow dep1, "//tensorflow_model_optimization/python/core/keras:compat", + "//tensorflow_model_optimization/python/core/quantization/keras:utils", ], ) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py index 340078e2e..39c5fe1fe 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py @@ -18,6 +18,7 @@ import abc import tensorflow as tf +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper layers = tf.keras.layers @@ -216,9 +217,9 @@ def _check_layer_support(self, layer): elif isinstance(layer, layers.UpSampling2D): return layer.interpolation == 'bilinear' elif isinstance(layer, layers.Activation): - return activations.serialize(layer.activation) in ('relu', 'relu6', - 'leaky_relu', 'elu', - 'sigmoid') + return quantize_utils.serialize_activation( + layer.activation, use_legacy_format=True + ) in ('relu', 'relu6', 'leaky_relu', 'elu', 'sigmoid') elif layer.__class__.__name__ == 'TFOpLambda': return layer.function in (tf.identity, tf.__operators__.add, tf.math.add, tf.math.subtract, tf.math.multiply) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py index 916d080ab..36562b792 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py @@ -19,6 +19,7 @@ # TODO(b/139939526): move to public API. from tensorflow_model_optimization.python.core.keras import compat +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule @@ -242,12 +243,13 @@ def testSerializeDeserialize(self): sparsity = pruning_schedule.ConstantSparsity(0.7, 10, 20, 10) config = sparsity.get_config() - sparsity_deserialized = tf.keras.utils.deserialize_keras_object( + sparsity_deserialized = quantize_utils.deserialize_keras_object( config, custom_objects={ 'ConstantSparsity': pruning_schedule.ConstantSparsity, - 'PolynomialDecay': pruning_schedule.PolynomialDecay - }) + 'PolynomialDecay': pruning_schedule.PolynomialDecay, + }, + ) self.assertEqual(sparsity.__dict__, sparsity_deserialized.__dict__) @@ -278,12 +280,13 @@ def testSerializeDeserialize(self): sparsity = pruning_schedule.PolynomialDecay(0.2, 0.6, 10, 20, 5, 10) config = sparsity.get_config() - sparsity_deserialized = tf.keras.utils.deserialize_keras_object( + sparsity_deserialized = quantize_utils.deserialize_keras_object( config, custom_objects={ 'ConstantSparsity': pruning_schedule.ConstantSparsity, - 'PolynomialDecay': pruning_schedule.PolynomialDecay - }) + 'PolynomialDecay': pruning_schedule.PolynomialDecay, + }, + ) self.assertEqual(sparsity.__dict__, sparsity_deserialized.__dict__) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py index 89ff765f7..35064a911 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py @@ -121,13 +121,17 @@ def testPruneModel(self): # Test serialization model_config = self.model.get_config() + for layer in model_config['layers']: + layer.pop('build_config', None) self.assertEqual( model_config, self.model.__class__.from_config( - model_config, + self.model.get_config(), custom_objects={ 'PruneLowMagnitude': pruning_wrapper.PruneLowMagnitude - }).get_config()) + }, + ).get_config(), + ) def testCustomLayerNonPrunable(self): layer = CustomLayer(input_dim=16, output_dim=32)