Skip to content

Conversation

@cw-tan
Copy link
Contributor

@cw-tan cw-tan commented Nov 5, 2025

Motivation

AOT Inductor models won't trigger OOM runtime errors with the same message since the the OOM is thrown from C++. The caught runtime error will end up looking like

RuntimeError: run_func_( container_handle_, input_handles.data(), input_handles.size(), output_handles.data(), output_handles.size(), reinterpret_cast<AOTInductorStreamHandle>(stream_handle), proxy_executor_handle_) API call failed at /pytorch/torch/csrc/inductor/aoti_runner/model_container_runner.cpp, line 145

In general,

if "CUDA out of memory" in str(exc):

might be too narrow a criterion.

Proposed Change

The proposed solution is to add a treat_runtime_error_as_oom argument that has to be explicitly set when instantiating autobatchers, which will eventually be passed to the determine_max_batch_size function. This won't change existing behavior, but provides a knob to just treat every runtime error as an OOM that should account for future cases where models have runtime errors that don't explicitly contain the "OOM" message. This should be safe since a user has to very intentionally configure this behavior and are expected to know whether the error truly is OOM or not before intentionally configuring this behavior.

Checklist

Before a pull request can be merged, the following items must be checked:

  • Doc strings have been added in the Google docstring format.
  • Run ruff on your code.
  • Tests have been added for any new functionality or bug fixes.

@orionarcher
Copy link
Collaborator

orionarcher commented Nov 5, 2025

This is a good catch. Thanks @cw-tan!

I don't love the idea of ever blanket ignoring exceptions and I am wondering if theres a better middle ground. What do you think of allowing the user to modify the search string for OOM errors? So add a kwarg oom_error_message="CUDA out of memory"

@cw-tan
Copy link
Contributor Author

cw-tan commented Nov 5, 2025

I don't love the idea of ever blanket ignoring exceptions and I am wondering if theres a better middle ground. What do you think of allowing the user to modify the search string for OOM errors? So add a kwarg oom_error_message="CUDA out of memory"

Currently, the ignored exception is still printed out for the user to check and verify, but I do agree that a safer solution is to configure the oom_error_message.

The only scenario that I can think of where that's not enough is when the OOM error message can change over the course of the run, i.e. some model state that can change over the course of several evaluations (e.g. potentially some complicated inference mechanism that might switch between calling the AOT Inductor model vs Python-only model depending on the data given). But that's pretty contrived and should hopefully not be a case to worry about in the immediate future. Anyway, the solution for this could be list of oom_error_messages instead of a single str.

Happy to update the PR to provide the search string as a kwarg instead.

@orionarcher
Copy link
Collaborator

I was thinking of a list of strings too. Let's go with that to provide maximum flexibility. Let's keep the type as str | list[str] and then just make it a list if it's not already. So:

oom_error_message: str | list[str] ="CUDA out of memory"

@cw-tan
Copy link
Contributor Author

cw-tan commented Nov 5, 2025

Updated PR to use

oom_error_message: str | list[str] ="CUDA out of memory"

@cw-tan
Copy link
Contributor Author

cw-tan commented Nov 5, 2025

Running something on Perlmutter now gives me errors that look like

Model Memory Estimation: Running forward pass on state with 100 atoms and 1 systems.
Model Memory Estimation: Running forward pass on state with 200 atoms and 2 systems.
Model Memory Estimation: Running forward pass on state with 300 atoms and 3 systems.
Model Memory Estimation: Running forward pass on state with 500 atoms and 5 systems.
Model Memory Estimation: Running forward pass on state with 800 atoms and 8 systems.
Model Memory Estimation: Running forward pass on state with 1300 atoms and 13 systems.
Model Memory Estimation: Running forward pass on state with 2100 atoms and 21 systems.
Model Memory Estimation: Running forward pass on state with 3400 atoms and 34 systems.
Model Memory Estimation: Running forward pass on state with 5400 atoms and 54 systems.
Error: CUDA driver error: an illegal memory access was encountered
Traceback (most recent call last):
  File "/pscratch/sd/c/cw-tan/matbench-discovery/models/nequip/torchsim_evals/test_nequip_discovery.py", line 91, in <module>
    final_state = ts.optimize(
                  ^^^^^^^^^^^^
  File "/pscratch/sd/c/cw-tan/torch-sim/torch_sim/runners.py", line 444, in optimize
    state = _chunked_apply(
            ^^^^^^^^^^^^^^^
  File "/pscratch/sd/c/cw-tan/torch-sim/torch_sim/runners.py", line 296, in _chunked_apply
    fn(model=model, state=system, **init_kwargs) for system, _indices in autobatcher
                                                                         ^^^^^^^^^^^
  File "/pscratch/sd/c/cw-tan/torch-sim/torch_sim/autobatching.py", line 651, in __next__
    next_batch, indices = self.next_batch()
                          ^^^^^^^^^^^^^^^^^
  File "/pscratch/sd/c/cw-tan/torch-sim/torch_sim/autobatching.py", line 613, in next_batch
    state = ts.concatenate_states(state_bin)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/pscratch/sd/c/cw-tan/torch-sim/torch_sim/state.py", line 891, in concatenate_states
    num_systems = state.n_systems
                  ^^^^^^^^^^^^^^^
  File "/pscratch/sd/c/cw-tan/torch-sim/torch_sim/state.py", line 184, in n_systems
    return torch.unique(self.system_idx).shape[0]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/u2/c/cw-tan/micromamba/envs/nequip/lib/python3.12/site-packages/torch/_jit_internal.py", line 627, in fn
    return if_false(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/u2/c/cw-tan/micromamba/envs/nequip/lib/python3.12/site-packages/torch/_jit_internal.py", line 627, in fn
    return if_false(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/u2/c/cw-tan/micromamba/envs/nequip/lib/python3.12/site-packages/torch/functional.py", line 1102, in _return_output
    output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/u2/c/cw-tan/micromamba/envs/nequip/lib/python3.12/site-packages/torch/functional.py", line 995, in _unique_impl
    output, inverse_indices, counts = torch._unique2(
                                      ^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Failed to destroy CUDA event in AOTInductor model: an illegal memory access was encountered
terminate called after throwing an instance of 'std::runtime_error'
  what():  CUDA error: an illegal memory access was encountered
Aborted

i.e. it's not even enough to catch runtime errors. @orionarcher what do you think of just catching all exceptions, not just RuntimeError? We do check for specific strs anyway to ensure it's actually OOM-ing. Otherwise, we could pass another argument for the set of exceptions that we'd check, etc but it just bloats up the args for the autobatchers. Could potentially combine the exceptions and the OOM strs into a single arg too. What do you prefer?

@orionarcher
Copy link
Collaborator

orionarcher commented Nov 5, 2025

Just catching all exceptions makes sense to me! Exceptions as arguments feels bloated. I think checking the strings is sufficient validation.

@orionarcher orionarcher merged commit e0b33f1 into TorchSim:main Nov 6, 2025
139 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants