Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions monai/networks/blocks/activation_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import cast

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint


class ActivationCheckpointWrapper(nn.Module):
"""Wrapper applying activation checkpointing to a module during training.

Args:
module: The module to wrap with activation checkpointing.
"""

def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with optional activation checkpointing.

Args:
x: Input tensor.

Returns:
Output tensor from the wrapped module.
"""
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
Comment on lines +32 to +41
Copy link
Contributor

@coderabbitai coderabbitai bot Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Gate checkpointing to active training passes.
The docstring promises training-only checkpointing, but forward always recomputes, so eval/no-grad still pays the checkpoint dispatch. Wrap the call with self.training, torch.is_grad_enabled(), and an x.requires_grad check, falling back to the plain module call otherwise, to avoid needless recompute overhead while preserving the memory trade-off during training.(docs.pytorch.org)

     def forward(self, x: torch.Tensor) -> torch.Tensor:
         """Forward pass with optional activation checkpointing.
 
         Args:
             x: Input tensor.
 
         Returns:
             Output tensor from the wrapped module.
         """
-        return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
+        if self.training and torch.is_grad_enabled() and x.requires_grad:
+            return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
+        return cast(torch.Tensor, self.module(x))
🤖 Prompt for AI Agents
In monai/networks/blocks/activation_checkpointing.py around lines 32 to 41,
forward always calls checkpoint(self.module, x, use_reentrant=False) even during
eval/no-grad, causing unnecessary recompute; change it to only use
torch.utils.checkpoint when running training and gradients are enabled: check
self.training and torch.is_grad_enabled() and that input tensor x.requires_grad
before calling checkpoint(..., use_reentrant=False); otherwise call and return
self.module(x) directly to avoid unnecessary checkpoint overhead while
preserving training memory savings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just change the docstring here instead of the code change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericspod Understood. Update the docstrings to reflect unconditional checkpointing:

Class docstring:

"""Wrapper applying activation checkpointing to a module.

Checkpointing trades compute for memory by recomputing activations during 
the backward pass instead of storing them. Applied during both training and eval.

Args:
    module: The module to wrap with activation checkpointing.
"""

Forward method docstring:

"""Forward pass with activation checkpointing.

Args:
    x: Input tensor.

Returns:
    Output tensor from the wrapped module.
"""

Key changes: removed "during training" and "optional" language to match the unconditional behavior.

28 changes: 27 additions & 1 deletion monai/networks/nets/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import torch
import torch.nn as nn

from monai.networks.blocks.activation_checkpointing import ActivationCheckpointWrapper
from monai.networks.blocks.convolutions import Convolution, ResidualUnit
from monai.networks.layers.factories import Act, Norm
from monai.networks.layers.simplelayers import SkipConnection

__all__ = ["UNet", "Unet"]
__all__ = ["UNet", "Unet", "CheckpointUNet"]


class UNet(nn.Module):
Expand Down Expand Up @@ -298,4 +299,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class CheckpointUNet(UNet):
"""UNet variant that wraps internal connection blocks with activation checkpointing.

See `UNet` for constructor arguments. During training with gradients enabled,
intermediate activations inside encoder-decoder connections are recomputed in
the backward pass to reduce peak memory usage at the cost of extra compute.
"""

def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
"""Returns connection block with activation checkpointing applied to all components.

Args:
down_path: encoding half of the layer (will be wrapped with checkpointing).
up_path: decoding half of the layer (will be wrapped with checkpointing).
subblock: block defining the next layer (will be wrapped with checkpointing).

Returns:
Connection block with all components wrapped for activation checkpointing.
"""
subblock = ActivationCheckpointWrapper(subblock)
down_path = ActivationCheckpointWrapper(down_path)
up_path = ActivationCheckpointWrapper(up_path)
return super()._get_connection_block(down_path, up_path, subblock)


Unet = UNet
190 changes: 190 additions & 0 deletions tests/networks/nets/test_checkpointunet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.layers import Act, Norm
from monai.networks.nets.unet import CheckpointUNet, UNet
from tests.test_utils import test_script_save

device = "cuda" if torch.cuda.is_available() else "cpu"

TEST_CASE_0 = [ # single channel 2D, batch 16, no residual
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 3,
"channels": (16, 32, 64),
"strides": (2, 2),
"num_res_units": 0,
},
(16, 1, 32, 32),
(16, 3, 32, 32),
]

TEST_CASE_1 = [ # single channel 2D, batch 16
{
"spatial_dims": 2,
"in_channels": 1,
"out_channels": 3,
"channels": (16, 32, 64),
"strides": (2, 2),
"num_res_units": 1,
},
(16, 1, 32, 32),
(16, 3, 32, 32),
]

TEST_CASE_2 = [ # single channel 3D, batch 16
{
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 3,
"channels": (16, 32, 64),
"strides": (2, 2),
"num_res_units": 1,
},
(16, 1, 32, 24, 48),
(16, 3, 32, 24, 48),
]

TEST_CASE_3 = [ # 4-channel 3D, batch 16
{
"spatial_dims": 3,
"in_channels": 4,
"out_channels": 3,
"channels": (16, 32, 64),
"strides": (2, 2),
"num_res_units": 1,
},
(16, 4, 32, 64, 48),
(16, 3, 32, 64, 48),
]

TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalization
{
"spatial_dims": 3,
"in_channels": 4,
"out_channels": 3,
"channels": (16, 32, 64),
"strides": (2, 2),
"num_res_units": 1,
"norm": Norm.BATCH,
},
(16, 4, 32, 64, 48),
(16, 3, 32, 64, 48),
]

TEST_CASE_5 = [ # 4-channel 3D, batch 16, LeakyReLU activation
{
"spatial_dims": 3,
"in_channels": 4,
"out_channels": 3,
"channels": (16, 32, 64),
"strides": (2, 2),
"num_res_units": 1,
"act": (Act.LEAKYRELU, {"negative_slope": 0.2}),
"adn_ordering": "NA",
},
(16, 4, 32, 64, 48),
(16, 3, 32, 64, 48),
]

TEST_CASE_6 = [ # 4-channel 3D, batch 16, LeakyReLU activation explicit
{
"spatial_dims": 3,
"in_channels": 4,
"out_channels": 3,
"channels": (16, 32, 64),
"strides": (2, 2),
"num_res_units": 1,
"act": (torch.nn.LeakyReLU, {"negative_slope": 0.2}),
},
(16, 4, 32, 64, 48),
(16, 3, 32, 64, 48),
]

CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]


class TestCheckpointUNet(unittest.TestCase):
@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shape):
net = CheckpointUNet(**input_param).to(device)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

def test_script(self):
"""
TorchScript doesn't support activation-checkpointing (torch.utils.checkpoint) calls inside the module.
To keep the test suite validating TorchScript compatibility, script the plain UNet (which is scriptable),
rather than the CheckpointUNet wrapper that uses checkpointing internals.
"""
net = UNet(
spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0
)
test_data = torch.randn(16, 1, 32, 32)
test_script_save(net, test_data)

def test_checkpointing_equivalence_eval(self):
"""Ensure that CheckpointUNet matches standard UNet in eval mode (checkpointing inactive)."""
params = dict(
spatial_dims=2, in_channels=1, out_channels=2, channels=(8, 16, 32), strides=(2, 2), num_res_units=1
)

x = torch.randn(2, 1, 32, 32, device=device)

torch.manual_seed(42)
net_plain = UNet(**params).to(device)

torch.manual_seed(42)
net_ckpt = CheckpointUNet(**params).to(device)

# Both in eval mode disables checkpointing logic
with eval_mode(net_ckpt), eval_mode(net_plain):
y_ckpt = net_ckpt(x)
y_plain = net_plain(x)

# Check shape equality
self.assertEqual(y_ckpt.shape, y_plain.shape)

# Check numerical similarity
diff = torch.mean(torch.abs(y_ckpt - y_plain)).item()
self.assertLess(diff, 1e-3, f"Eval-mode outputs differ more than expected (mean abs diff={diff:.6f})")

def test_checkpointing_activates_training(self):
"""Ensure checkpointing triggers recomputation under training and gradients propagate."""
params = dict(
spatial_dims=2, in_channels=1, out_channels=1, channels=(8, 16, 32), strides=(2, 2), num_res_units=1
)

net = CheckpointUNet(**params).to(device)
net.train()

x = torch.randn(2, 1, 32, 32, device=device, requires_grad=True)
y = net(x)
loss = y.mean()
loss.backward()

# gradient flow check
grad_norm = sum(p.grad.abs().sum() for p in net.parameters() if p.grad is not None)
self.assertGreater(grad_norm.item(), 0.0)


if __name__ == "__main__":
unittest.main()
Loading