Skip to content

Commit 091f891

Browse files
Revert "Mixed Precision batchnorm fix (pytorch#77089)"
This reverts commit bf61b79. Reverted pytorch#77089 on behalf of https://github.com/suo
1 parent 9ba3b83 commit 091f891

File tree

5 files changed

+4
-254
lines changed

5 files changed

+4
-254
lines changed

Diff for: test/distributed/fsdp/test_fsdp_mixed_precision.py

-106
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch
99
import torch.cuda.nccl as nccl
1010
import torch.nn as nn
11-
import torch.nn.functional as F
1211
from torch import distributed as dist
1312
from torch.distributed.fsdp import (
1413
FullyShardedDataParallel as FSDP,
@@ -17,8 +16,6 @@
1716
BackwardPrefetch,
1817
ShardingStrategy,
1918
)
20-
from torch.distributed.fsdp.wrap import default_auto_wrap_policy
21-
from torch.nn.modules.batchnorm import _BatchNorm
2219
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
2320
from torch.testing._internal.common_fsdp import (
2421
FSDPTest,
@@ -29,18 +26,9 @@
2926
parametrize,
3027
run_tests,
3128
TEST_WITH_DEV_DBG_ASAN,
32-
sandcastle_skip_if,
3329
)
3430
from torch.testing._internal.common_cuda import CUDA11OrLater
3531

36-
try:
37-
import torchvision
38-
HAS_TORCHVISION = True
39-
except ImportError:
40-
HAS_TORCHVISION = False
41-
42-
skipIfNoTorchVision = sandcastle_skip_if(not HAS_TORCHVISION, "no torchvision")
43-
4432

4533
if not dist.is_available():
4634
print("Distributed not available, skipping tests", file=sys.stderr)
@@ -517,100 +505,6 @@ def test_mp_embedding_params_and_reduce_diff(self):
517505
)
518506
self._test_mixed_precision_embedding_table(mp_config=params_and_reduce_different)
519507

