diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index a3c6680f..65bbbe12 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -253,6 +253,7 @@ def determine_max_batch_size( max_atoms: int = 500_000, start_size: int = 1, scale_factor: float = 1.6, + oom_error_message: str | list[str] = "CUDA out of memory", ) -> int: """Determine maximum batch size that fits in GPU memory. @@ -269,12 +270,16 @@ def determine_max_batch_size( start_size (int): Initial batch size to test. Defaults to 1. scale_factor (float): Factor to multiply batch size by in each iteration. Defaults to 1.6. + oom_error_message (str | list[str]): String or list of strings to match in + RuntimeError messages to identify out-of-memory errors. Defaults to + "CUDA out of memory". Returns: int: Maximum number of batches that fit in GPU memory. Raises: - RuntimeError: If any error other than CUDA out of memory occurs during testing. + RuntimeError: If a RuntimeError occurs that doesn't match any of the + specified OOM error messages. Example:: @@ -287,6 +292,10 @@ def determine_max_batch_size( The function returns a batch size slightly smaller than the actual maximum (with a safety margin) to avoid operating too close to memory limits. """ + # Convert oom_error_message to list if it's a string + if isinstance(oom_error_message, str): + oom_error_message = [oom_error_message] + # Create a geometric sequence of batch sizes sizes = [start_size] while ( @@ -300,10 +309,14 @@ def determine_max_batch_size( try: measure_model_memory_forward(concat_state, model) - except RuntimeError as exc: - if "CUDA out of memory" in str(exc): - # Return the last successful size, with a safety margin - return sizes[max(0, sys_idx - 2)] + except Exception as exc: + exc_str = str(exc) + # Check if any of the OOM error messages match + for msg in oom_error_message: + if msg in exc_str: + return sizes[max(0, sys_idx - 2)] + + # No OOM message matched - re-raise the error raise return sizes[-1] @@ -469,6 +482,7 @@ def __init__( max_atoms_to_try: int = 500_000, memory_scaling_factor: float = 1.6, max_memory_padding: float = 1.0, + oom_error_message: str | list[str] = "CUDA out of memory", ) -> None: """Initialize the binning auto-batcher. @@ -490,6 +504,9 @@ def __init__( to 1.6. max_memory_padding (float): Multiply the auto-determined max_memory_scaler by this value to account for fluctuations in max memory. Defaults to 1.0. + oom_error_message (str | list[str]): String or list of strings to match in + RuntimeError messages to identify out-of-memory errors. Defaults to + "CUDA out of memory". """ self.max_memory_scaler = max_memory_scaler self.max_atoms_to_try = max_atoms_to_try @@ -497,6 +514,7 @@ def __init__( self.model = model self.memory_scaling_factor = memory_scaling_factor self.max_memory_padding = max_memory_padding + self.oom_error_message = oom_error_message def load_states(self, states: T | Sequence[T]) -> float: """Load new states into the batcher. @@ -542,6 +560,7 @@ def load_states(self, states: T | Sequence[T]) -> float: self.memory_scalers, max_atoms=self.max_atoms_to_try, scale_factor=self.memory_scaling_factor, + oom_error_message=self.oom_error_message, ) self.max_memory_scaler = self.max_memory_scaler * self.max_memory_padding @@ -744,6 +763,7 @@ def __init__( memory_scaling_factor: float = 1.6, max_iterations: int | None = None, max_memory_padding: float = 1.0, + oom_error_message: str | list[str] = "CUDA out of memory", ) -> None: """Initialize the hot-swapping auto-batcher. @@ -768,6 +788,9 @@ def __init__( infinite loops. Defaults to None (no limit). max_memory_padding (float): Multiply the auto-determined max_memory_scaler by this value to account for fluctuations in max memory. Defaults to 1.0. + oom_error_message (str | list[str]): String or list of strings to match in + RuntimeError messages to identify out-of-memory errors. Defaults to + "CUDA out of memory". """ self.model = model self.memory_scales_with = memory_scales_with @@ -776,6 +799,7 @@ def __init__( self.memory_scaling_factor = memory_scaling_factor self.max_iterations = max_iterations self.max_memory_padding = max_memory_padding + self.oom_error_message = oom_error_message def load_states(self, states: Sequence[T] | Iterator[T] | T) -> None: """Load new states into the batcher. @@ -911,6 +935,7 @@ def _get_first_batch(self) -> T: self.model, max_atoms=self.max_atoms_to_try, scale_factor=self.memory_scaling_factor, + oom_error_message=self.oom_error_message, ) self.max_memory_scaler = n_systems * first_metric * 0.8 @@ -923,6 +948,7 @@ def _get_first_batch(self) -> T: self.current_scalers, max_atoms=self.max_atoms_to_try, scale_factor=self.memory_scaling_factor, + oom_error_message=self.oom_error_message, ) self.max_memory_scaler = self.max_memory_scaler * self.max_memory_padding newer_states = self._get_next_states() diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 4ac6f3be..1cbbbde8 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -449,6 +449,7 @@ def optimize[T: OptimState]( # noqa: C901, PLR0915 max_memory_scaler=autobatcher.max_memory_scaler, memory_scales_with=autobatcher.memory_scales_with, max_atoms_to_try=autobatcher.max_atoms_to_try, + oom_error_message=autobatcher.oom_error_message, ) autobatcher.load_states(state) if trajectory_reporter is not None: