From 84c9456fd6addfdd092edf74d6b008df14e00f3f Mon Sep 17 00:00:00 2001 From: Trevor Bekolay Date: Mon, 5 Jun 2023 14:23:14 -0500 Subject: [PATCH] Extract _clone_model_with_weights from quantize_apply This aids downstream repos that implement fixes for various cloning issues by making this function able to be monkey-patched. --- .../python/core/quantization/keras/quantize.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize.py index 1d0c4154a..27bdb2057 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize.py @@ -285,6 +285,13 @@ def quantize_annotate_layer(to_annotate, quantize_config=None): layer=to_annotate, quantize_config=quantize_config) +def _clone_model_with_weights(model_to_clone): + cloned_model = keras.models.clone_model(model_to_clone) + cloned_model.set_weights(model_to_clone.get_weights()) + + return cloned_model + + @metrics.MonitorBoolGauge('quantize_apply_usage') def quantize_apply( model, @@ -361,12 +368,6 @@ def quantize_apply( 'been built yet. Please call `model.build(input_shape)` ' 'before quantizing your model.') - def _clone_model_with_weights(model_to_clone): - cloned_model = keras.models.clone_model(model_to_clone) - cloned_model.set_weights(model_to_clone.get_weights()) - - return cloned_model - def _extract_original_model(model_to_unwrap): """Extracts original model by removing wrappers.""" layer_quantize_map = {}