520-
@skip_if_lt_x_gpu(2)
521-
@skipIfNoTorchVision
522-
def test_mixed_precision_resnet(self):
523-
"""
524-
End to end test to ensure mixed precision + auto_wrap works
525-
for ResNet model.
526-
"""
527-
resnet_model = torchvision.models.resnet50().cuda()
528-
resnet_model = nn.SyncBatchNorm.convert_sync_batchnorm(
529-
resnet_model,
530-
process_group=dist.distributed_c10d._get_default_group()
531-
)
532-
n_bn = sum(1 if isinstance(x, _BatchNorm) else 0 for x in resnet_model.modules())
533-
inp = torch.ones(1, 3, 1000, 1000, device='cuda')
534-
mp_config = MixedPrecision(
535-
param_dtype=torch.float16,
536-
reduce_dtype=torch.float16,
537-
buffer_dtype=torch.float16,
538-
)
539-
fsdp = FSDP(
540-
resnet_model,
541-
auto_wrap_policy=default_auto_wrap_policy,
542-
mixed_precision=mp_config
543-
)
544-
# Batchnorm units should be wrapped individually. Validate this by
545-
# ensuring there are equal no. of FSDP units that are BN as BN units
546-
# in original resnet model.
547-
fsdp_bn = 0
548-
for module in fsdp.fsdp_modules(fsdp):
549-
wrapped_module = module.module.module
550-
if isinstance(wrapped_module, _BatchNorm):
551-
fsdp_bn += 1
552-
553-
self.assertEqual(fsdp_bn, n_bn)
554-
# Would throw type mismatch issue without mixed precision autowrapping.
555-
loss = fsdp(inp).sum()
556-
loss.backward()
557-
558-
@skip_if_lt_x_gpu(2)
559-
@parametrize("convert_sync_bn", [True, False])
560-
def test_mp_batchnorm(self, convert_sync_bn):
561-
class BatchNormNet(nn.Module):
562-
def __init__(self, affine=True):
563-
super(BatchNormNet, self).__init__()
564-
self.fc1 = nn.Linear(2, 40, bias=False)
565-
self.bn = nn.BatchNorm1d(4, affine=affine)
566-
self.fc2 = nn.Linear(40, 4, bias=False)
567-
568-
def forward(self, x):
569-
x = torch.reshape(self.fc1(x), (-1, 4, 10))
570-
x = self.bn(x)
571-
x = torch.reshape(x, (-1, 40))
572-
x = self.fc2(x)
573-
return F.softmax(x, dim=1)
574-
575-
def never_wrap_policy(*args, **kwargs):
576-
return False
577-
578-
net = BatchNormNet().cuda()
579-
if convert_sync_bn:
580-
net = nn.SyncBatchNorm.convert_sync_batchnorm(net)
581-
# FSDP detects that mixed precision + batchnorm will cause issues
582-
# and thus wrap batchnorm in a distinct FSDP unit that does not
583-
# use mixed precision.
584-
mp_config = MixedPrecision(
585-
param_dtype=torch.float16,
586-
reduce_dtype=torch.float16,
587-
buffer_dtype=torch.float16,
588-
)
589-
with self.assertWarnsRegex(
590-
expected_warning=UserWarning,
591-
expected_regex="BatchNorm units will be wrapped as a separate"
592-
):
593-
model = FSDP(
594-
net,
595-
mixed_precision=mp_config,
596-
auto_wrap_policy=never_wrap_policy,
597-
)
598-
599-
bn = model.bn
600-
self.assertTrue(isinstance(bn, FSDP))
601-
# policy should not have wrapped any other submodules
602-
self.assertFalse(isinstance(model.fc1, FSDP))
603-
self.assertFalse(isinstance(model.fc2, FSDP))
604-
self.assertEqual(None, bn.mixed_precision)
605-
self.assertNotEqual(None, model.mixed_precision)
606-
607-
inp = torch.randn((1, 2), device='cuda')
608-
# Without FSDP BN mixed precision fix, this would result in
609-
# RuntimeError: Expected counts to have type Half but got Float
610-
# for syncBN
611-
model(inp).sum().backward()
612-
613-
614508
class TestFSDPMixedPrecisionUnsharded(TestFSDPMixedPrecision):
615509
"""
616510
Smaller test suite for unshared param (i.e. world_size == 1) case.

Diff for: test/distributed/fsdp/test_wrap.py

-67
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
always_wrap_policy,
1818
size_based_auto_wrap_policy,
1919
enable_wrap,
20-
_or_policy,
2120
wrap,
22-
_wrap_batchnorm_individually,
2321
transformer_auto_wrap_policy,
2422
)
2523
from torch.testing._internal.common_distributed import (
@@ -42,15 +40,6 @@
4240
)
4341
from torch.nn import TransformerEncoderLayer, TransformerDecoderLayer
4442

45-
class BatchNormNet(nn.Module):
46-
def __init__(self):
47-
super().__init__()
48-
self.lin = nn.Linear(10, 10, bias=False)
49-
self.bn1 = nn.BatchNorm1d(10)
50-
self.bn2 = nn.BatchNorm2d(10)
51-
self.bn3 = nn.BatchNorm3d(10)
52-
self.sync_bn = nn.SyncBatchNorm(10)
53-
5443
class WrapMethod(Enum):
5544
FSDP_CTOR = auto()
5645
# FSDP_CTOR is the supported way forward, but keep WRAP_API in case we miss
@@ -148,62 +137,6 @@ def test_error_already_wrapped(self, nested, fsdp_init_mode):
148137
with self.assertRaisesRegex(ValueError, "to NOT be FullyShardedDataParallel"):
149138
mod = FSDP(wrapped_fsdp, auto_wrap_policy=size_based_auto_wrap_policy)
150139

151-
@skip_if_lt_x_gpu(2)
152-
@parametrize("use_or_policy", [True, False])
153-
def test_wrap_batchnorm_individually(self, use_or_policy):
154-
def never_wrap_policy(*args, **kwargs):
155-
return False
156-
157-
policy = (
158-
functools.partial(
159-
_or_policy,
160-
policies=[never_wrap_policy, _wrap_batchnorm_individually]
161-
) if use_or_policy else _wrap_batchnorm_individually
162-
)
163-
model = BatchNormNet()
164-
fsdp = FSDP(model, auto_wrap_policy=policy)
165-
# Batchnorms should be wrapped
166-
for layer in [fsdp.bn1, fsdp.bn2, fsdp.bn3, fsdp.sync_bn]:
167-
self.assertTrue(isinstance(layer, FSDP))
168-
169-
self.assertFalse(isinstance(fsdp.lin, FSDP))
170-
171-
@skip_if_lt_x_gpu(2)
172-
def test_bn_always_wrapped_individually(self):
173-
"""
174-
Ensures that by using _or_policy with _wrap_batchnorm_individually, even
175-
if the other policy results in a module containing a BN unit being
176-
wrapped, the contained BN unit will still be individually wrapped.
177-
"""
178-
class MyModule(nn.Module):
179-
def __init__(self):
180-
super().__init__()
181-
self.bn_container = BatchNormNet()
182-
183-
def wrap_bn_container(module, recurse, *args, **kwargs):
184-
if recurse:
185-
return True
186-
return isinstance(module, BatchNormNet)
187-
188-
my_policy = functools.partial(
189-
_or_policy,
190-
policies=[wrap_bn_container, _wrap_batchnorm_individually]
191-
)
192-
mod = MyModule()
193-
fsdp = FSDP(mod, auto_wrap_policy=my_policy)
194-
195-
# Wrapping should be FSDP(FSDP(BatchNormNet(FSDP(BN))))
196-
# and not FSDP(FSDP(BatchNormNet(BN))) (in the latter the inner
197-
# BN is not individually wrapped.)
198-
199-
for bn in [
200-
fsdp.bn_container.bn1,
201-
fsdp.bn_container.bn2,
202-
fsdp.bn_container.bn3,
203-
fsdp.bn_container.sync_bn
204-
]:
205-
self.assertTrue(isinstance(bn, FSDP))
206-
207140
@skip_if_lt_x_gpu(2)
208141
@parametrize(
209142
"cpu_offload",

Diff for: torch/distributed/fsdp/_utils.py

-11
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,10 @@
22
from typing import Any, Callable, Dict, List, Set, Tuple, Union
33

44
import torch
5-
from torch.nn.modules.batchnorm import _BatchNorm
6-
75
from torch.nn.utils.rnn import PackedSequence
86

97
"""Useful functions to deal with tensor types with other python container types."""
108

11-
def _contains_batchnorm(module):
12-
return any(
13-
isinstance(mod, _BatchNorm) for mod in module.modules()
14-
)
15-
16-
def _override_batchnorm_mixed_precision(module):
17-
for mod in module.modules():
18-
if isinstance(mod, _BatchNorm):
19-
mod._wrap_overrides = {"mixed_precision": None} # type: ignore[assignment]
209

2110
def _apply_to_tensors(
2211
fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]

Diff for: torch/distributed/fsdp/fully_sharded_data_parallel.py

+3-30
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,14 @@
5050
_process_pos_dim_tensor_state,
5151
_unflatten_optim_state,
5252
)
53-
from ._utils import (
54-
_apply_to_modules, _apply_to_tensors, _replace_by_prefix,
55-
_override_batchnorm_mixed_precision, _contains_batchnorm
56-
)
53+
from ._utils import _apply_to_modules, _apply_to_tensors, _replace_by_prefix
5754
from .flatten_params_wrapper import (
5855
FLAT_PARAM,
5956
FPW_MODULE,
6057
FlatParameter,
6158
FlattenParamsWrapper,
6259
)
63-
from .wrap import _recursive_wrap, _wrap_batchnorm_individually, _or_policy
60+
from .wrap import _recursive_wrap
6461

6562
if TYPE_CHECKING:
6663
from collections import OrderedDict # noqa: F401
@@ -499,14 +496,6 @@ class FullyShardedDataParallel(nn.Module):
499496
that only floating point data is cast to the reduced precision. This allows
500497
users potential memory saving and training speedup while trading off
501498
accuracy during model training. If ``None``, no mixed precision is applied.
502-
Note that if ``mixed_precision`` is enabled for FSDP model that
503-
contains ``BatchNorm`` with ``auto_wrap_policy``, FSDP will take
504-
care to disable mixed precision for ``BatchNorm`` units by wrapping
505-
them separately in their own FSDP unit with ``mixed_precision=None``.
506-
This is done because several ``BatchNorm`` kernels do not implement
507-
reduced type support at the moment. If individually wrapping the model,
508-
users must take care to set ``mixed_precision=None`` for
509-
``BatchNorm`` units.
510499
(Default: ``None``)
511500
ignored_modules (Optional[Iterable[torch.nn.Module]]): Modules whose
512501
own parameters and child modules' parameters and buffers are
@@ -591,25 +580,9 @@ def __init__(
591580
check_fn=lambda mod: not isinstance(mod, FullyShardedDataParallel),
592581
err_fn=lambda mod: f"Expected {mod} to NOT be FullyShardedDataParallel if auto_wrap is enabled.",
593582
)
594-
if mixed_precision is not None and _contains_batchnorm(module):
595-
_override_batchnorm_mixed_precision(module)
596-
policy_to_use = functools.partial(
597-
_or_policy,
598-
policies=[_wrap_batchnorm_individually, auto_wrap_policy]
599-
)
600-
warnings.warn(
601-
"Mixed precision was specified for FSDP module with"
602-
" batchnorm submodules wrapped via ``auto_wrap_policy``."
603-
" BatchNorm units will be wrapped as a separate FSDP unit,"
604-
" with mixed_precision disabled (i.e. set to ``None``)"
605-
" as several BatchNorm kernels would raise errors when"
606-
" operating on reduced precision inputs."
607-
)
608-
else:
609-
policy_to_use = auto_wrap_policy
610583
_recursive_wrap(
611584
module,
612-
auto_wrap_policy=policy_to_use,
585+
auto_wrap_policy=auto_wrap_policy,
613586
wrapper_cls=FullyShardedDataParallel,
614587
ignored_modules=ignored_modules,
615588
ignored_params=ignored_params,

Diff for: torch/distributed/fsdp/wrap.py

+1-40
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
)
1818

1919
import torch.nn as nn
20-
from torch.nn.modules.batchnorm import _BatchNorm
2120

2221

2322
def always_wrap_policy(*args, **kwargs) -> bool:
@@ -29,6 +28,7 @@ def always_wrap_policy(*args, **kwargs) -> bool:
2928
"""
3029
return True
3130

