From b7d330f85103675c36f49c3d8964bd1bb1374464 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sat, 19 Apr 2025 19:10:45 +0800 Subject: [PATCH 1/3] Add support for sharded weights. --- keras_hub/src/utils/keras_utils.py | 14 ++++ keras_hub/src/utils/preset_utils.py | 86 ++++++++++++++++++++++-- keras_hub/src/utils/preset_utils_test.py | 46 +++++++++++++ 3 files changed, 140 insertions(+), 6 deletions(-) diff --git a/keras_hub/src/utils/keras_utils.py b/keras_hub/src/utils/keras_utils.py index e1e40e489a..da65e33df0 100644 --- a/keras_hub/src/utils/keras_utils.py +++ b/keras_hub/src/utils/keras_utils.py @@ -1,3 +1,4 @@ +import inspect import sys import keras @@ -147,3 +148,16 @@ def get_gpu_names(): ] else: return [""] + + +def sharded_weights_available(): + """Whether sharded weights serialization is available. + + Returns: + `True` if sharded weights are available, `False` otherwise. + """ + save_weights_signature = inspect.signature(keras.saving.save_weights) + if "max_shard_size" in save_weights_signature.parameters: + return True + else: + return False diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 7fa4b3bb00..6bb21fb9eb 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -1,7 +1,9 @@ import collections import datetime +import functools import inspect import json +import math import os import re @@ -10,6 +12,7 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.utils.keras_utils import print_msg +from keras_hub.src.utils.keras_utils import sharded_weights_available try: import kagglehub @@ -48,6 +51,7 @@ # Weight file names. MODEL_WEIGHTS_FILE = "model.weights.h5" TASK_WEIGHTS_FILE = "task.weights.h5" +SHARDED_MODEL_WEIGHTS_CONFIG_FILE = "model.weights.json" # HuggingFace filenames. README_FILE = "README.md" @@ -647,7 +651,7 @@ def load_backbone(self, cls, load_weights, **kwargs): backbone = self._load_serialized_object(self.config, **kwargs) if load_weights: jax_memory_cleanup(backbone) - backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE)) + self._load_backbone_weights(backbone) return backbone def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs): @@ -697,8 +701,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs): task.load_task_weights(task_weights) else: jax_memory_cleanup(task.backbone) - backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE) - task.backbone.load_weights(backbone_weights) + self._load_backbone_weights(task.backbone) return task def load_preprocessor( @@ -726,18 +729,64 @@ def _load_serialized_object(self, config, **kwargs): config["config"] = {**config["config"], **kwargs} return keras.saving.deserialize_keras_object(config) + def _get_sharded_filenames(self, config_path): + with open(config_path, encoding="utf-8") as config_file: + config = json.load(config_file) + weight_map = config["weight_map"] + return sorted(set(weight_map.values())) + + def _load_backbone_weights(self, backbone): + # Detect if the backbone is sharded or not. + has_single_file_weights = check_file_exists( + self.preset, MODEL_WEIGHTS_FILE + ) + if has_single_file_weights: + filepath = get_file(self.preset, MODEL_WEIGHTS_FILE) + else: + if not sharded_weights_available(): + raise RuntimeError( + "Sharded weights loading is not supported in the current " + f"Keras version {keras.__version__}. " + "Please update to a newer version." + ) + filepath = get_file(self.preset, SHARDED_MODEL_WEIGHTS_CONFIG_FILE) + sharded_filenames = self._get_sharded_filenames(filepath) + for sharded_filename in sharded_filenames: + # Download the sharded weights. + _ = get_file(self.preset, sharded_filename) + backbone.load_weights(filepath) + class KerasPresetSaver: def __init__(self, preset_dir): os.makedirs(preset_dir, exist_ok=True) self.preset_dir = preset_dir - def save_backbone(self, backbone): + def save_backbone(self, backbone, max_shard_size=10): self._save_serialized_object(backbone, config_file=CONFIG_FILE) - backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE) - backbone.save_weights(backbone_weight_path) self._save_metadata(backbone) + # Save the weights. + backbone_size_in_bytes = self._get_variables_size_in_bytes( + backbone.variables + ) + backbone_size_in_gb = backbone_size_in_bytes / (1024**3) + # If the size of the backbone is larger than `max_shard_size`, save + # sharded weights. + if sharded_weights_available() and backbone_size_in_gb > max_shard_size: + backbone_sharded_weights_config_path = os.path.join( + self.preset_dir, SHARDED_MODEL_WEIGHTS_CONFIG_FILE + ) + backbone.save_weights( + backbone_sharded_weights_config_path, + max_shard_size=max_shard_size, + ) + else: + backbone_weight_path = os.path.join( + self.preset_dir, MODEL_WEIGHTS_FILE + ) + backbone.save_weights(backbone_weight_path) + def save_tokenizer(self, tokenizer): config_file = TOKENIZER_CONFIG_FILE if hasattr(tokenizer, "config_file"): @@ -823,3 +872,28 @@ def _save_metadata(self, layer): metadata_path = os.path.join(self.preset_dir, METADATA_FILE) with open(metadata_path, "w") as metadata_file: metadata_file.write(json.dumps(metadata, indent=4)) + + def _get_variables_size_in_bytes(self, variables): + @functools.lru_cache(512) + def _compute_memory_size(shape, dtype): + weight_counts = math.prod(shape) + dtype = keras.backend.standardize_dtype(dtype) + dtype_size = int( + ( + dtype.replace("bfloat", "") + .replace("float", "") + .replace("uint", "") + .replace("int", "") + .replace("bool", "1") + ) + ) + return weight_counts * dtype_size + + unique_variables = {} + for v in variables: + if id(v) not in unique_variables: + unique_variables[id(v)] = (v.shape, v.dtype) + total_memory_size = 0 + for shape, dtype in unique_variables.values(): + total_memory_size += _compute_memory_size(shape, dtype) + return total_memory_size / 8 diff --git a/keras_hub/src/utils/preset_utils_test.py b/keras_hub/src/utils/preset_utils_test.py index 998dcadfa9..14a12e8ba2 100644 --- a/keras_hub/src/utils/preset_utils_test.py +++ b/keras_hub/src/utils/preset_utils_test.py @@ -10,12 +10,58 @@ ) from keras_hub.src.models.bert.bert_backbone import BertBackbone from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer +from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone from keras_hub.src.tests.test_case import TestCase +from keras_hub.src.utils.keras_utils import sharded_weights_available from keras_hub.src.utils.preset_utils import CONFIG_FILE +from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.preset_utils import upload_preset class PresetUtilsTest(TestCase): + @pytest.mark.large + def test_sharded_weights(self): + if not sharded_weights_available(): + self.skipTest("Sharded weights are not available.") + + # Gemma2 config. + init_kwargs = { + "vocabulary_size": 4096, # 256128 + "num_layers": 24, # 46 + "num_query_heads": 16, # 32 + "num_key_value_heads": 8, # 16 + "hidden_dim": 64, # 4608 + "intermediate_dim": 128, # 73728 + "head_dim": 8, # 128 + "sliding_window_size": 5, # 4096 + "attention_logit_soft_cap": 50, + "final_logit_soft_cap": 30, + "layer_norm_epsilon": 1e-6, + "query_head_dim_normalize": False, + "use_post_ffw_norm": True, + "use_post_attention_norm": True, + "use_sliding_window_attention": True, + } + backbone = GemmaBackbone(**init_kwargs) # ~4.4MB + + # Save the sharded weights. + preset_dir = self.get_temp_dir() + preset_saver = get_preset_saver(preset_dir) + preset_saver.save_backbone(backbone, max_shard_size=0.002) + self.assertTrue( + os.path.exists(os.path.join(preset_dir, "model.weights.json")) + ) + self.assertTrue( + os.path.exists(os.path.join(preset_dir, "model_00000.weights.h5")) + ) + + # Load the sharded weights. + revived_backbone = GemmaBackbone.from_preset(preset_dir) + for v1, v2 in zip( + backbone.trainable_variables, revived_backbone.trainable_variables + ): + self.assertAllClose(v1, v2) + @pytest.mark.large def test_preset_errors(self): with self.assertRaisesRegex(ValueError, "must be a string"): From 2b70a6f440913a709447084041f897e74c69e0a9 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Fri, 25 Apr 2025 14:08:25 +0800 Subject: [PATCH 2/3] Add `max_shard_size` to Backbone and Task. Simplify the test. --- keras_hub/src/models/backbone.py | 7 ++++-- keras_hub/src/models/task.py | 7 ++++-- keras_hub/src/utils/preset_utils.py | 30 +++++++++++------------- keras_hub/src/utils/preset_utils_test.py | 21 ++++++++--------- 4 files changed, 34 insertions(+), 31 deletions(-) diff --git a/keras_hub/src/models/backbone.py b/keras_hub/src/models/backbone.py index 85cde29b3c..a810f3430c 100644 --- a/keras_hub/src/models/backbone.py +++ b/keras_hub/src/models/backbone.py @@ -177,14 +177,17 @@ class like `keras_hub.models.Backbone.from_preset()`, or from ) return loader.load_backbone(backbone_cls, load_weights, **kwargs) - def save_to_preset(self, preset_dir): + def save_to_preset(self, preset_dir, max_shard_size=10): """Save backbone to a preset directory. Args: preset_dir: The path to the local model preset directory. + max_shard_size: `int` or `float`. Maximum size in GB for each + sharded file. If `None`, no sharding will be done. Defaults to + `10`. """ saver = get_preset_saver(preset_dir) - saver.save_backbone(self) + saver.save_backbone(self, max_shard_size=max_shard_size) def get_lora_target_names(self): """Returns list of layer names which are to be LoRA-fied. diff --git a/keras_hub/src/models/task.py b/keras_hub/src/models/task.py index 5920776232..d273759b46 100644 --- a/keras_hub/src/models/task.py +++ b/keras_hub/src/models/task.py @@ -236,14 +236,17 @@ def save_task_weights(self, filepath): objects_to_skip=backbone_layer_ids, ) - def save_to_preset(self, preset_dir): + def save_to_preset(self, preset_dir, max_shard_size=10): """Save task to a preset directory. Args: preset_dir: The path to the local model preset directory. + max_shard_size: `int` or `float`. Maximum size in GB for each + sharded file. If `None`, no sharding will be done. Defaults to + `10`. """ saver = get_preset_saver(preset_dir) - saver.save_task(self) + saver.save_task(self, max_shard_size=max_shard_size) @property def layers(self): diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 6bb21fb9eb..9cb4bd9889 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -1,6 +1,5 @@ import collections import datetime -import functools import inspect import json import math @@ -804,7 +803,7 @@ def save_audio_converter(self, converter): def save_image_converter(self, converter): self._save_serialized_object(converter, IMAGE_CONVERTER_CONFIG_FILE) - def save_task(self, task): + def save_task(self, task, max_shard_size=10): # Save task specific config and weights. self._save_serialized_object(task, TASK_CONFIG_FILE) if task.has_task_weights(): @@ -812,10 +811,12 @@ def save_task(self, task): task.save_task_weights(task_weight_path) # Save backbone. if hasattr(task.backbone, "save_to_preset"): - task.backbone.save_to_preset(self.preset_dir) + task.backbone.save_to_preset( + self.preset_dir, max_shard_size=max_shard_size + ) else: # Allow saving a `keras.Model` that is not a backbone subclass. - self.save_backbone(task.backbone) + self.save_backbone(task.backbone, max_shard_size=max_shard_size) # Save preprocessor. if task.preprocessor and hasattr(task.preprocessor, "save_to_preset"): task.preprocessor.save_to_preset(self.preset_dir) @@ -874,20 +875,17 @@ def _save_metadata(self, layer): metadata_file.write(json.dumps(metadata, indent=4)) def _get_variables_size_in_bytes(self, variables): - @functools.lru_cache(512) def _compute_memory_size(shape, dtype): + def _get_dtype_size(dtype): + dtype = keras.backend.standardize_dtype(dtype) + # If dtype is bool, return 1 immediately. + if dtype == "bool": + return 1 + # Else, we extract the bit size from the string. + return int(re.sub(r"bfloat|float|uint|int", "", dtype)) + weight_counts = math.prod(shape) - dtype = keras.backend.standardize_dtype(dtype) - dtype_size = int( - ( - dtype.replace("bfloat", "") - .replace("float", "") - .replace("uint", "") - .replace("int", "") - .replace("bool", "1") - ) - ) - return weight_counts * dtype_size + return weight_counts * _get_dtype_size(dtype) unique_variables = {} for v in variables: diff --git a/keras_hub/src/utils/preset_utils_test.py b/keras_hub/src/utils/preset_utils_test.py index 14a12e8ba2..92c763e052 100644 --- a/keras_hub/src/utils/preset_utils_test.py +++ b/keras_hub/src/utils/preset_utils_test.py @@ -14,7 +14,6 @@ from keras_hub.src.tests.test_case import TestCase from keras_hub.src.utils.keras_utils import sharded_weights_available from keras_hub.src.utils.preset_utils import CONFIG_FILE -from keras_hub.src.utils.preset_utils import get_preset_saver from keras_hub.src.utils.preset_utils import upload_preset @@ -26,13 +25,13 @@ def test_sharded_weights(self): # Gemma2 config. init_kwargs = { - "vocabulary_size": 4096, # 256128 - "num_layers": 24, # 46 - "num_query_heads": 16, # 32 - "num_key_value_heads": 8, # 16 - "hidden_dim": 64, # 4608 - "intermediate_dim": 128, # 73728 - "head_dim": 8, # 128 + "vocabulary_size": 1024, # 256128 + "num_layers": 12, # 46 + "num_query_heads": 8, # 32 + "num_key_value_heads": 4, # 16 + "hidden_dim": 32, # 4608 + "intermediate_dim": 64, # 73728 + "head_dim": 4, # 128 "sliding_window_size": 5, # 4096 "attention_logit_soft_cap": 50, "final_logit_soft_cap": 30, @@ -42,12 +41,12 @@ def test_sharded_weights(self): "use_post_attention_norm": True, "use_sliding_window_attention": True, } - backbone = GemmaBackbone(**init_kwargs) # ~4.4MB + backbone = GemmaBackbone(**init_kwargs) # ~422KB + backbone.summary() # Save the sharded weights. preset_dir = self.get_temp_dir() - preset_saver = get_preset_saver(preset_dir) - preset_saver.save_backbone(backbone, max_shard_size=0.002) + backbone.save_to_preset(preset_dir, max_shard_size=0.0002) self.assertTrue( os.path.exists(os.path.join(preset_dir, "model.weights.json")) ) From 0fff9d1edd6a75d3f39fb6bf04782aad97154820 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 29 Apr 2025 08:26:19 +0800 Subject: [PATCH 3/3] Split the functions from `KerasPresetSaver`. Clean up the comments. --- keras_hub/src/utils/keras_utils.py | 5 +---- keras_hub/src/utils/preset_utils.py | 16 ++------------ keras_hub/src/utils/preset_utils_test.py | 18 +++++++-------- keras_hub/src/utils/tensor_utils.py | 28 +++++++++++++++++++++++- 4 files changed, 38 insertions(+), 29 deletions(-) diff --git a/keras_hub/src/utils/keras_utils.py b/keras_hub/src/utils/keras_utils.py index da65e33df0..21607ffccb 100644 --- a/keras_hub/src/utils/keras_utils.py +++ b/keras_hub/src/utils/keras_utils.py @@ -157,7 +157,4 @@ def sharded_weights_available(): `True` if sharded weights are available, `False` otherwise. """ save_weights_signature = inspect.signature(keras.saving.save_weights) - if "max_shard_size" in save_weights_signature.parameters: - return True - else: - return False + return "max_shard_size" in save_weights_signature.parameters diff --git a/keras_hub/src/utils/preset_utils.py b/keras_hub/src/utils/preset_utils.py index 9cb4bd9889..8423238b5c 100644 --- a/keras_hub/src/utils/preset_utils.py +++ b/keras_hub/src/utils/preset_utils.py @@ -2,7 +2,6 @@ import datetime import inspect import json -import math import os import re @@ -12,6 +11,7 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.utils.keras_utils import print_msg from keras_hub.src.utils.keras_utils import sharded_weights_available +from keras_hub.src.utils.tensor_utils import get_tensor_size_in_bits try: import kagglehub @@ -875,23 +875,11 @@ def _save_metadata(self, layer): metadata_file.write(json.dumps(metadata, indent=4)) def _get_variables_size_in_bytes(self, variables): - def _compute_memory_size(shape, dtype): - def _get_dtype_size(dtype): - dtype = keras.backend.standardize_dtype(dtype) - # If dtype is bool, return 1 immediately. - if dtype == "bool": - return 1 - # Else, we extract the bit size from the string. - return int(re.sub(r"bfloat|float|uint|int", "", dtype)) - - weight_counts = math.prod(shape) - return weight_counts * _get_dtype_size(dtype) - unique_variables = {} for v in variables: if id(v) not in unique_variables: unique_variables[id(v)] = (v.shape, v.dtype) total_memory_size = 0 for shape, dtype in unique_variables.values(): - total_memory_size += _compute_memory_size(shape, dtype) + total_memory_size += get_tensor_size_in_bits(shape, dtype) return total_memory_size / 8 diff --git a/keras_hub/src/utils/preset_utils_test.py b/keras_hub/src/utils/preset_utils_test.py index 92c763e052..738682a286 100644 --- a/keras_hub/src/utils/preset_utils_test.py +++ b/keras_hub/src/utils/preset_utils_test.py @@ -23,16 +23,15 @@ def test_sharded_weights(self): if not sharded_weights_available(): self.skipTest("Sharded weights are not available.") - # Gemma2 config. init_kwargs = { - "vocabulary_size": 1024, # 256128 - "num_layers": 12, # 46 - "num_query_heads": 8, # 32 - "num_key_value_heads": 4, # 16 - "hidden_dim": 32, # 4608 - "intermediate_dim": 64, # 73728 - "head_dim": 4, # 128 - "sliding_window_size": 5, # 4096 + "vocabulary_size": 1024, + "num_layers": 12, + "num_query_heads": 8, + "num_key_value_heads": 4, + "hidden_dim": 32, + "intermediate_dim": 64, + "head_dim": 4, + "sliding_window_size": 5, "attention_logit_soft_cap": 50, "final_logit_soft_cap": 30, "layer_norm_epsilon": 1e-6, @@ -42,7 +41,6 @@ def test_sharded_weights(self): "use_sliding_window_attention": True, } backbone = GemmaBackbone(**init_kwargs) # ~422KB - backbone.summary() # Save the sharded weights. preset_dir = self.get_temp_dir() diff --git a/keras_hub/src/utils/tensor_utils.py b/keras_hub/src/utils/tensor_utils.py index a602963cf0..5588328ead 100644 --- a/keras_hub/src/utils/tensor_utils.py +++ b/keras_hub/src/utils/tensor_utils.py @@ -1,6 +1,8 @@ import contextlib import functools import inspect +import math +import re import threading import keras @@ -305,6 +307,29 @@ def is_string_dtype(dtype): return "string" in keras.backend.standardize_dtype(dtype) +def get_dtype_size_in_bits(dtype): + """Get the size of a given dtype in bits.""" + dtype = keras.backend.standardize_dtype(dtype) + # If dtype is bool, return 1 immediately. + if dtype == "bool": + return 1 + # Else, we extract the bit size from the string. + return int(re.sub(r"bfloat|float|uint|int", "", dtype)) + + +def get_tensor_size_in_bits(shape, dtype): + """Calculate the size given dtype and shape in bits. + + Args: + dtype: The dtype of the tensor. + shape: List of iterables representing the shape of the tensor. + + Returns: + The size of the tensor in bytes. + """ + return math.prod(shape) * get_dtype_size_in_bits(dtype) + + def any_equal(inputs, values, padding_mask): """Return a mask that is True anywhere `inputs` has a value in `values`. @@ -320,7 +345,8 @@ def any_equal(inputs, values, padding_mask): Returns: A tensor with `inputs` shape where each position is True if it contains a value from any `values`. Padding mask will be applied before - returning.""" + returning. + """ output = ops.equal(inputs, values[0]) for value in values[1:]: value_equality = ops.equal(inputs, value)