diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index ac156a63..7d8370c9 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -68,10 +68,8 @@ __all__ = [ - "is_module_offloaded", "get_execution_device", "get_offloaded_device", - "update_prefix_dict", "update_parameter_data", "register_offload_parameter", "update_offload_parameter", @@ -116,11 +114,6 @@ def fallback_fn(*args, **kwargs): """ Candidates for Depreciation """ -@check_accelerate(fallback=False) -def is_module_offloaded(module: torch.nn.Module) -> bool: - return has_offloaded_params(module) - - def get_offloaded_device(module: torch.nn.Module) -> torch.device: """ :param module: module to check @@ -133,25 +126,6 @@ def get_offloaded_device(module: torch.nn.Module) -> torch.device: return next(module.parameters()).device -@check_accelerate(fallback=None) -def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor): - """ - Updates the offloaded state dict for a given module. Parameter named key is replaced - by data. This is neccesary because parameter updates for offloaded modules do not - persist automatically between loads. This function only affects the offloaded - state dict and not the current state of the loaded module. - - :param module: module containing the parameter to update - :param key: name of parameter to update - :param data: tensor to update parameter with in the offloaded state dict - """ - if not has_offloaded_params(module): - raise ValueError("Prefix dict is only applicable to offloaded modules") - - weights_map = module._hf_hook.weights_map - offload_to_weights_map(weights_map, key, data) - - def update_parameter_data( module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str ):