31+
3232
def transformer_auto_wrap_policy(
3333
module: nn.Module,
3434
recurse: bool,
@@ -72,37 +72,6 @@ def transformer_auto_wrap_policy(
7272
# if not recursing, decide whether we should wrap for the leaf node or reminder
7373
return isinstance(module, tuple(transformer_layer_cls))
7474

75-
def _wrap_batchnorm_individually(
76-
module: nn.Module,
77-
recurse: bool,
78-
*args,
79-
**kwargs,
80-
) -> bool:
81-
"""
82-
A policy that wraps ``BatchNorm`` instances in their own FSDP unit.
83-
"""
84-
if recurse:
85-
# always recurse
86-
return True
87-
else:
88-
# if not recursing, decide whether we should wrap based on whether it is a
89-
# BN layer or not.
90-
return isinstance(module, _BatchNorm)
91-
92-
def _or_policy(
93-
module: nn.Module,
94-
recurse: bool,
95-
unwrapped_params: int,
96-
policies,
97-
) -> bool:
98-
"""
99-
A policy that wraps ``module`` if any policy in the passed in iterable of
100-
``policies`` returns ``True``.
101-
"""
102-
return any(
103-
policy(module, recurse, unwrapped_params) for policy in policies
104-
)
105-
10675

10776
def size_based_auto_wrap_policy(
10877
module: nn.Module,
@@ -241,14 +210,6 @@ def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
241210

242211
def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module:
243212
assert wrapper_cls is not None
244-
if hasattr(module, '_wrap_overrides'):
245-
# If module has a _wrap_overrides attribute, we force overriding the
246-
# FSDP config with these attributes for this module. Currently this
247-
# is only used to disable mixed precision for BatchNorm when
248-
# auto_wrapping.
249-
overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type]
250-
return wrapper_cls(module, **overrides)
251-
252213
return wrapper_cls(module, **kwargs)
253214

254215

0 commit comments

Comments
 (0)