-
Notifications
You must be signed in to change notification settings - Fork 357
Update Float8Tensor for GRPO training in unsloth #3158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3158
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Unrelated FailureAs of commit 82012af with merge base f856d36 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ed3c237 to
19500bf
Compare
345bb63 to
9d27057
Compare
test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Outdated
Show resolved
Hide resolved
torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Outdated
Show resolved
Hide resolved
torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Outdated
Show resolved
Hide resolved
torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Outdated
Show resolved
Hide resolved
torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Outdated
Show resolved
Hide resolved
torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Outdated
Show resolved
Hide resolved
torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
plz clean up _float8_mm_impl
060b217 to
092ca75
Compare
| input_tensor: Float8Tensor, | ||
| weight_tensor: Float8Tensor, | ||
| bias: Optional[torch.Tensor] = None, | ||
| weight_is_already_transposed: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of this flag, just transpose at the callsite to match the meaning of matmul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason behind this flag is to prevent unnecessary double transpose when we call linear, since the fbgemm op expects the weight to be in the linear format (already transposed). So if we don't have this flag:
1. linear calls _float8_mm_impl(input, weight.t())
2. _float8_mm_impl calls weight.t() before calling torch.ops.fbgemm.f8f8bf16
If we just transpose the weight for linear, we may end up slowing linear down, is that OK?
| ) | ||
|
|
||
|
|
||
| def _float8_matmul_impl( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about
- define _float8_mm_impl as the float8 version of
torch.mm, as the lowest level shared code of maybe quantizing the input and then choosing a gemm - all other functions (matmul, linear, etc) call _float8_mm_impl
it's a bit confusing to have two different paths for linear and matmul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was done to avoid double transpose in the linear path (which doesn't happen today, see this comment). I agree that ideally everything should go through _float8_mm_impl, but doing so may add overhead for the linear path, should I go ahead and merge the implementations anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, just did some benchmark on the double transpose thing, seems like it didn't introduce much overhead. I refactored the code to the way you suggested, please have another look
1619676 to
a323bbe
Compare
**Summary:** Support a few extra ops called during GRPO loop in unsloth/vllm for Float8Tensor. **Test Plan:** ``` python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_matmul_lora_variants python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_to_dtype_layout python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_has_compatible_shallow_copy_type python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_transpose ```
a323bbe to
82012af
Compare
Summary: Support a few extra ops called during GRPO loop in unsloth/vllm for Float8Tensor.
Test Plan: