Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions torch_sim/autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def determine_max_batch_size(
max_atoms: int = 500_000,
start_size: int = 1,
scale_factor: float = 1.6,
oom_error_message: str | list[str] = "CUDA out of memory",
) -> int:
"""Determine maximum batch size that fits in GPU memory.

Expand All @@ -269,12 +270,16 @@ def determine_max_batch_size(
start_size (int): Initial batch size to test. Defaults to 1.
scale_factor (float): Factor to multiply batch size by in each iteration.
Defaults to 1.6.
oom_error_message (str | list[str]): String or list of strings to match in
RuntimeError messages to identify out-of-memory errors. Defaults to
"CUDA out of memory".

Returns:
int: Maximum number of batches that fit in GPU memory.

Raises:
RuntimeError: If any error other than CUDA out of memory occurs during testing.
RuntimeError: If a RuntimeError occurs that doesn't match any of the
specified OOM error messages.

Example::

Expand All @@ -287,6 +292,10 @@ def determine_max_batch_size(
The function returns a batch size slightly smaller than the actual maximum
(with a safety margin) to avoid operating too close to memory limits.
"""
# Convert oom_error_message to list if it's a string
if isinstance(oom_error_message, str):
oom_error_message = [oom_error_message]

# Create a geometric sequence of batch sizes
sizes = [start_size]
while (
Expand All @@ -300,10 +309,14 @@ def determine_max_batch_size(

try:
measure_model_memory_forward(concat_state, model)
except RuntimeError as exc:
if "CUDA out of memory" in str(exc):
# Return the last successful size, with a safety margin
return sizes[max(0, sys_idx - 2)]
except Exception as exc:
exc_str = str(exc)
# Check if any of the OOM error messages match
for msg in oom_error_message:
if msg in exc_str:
return sizes[max(0, sys_idx - 2)]

# No OOM message matched - re-raise the error
raise

return sizes[-1]
Expand Down Expand Up @@ -469,6 +482,7 @@ def __init__(
max_atoms_to_try: int = 500_000,
memory_scaling_factor: float = 1.6,
max_memory_padding: float = 1.0,
oom_error_message: str | list[str] = "CUDA out of memory",
) -> None:
"""Initialize the binning auto-batcher.

Expand All @@ -490,13 +504,17 @@ def __init__(
to 1.6.
max_memory_padding (float): Multiply the auto-determined max_memory_scaler
by this value to account for fluctuations in max memory. Defaults to 1.0.
oom_error_message (str | list[str]): String or list of strings to match in
RuntimeError messages to identify out-of-memory errors. Defaults to
"CUDA out of memory".
"""
self.max_memory_scaler = max_memory_scaler
self.max_atoms_to_try = max_atoms_to_try
self.memory_scales_with = memory_scales_with
self.model = model
self.memory_scaling_factor = memory_scaling_factor
self.max_memory_padding = max_memory_padding
self.oom_error_message = oom_error_message

def load_states(self, states: T | Sequence[T]) -> float:
"""Load new states into the batcher.
Expand Down Expand Up @@ -542,6 +560,7 @@ def load_states(self, states: T | Sequence[T]) -> float:
self.memory_scalers,
max_atoms=self.max_atoms_to_try,
scale_factor=self.memory_scaling_factor,
oom_error_message=self.oom_error_message,
)
self.max_memory_scaler = self.max_memory_scaler * self.max_memory_padding

Expand Down Expand Up @@ -744,6 +763,7 @@ def __init__(
memory_scaling_factor: float = 1.6,
max_iterations: int | None = None,
max_memory_padding: float = 1.0,
oom_error_message: str | list[str] = "CUDA out of memory",
) -> None:
"""Initialize the hot-swapping auto-batcher.

Expand All @@ -768,6 +788,9 @@ def __init__(
infinite loops. Defaults to None (no limit).
max_memory_padding (float): Multiply the auto-determined max_memory_scaler
by this value to account for fluctuations in max memory. Defaults to 1.0.
oom_error_message (str | list[str]): String or list of strings to match in
RuntimeError messages to identify out-of-memory errors. Defaults to
"CUDA out of memory".
"""
self.model = model
self.memory_scales_with = memory_scales_with
Expand All @@ -776,6 +799,7 @@ def __init__(
self.memory_scaling_factor = memory_scaling_factor
self.max_iterations = max_iterations
self.max_memory_padding = max_memory_padding
self.oom_error_message = oom_error_message

def load_states(self, states: Sequence[T] | Iterator[T] | T) -> None:
"""Load new states into the batcher.
Expand Down Expand Up @@ -911,6 +935,7 @@ def _get_first_batch(self) -> T:
self.model,
max_atoms=self.max_atoms_to_try,
scale_factor=self.memory_scaling_factor,
oom_error_message=self.oom_error_message,
)
self.max_memory_scaler = n_systems * first_metric * 0.8

Expand All @@ -923,6 +948,7 @@ def _get_first_batch(self) -> T:
self.current_scalers,
max_atoms=self.max_atoms_to_try,
scale_factor=self.memory_scaling_factor,
oom_error_message=self.oom_error_message,
)
self.max_memory_scaler = self.max_memory_scaler * self.max_memory_padding
newer_states = self._get_next_states()
Expand Down
1 change: 1 addition & 0 deletions torch_sim/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915
max_memory_scaler=autobatcher.max_memory_scaler,
memory_scales_with=autobatcher.memory_scales_with,
max_atoms_to_try=autobatcher.max_atoms_to_try,
oom_error_message=autobatcher.oom_error_message,
)
autobatcher.load_states(state)
if trajectory_reporter is not None:
Expand Down