-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat: add activation checkpointing to unet #8554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ferreirafabio80
wants to merge
19
commits into
Project-MONAI:dev
Choose a base branch
from
ferreirafabio80:feat/add_activation_checkpointing_to_unet
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+258
−1
Open
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
de2b6bd
feat: add optional gradient checkpointing to unet
66edcb5
fix: small ruff issue
e66e357
Update monai/networks/nets/unet.py
ferreirafabio80 feefcaa
docs: update docstrings
e112457
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f673ca1
fix: avoid BatchNorm subblocks
69540ff
fix: revert batch norm changes
42ec757
refactor: creates a subclass of UNet and overrides the get connection…
a2e8474
chore: remove use checkpointing from doc string
4c4782e
fix: linting issues
515c659
feat: add activation checkpointing to down and up paths to be more ef…
da5a3a4
refactor: move activation checkpointing wrapper to blocks
43dec88
chore: add docstrings to checkpointed unet
84c0f48
test: add checkpoint unet test
5805515
fix: change test name
1aa8e3c
fix: simplify test and make sure that checkpoint unet runs well in tr…
447d9f2
fix: set seed
b20a19e
fix: fix testing bugs
41f000f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import cast | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from torch.utils.checkpoint import checkpoint | ||
|
|
||
|
|
||
| class ActivationCheckpointWrapper(nn.Module): | ||
| """Wrapper applying activation checkpointing to a module during training. | ||
|
|
||
| Args: | ||
| module: The module to wrap with activation checkpointing. | ||
| """ | ||
|
|
||
| def __init__(self, module: nn.Module) -> None: | ||
| super().__init__() | ||
| self.module = module | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| """Forward pass with optional activation checkpointing. | ||
|
|
||
| Args: | ||
| x: Input tensor. | ||
|
|
||
| Returns: | ||
| Output tensor from the wrapped module. | ||
| """ | ||
| return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,190 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| from parameterized import parameterized | ||
|
|
||
| from monai.networks import eval_mode | ||
| from monai.networks.layers import Act, Norm | ||
| from monai.networks.nets.unet import CheckpointUNet, UNet | ||
| from tests.test_utils import test_script_save | ||
|
|
||
| device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
|
||
| TEST_CASE_0 = [ # single channel 2D, batch 16, no residual | ||
| { | ||
| "spatial_dims": 2, | ||
| "in_channels": 1, | ||
| "out_channels": 3, | ||
| "channels": (16, 32, 64), | ||
| "strides": (2, 2), | ||
| "num_res_units": 0, | ||
| }, | ||
| (16, 1, 32, 32), | ||
| (16, 3, 32, 32), | ||
| ] | ||
|
|
||
| TEST_CASE_1 = [ # single channel 2D, batch 16 | ||
| { | ||
| "spatial_dims": 2, | ||
| "in_channels": 1, | ||
| "out_channels": 3, | ||
| "channels": (16, 32, 64), | ||
| "strides": (2, 2), | ||
| "num_res_units": 1, | ||
| }, | ||
| (16, 1, 32, 32), | ||
| (16, 3, 32, 32), | ||
| ] | ||
|
|
||
| TEST_CASE_2 = [ # single channel 3D, batch 16 | ||
| { | ||
| "spatial_dims": 3, | ||
| "in_channels": 1, | ||
| "out_channels": 3, | ||
| "channels": (16, 32, 64), | ||
| "strides": (2, 2), | ||
| "num_res_units": 1, | ||
| }, | ||
| (16, 1, 32, 24, 48), | ||
| (16, 3, 32, 24, 48), | ||
| ] | ||
|
|
||
| TEST_CASE_3 = [ # 4-channel 3D, batch 16 | ||
| { | ||
| "spatial_dims": 3, | ||
| "in_channels": 4, | ||
| "out_channels": 3, | ||
| "channels": (16, 32, 64), | ||
| "strides": (2, 2), | ||
| "num_res_units": 1, | ||
| }, | ||
| (16, 4, 32, 64, 48), | ||
| (16, 3, 32, 64, 48), | ||
| ] | ||
|
|
||
| TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalization | ||
| { | ||
| "spatial_dims": 3, | ||
| "in_channels": 4, | ||
| "out_channels": 3, | ||
| "channels": (16, 32, 64), | ||
| "strides": (2, 2), | ||
| "num_res_units": 1, | ||
| "norm": Norm.BATCH, | ||
| }, | ||
| (16, 4, 32, 64, 48), | ||
| (16, 3, 32, 64, 48), | ||
| ] | ||
|
|
||
| TEST_CASE_5 = [ # 4-channel 3D, batch 16, LeakyReLU activation | ||
| { | ||
| "spatial_dims": 3, | ||
| "in_channels": 4, | ||
| "out_channels": 3, | ||
| "channels": (16, 32, 64), | ||
| "strides": (2, 2), | ||
| "num_res_units": 1, | ||
| "act": (Act.LEAKYRELU, {"negative_slope": 0.2}), | ||
| "adn_ordering": "NA", | ||
| }, | ||
| (16, 4, 32, 64, 48), | ||
| (16, 3, 32, 64, 48), | ||
| ] | ||
|
|
||
| TEST_CASE_6 = [ # 4-channel 3D, batch 16, LeakyReLU activation explicit | ||
| { | ||
| "spatial_dims": 3, | ||
| "in_channels": 4, | ||
| "out_channels": 3, | ||
| "channels": (16, 32, 64), | ||
| "strides": (2, 2), | ||
| "num_res_units": 1, | ||
| "act": (torch.nn.LeakyReLU, {"negative_slope": 0.2}), | ||
| }, | ||
| (16, 4, 32, 64, 48), | ||
| (16, 3, 32, 64, 48), | ||
| ] | ||
|
|
||
| CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6] | ||
|
|
||
|
|
||
| class TestCheckpointUNet(unittest.TestCase): | ||
| @parameterized.expand(CASES) | ||
| def test_shape(self, input_param, input_shape, expected_shape): | ||
| net = CheckpointUNet(**input_param).to(device) | ||
| with eval_mode(net): | ||
| result = net.forward(torch.randn(input_shape).to(device)) | ||
| self.assertEqual(result.shape, expected_shape) | ||
|
|
||
| def test_script(self): | ||
| """ | ||
| TorchScript doesn't support activation-checkpointing (torch.utils.checkpoint) calls inside the module. | ||
| To keep the test suite validating TorchScript compatibility, script the plain UNet (which is scriptable), | ||
| rather than the CheckpointUNet wrapper that uses checkpointing internals. | ||
| """ | ||
| net = UNet( | ||
| spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 | ||
| ) | ||
| test_data = torch.randn(16, 1, 32, 32) | ||
| test_script_save(net, test_data) | ||
|
|
||
| def test_checkpointing_equivalence_eval(self): | ||
| """Ensure that CheckpointUNet matches standard UNet in eval mode (checkpointing inactive).""" | ||
| params = dict( | ||
| spatial_dims=2, in_channels=1, out_channels=2, channels=(8, 16, 32), strides=(2, 2), num_res_units=1 | ||
| ) | ||
|
|
||
| x = torch.randn(2, 1, 32, 32, device=device) | ||
|
|
||
| torch.manual_seed(42) | ||
| net_plain = UNet(**params).to(device) | ||
|
|
||
| torch.manual_seed(42) | ||
| net_ckpt = CheckpointUNet(**params).to(device) | ||
|
|
||
| # Both in eval mode disables checkpointing logic | ||
| with eval_mode(net_ckpt), eval_mode(net_plain): | ||
| y_ckpt = net_ckpt(x) | ||
| y_plain = net_plain(x) | ||
|
|
||
| # Check shape equality | ||
| self.assertEqual(y_ckpt.shape, y_plain.shape) | ||
|
|
||
| # Check numerical similarity | ||
| diff = torch.mean(torch.abs(y_ckpt - y_plain)).item() | ||
| self.assertLess(diff, 1e-3, f"Eval-mode outputs differ more than expected (mean abs diff={diff:.6f})") | ||
|
|
||
| def test_checkpointing_activates_training(self): | ||
| """Ensure checkpointing triggers recomputation under training and gradients propagate.""" | ||
| params = dict( | ||
| spatial_dims=2, in_channels=1, out_channels=1, channels=(8, 16, 32), strides=(2, 2), num_res_units=1 | ||
| ) | ||
|
|
||
| net = CheckpointUNet(**params).to(device) | ||
| net.train() | ||
|
|
||
| x = torch.randn(2, 1, 32, 32, device=device, requires_grad=True) | ||
| y = net(x) | ||
| loss = y.mean() | ||
| loss.backward() | ||
|
|
||
| # gradient flow check | ||
| grad_norm = sum(p.grad.abs().sum() for p in net.parameters() if p.grad is not None) | ||
| self.assertGreater(grad_norm.item(), 0.0) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gate checkpointing to active training passes.
The docstring promises training-only checkpointing, but
forwardalways recomputes, so eval/no-grad still pays the checkpoint dispatch. Wrap the call withself.training,torch.is_grad_enabled(), and anx.requires_gradcheck, falling back to the plain module call otherwise, to avoid needless recompute overhead while preserving the memory trade-off during training.(docs.pytorch.org)def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass with optional activation checkpointing. Args: x: Input tensor. Returns: Output tensor from the wrapped module. """ - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + if self.training and torch.is_grad_enabled() and x.requires_grad: + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + return cast(torch.Tensor, self.module(x))🤖 Prompt for AI Agents
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think just change the docstring here instead of the code change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ericspod Understood. Update the docstrings to reflect unconditional checkpointing:
Class docstring:
Forward method docstring:
Key changes: removed "during training" and "optional" language to match the unconditional behavior.