-
Notifications
You must be signed in to change notification settings - Fork 62
Autobatch oom #337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Autobatch oom #337
Conversation
|
This is a good catch. Thanks @cw-tan! I don't love the idea of ever blanket ignoring exceptions and I am wondering if theres a better middle ground. What do you think of allowing the user to modify the search string for OOM errors? So add a kwarg |
Currently, the ignored exception is still printed out for the user to check and verify, but I do agree that a safer solution is to configure the The only scenario that I can think of where that's not enough is when the OOM error message can change over the course of the run, i.e. some model state that can change over the course of several evaluations (e.g. potentially some complicated inference mechanism that might switch between calling the AOT Inductor model vs Python-only model depending on the data given). But that's pretty contrived and should hopefully not be a case to worry about in the immediate future. Anyway, the solution for this could be list of oom_error_messages instead of a single str. Happy to update the PR to provide the search string as a kwarg instead. |
|
I was thinking of a list of strings too. Let's go with that to provide maximum flexibility. Let's keep the type as oom_error_message: str | list[str] ="CUDA out of memory" |
|
Updated PR to use |
|
Running something on Perlmutter now gives me errors that look like i.e. it's not even enough to catch runtime errors. @orionarcher what do you think of just catching all exceptions, not just |
|
Just catching all exceptions makes sense to me! Exceptions as arguments feels bloated. I think checking the strings is sufficient validation. |
Motivation
AOT Inductor models won't trigger OOM runtime errors with the same message since the the OOM is thrown from C++. The caught runtime error will end up looking like
In general,
might be too narrow a criterion.
Proposed Change
The proposed solution is to add a
treat_runtime_error_as_oomargument that has to be explicitly set when instantiating autobatchers, which will eventually be passed to thedetermine_max_batch_sizefunction. This won't change existing behavior, but provides a knob to just treat every runtime error as an OOM that should account for future cases where models have runtime errors that don't explicitly contain the "OOM" message. This should be safe since a user has to very intentionally configure this behavior and are expected to know whether the error truly is OOM or not before intentionally configuring this behavior.Checklist
Before a pull request can be merged, the following items must be checked: