diff --git a/monai/networks/blocks/activation_checkpointing.py b/monai/networks/blocks/activation_checkpointing.py new file mode 100644 index 0000000000..283bcd19e1 --- /dev/null +++ b/monai/networks/blocks/activation_checkpointing.py @@ -0,0 +1,41 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import cast + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + + +class ActivationCheckpointWrapper(nn.Module): + """Wrapper applying activation checkpointing to a module during training. + + Args: + module: The module to wrap with activation checkpointing. + """ + + def __init__(self, module: nn.Module) -> None: + super().__init__() + self.module = module + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass with optional activation checkpointing. + + Args: + x: Input tensor. + + Returns: + Output tensor from the wrapped module. + """ + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py index eac0ddab39..a4995ef701 100644 --- a/monai/networks/nets/unet.py +++ b/monai/networks/nets/unet.py @@ -17,11 +17,12 @@ import torch import torch.nn as nn +from monai.networks.blocks.activation_checkpointing import ActivationCheckpointWrapper from monai.networks.blocks.convolutions import Convolution, ResidualUnit from monai.networks.layers.factories import Act, Norm from monai.networks.layers.simplelayers import SkipConnection -__all__ = ["UNet", "Unet"] +__all__ = ["UNet", "Unet", "CheckpointUNet"] class UNet(nn.Module): @@ -298,4 +299,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class CheckpointUNet(UNet): + """UNet variant that wraps internal connection blocks with activation checkpointing. + + See `UNet` for constructor arguments. During training with gradients enabled, + intermediate activations inside encoder-decoder connections are recomputed in + the backward pass to reduce peak memory usage at the cost of extra compute. + """ + + def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: + """Returns connection block with activation checkpointing applied to all components. + + Args: + down_path: encoding half of the layer (will be wrapped with checkpointing). + up_path: decoding half of the layer (will be wrapped with checkpointing). + subblock: block defining the next layer (will be wrapped with checkpointing). + + Returns: + Connection block with all components wrapped for activation checkpointing. + """ + subblock = ActivationCheckpointWrapper(subblock) + down_path = ActivationCheckpointWrapper(down_path) + up_path = ActivationCheckpointWrapper(up_path) + return super()._get_connection_block(down_path, up_path, subblock) + + Unet = UNet diff --git a/tests/networks/nets/test_checkpointunet.py b/tests/networks/nets/test_checkpointunet.py new file mode 100644 index 0000000000..d1de61df75 --- /dev/null +++ b/tests/networks/nets/test_checkpointunet.py @@ -0,0 +1,190 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.layers import Act, Norm +from monai.networks.nets.unet import CheckpointUNet, UNet +from tests.test_utils import test_script_save + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASE_0 = [ # single channel 2D, batch 16, no residual + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 3, + "channels": (16, 32, 64), + "strides": (2, 2), + "num_res_units": 0, + }, + (16, 1, 32, 32), + (16, 3, 32, 32), +] + +TEST_CASE_1 = [ # single channel 2D, batch 16 + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 3, + "channels": (16, 32, 64), + "strides": (2, 2), + "num_res_units": 1, + }, + (16, 1, 32, 32), + (16, 3, 32, 32), +] + +TEST_CASE_2 = [ # single channel 3D, batch 16 + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 3, + "channels": (16, 32, 64), + "strides": (2, 2), + "num_res_units": 1, + }, + (16, 1, 32, 24, 48), + (16, 3, 32, 24, 48), +] + +TEST_CASE_3 = [ # 4-channel 3D, batch 16 + { + "spatial_dims": 3, + "in_channels": 4, + "out_channels": 3, + "channels": (16, 32, 64), + "strides": (2, 2), + "num_res_units": 1, + }, + (16, 4, 32, 64, 48), + (16, 3, 32, 64, 48), +] + +TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalization + { + "spatial_dims": 3, + "in_channels": 4, + "out_channels": 3, + "channels": (16, 32, 64), + "strides": (2, 2), + "num_res_units": 1, + "norm": Norm.BATCH, + }, + (16, 4, 32, 64, 48), + (16, 3, 32, 64, 48), +] + +TEST_CASE_5 = [ # 4-channel 3D, batch 16, LeakyReLU activation + { + "spatial_dims": 3, + "in_channels": 4, + "out_channels": 3, + "channels": (16, 32, 64), + "strides": (2, 2), + "num_res_units": 1, + "act": (Act.LEAKYRELU, {"negative_slope": 0.2}), + "adn_ordering": "NA", + }, + (16, 4, 32, 64, 48), + (16, 3, 32, 64, 48), +] + +TEST_CASE_6 = [ # 4-channel 3D, batch 16, LeakyReLU activation explicit + { + "spatial_dims": 3, + "in_channels": 4, + "out_channels": 3, + "channels": (16, 32, 64), + "strides": (2, 2), + "num_res_units": 1, + "act": (torch.nn.LeakyReLU, {"negative_slope": 0.2}), + }, + (16, 4, 32, 64, 48), + (16, 3, 32, 64, 48), +] + +CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6] + + +class TestCheckpointUNet(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, expected_shape): + net = CheckpointUNet(**input_param).to(device) + with eval_mode(net): + result = net.forward(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_script(self): + """ + TorchScript doesn't support activation-checkpointing (torch.utils.checkpoint) calls inside the module. + To keep the test suite validating TorchScript compatibility, script the plain UNet (which is scriptable), + rather than the CheckpointUNet wrapper that uses checkpointing internals. + """ + net = UNet( + spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 + ) + test_data = torch.randn(16, 1, 32, 32) + test_script_save(net, test_data) + + def test_checkpointing_equivalence_eval(self): + """Ensure that CheckpointUNet matches standard UNet in eval mode (checkpointing inactive).""" + params = dict( + spatial_dims=2, in_channels=1, out_channels=2, channels=(8, 16, 32), strides=(2, 2), num_res_units=1 + ) + + x = torch.randn(2, 1, 32, 32, device=device) + + torch.manual_seed(42) + net_plain = UNet(**params).to(device) + + torch.manual_seed(42) + net_ckpt = CheckpointUNet(**params).to(device) + + # Both in eval mode disables checkpointing logic + with eval_mode(net_ckpt), eval_mode(net_plain): + y_ckpt = net_ckpt(x) + y_plain = net_plain(x) + + # Check shape equality + self.assertEqual(y_ckpt.shape, y_plain.shape) + + # Check numerical similarity + diff = torch.mean(torch.abs(y_ckpt - y_plain)).item() + self.assertLess(diff, 1e-3, f"Eval-mode outputs differ more than expected (mean abs diff={diff:.6f})") + + def test_checkpointing_activates_training(self): + """Ensure checkpointing triggers recomputation under training and gradients propagate.""" + params = dict( + spatial_dims=2, in_channels=1, out_channels=1, channels=(8, 16, 32), strides=(2, 2), num_res_units=1 + ) + + net = CheckpointUNet(**params).to(device) + net.train() + + x = torch.randn(2, 1, 32, 32, device=device, requires_grad=True) + y = net(x) + loss = y.mean() + loss.backward() + + # gradient flow check + grad_norm = sum(p.grad.abs().sum() for p in net.parameters() if p.grad is not None) + self.assertGreater(grad_norm.item(), 0.0) + + +if __name__ == "__main__": + unittest.main()