-
Notifications
You must be signed in to change notification settings - Fork 358
introduce new int8 quantization API #3241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3241
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| quant_min=self.int8_min, | ||
| quant_max=self.int8_max, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can omit these two args if these are the same as default (-128, 127)
| ) | ||
|
|
||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| def test_quantization_shapes(self, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to be a combination of two tests, one for dynamic quant one for static quant, can you use something like this:
| @common_utils.parametrize("mode", ["dynamic", "weight-only"]) |
also I feel it might be better to not add static quant in this PR, and in a separate PR add both the tensor support and config support for static quant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, not sure to remove static flags (although its not fully implemented) before, but small PR should be always better I feel. I will remove static_scale and all those supports.
| if act_quant_kwargs is not None and act_quant_kwargs.static_scale is not None: | ||
| # INT8 × INT8 (static) | ||
| scale = act_quant_kwargs.static_scale | ||
| zero_point = torch.zeros_like(scale, dtype=torch.int8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think user should specify static_zero_point as well
but again, it's better to do this in a separate PR, since current state is a half of the static quant feature (no config)
| # Cast fp16 scale to float | ||
| intermediate_dtype = ( | ||
| torch.float if x_scales.dtype == torch.half else x_scales.dtype | ||
| ) | ||
| # Note: CUDA doesn't support int32/int64 matmul, so we convert to float | ||
| # Error message is NotImplementedError: "addmm_cuda" not implemented for 'Int' | ||
| # This may introduce minor numerical differences compared to int arithmetic | ||
| y_dot = torch.mm(tmp.to(intermediate_dtype), w_vals_t.to(intermediate_dtype)) | ||
|
|
||
| # Apply activation scale | ||
| is_per_tensor_act = x_scales.numel() == 1 | ||
| if is_per_tensor_act: | ||
| y_dot.mul_(x_scales.to(intermediate_dtype)) | ||
| else: | ||
| # For block-wise activation scale, reshape to match y_dot | ||
| x_scales_reshaped = x_scales.view(y_dot.shape[0], -1) | ||
| y_dot.mul_(x_scales_reshaped.to(intermediate_dtype)) | ||
|
|
||
| # Apply weight scale | ||
| is_per_tensor_weight = w_scales.numel() == 1 | ||
| if is_per_tensor_weight: | ||
| result = y_dot.mul_(w_scales.to(intermediate_dtype)) | ||
| else: | ||
| # Per-row weight scale - transpose and broadcast | ||
| w_scales_broadcast = w_scales.t().expand_as(y_dot) | ||
| result = y_dot.mul_(w_scales_broadcast.to(intermediate_dtype)) | ||
|
|
||
| # Reshape back to original shape | ||
| result = result.view(*x_vals.shape[:-1], result.shape[-1]) | ||
| result = result.to(activation_tensor.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should follow:
ao/torchao/dtypes/uintx/plain_layout.py
Line 281 in e9c7bea
| def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should
- split the static quant support to separate PR
- follow what https://github.com/pytorch/ao/blob/main/torchao/dtypes/uintx/plain_layout.py is doing for quantized linear implementation
this should be a refactor PR, not a refactor + some extra modifications + some feature implementations I think
| aten = torch.ops.aten | ||
|
|
||
| # Unsupported case for now, this would be 1 scale per data element | ||
| # Per-tensor quantization (scalar scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this change related?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is updated to support tensor slicing granularity. Without this change, we can't use per-tensor (0D scale) and per-row (1D scale).
| Int8WeightOnlyConfig(version=2), | ||
| ], | ||
| ) | ||
| def test_per_row_scale_shape(self, dtype, config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a test like this one
| def test_fp8_linear_variants( |
|
|
||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| @common_utils.parametrize("has_bias", [True, False]) | ||
| def test_weight_only_linear_with_bias(self, dtype, has_bias): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can probably be merged into the linear varaints test as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, I think the tensor changes looks good, but need to make a linear_variants tests to make sure we cover different aspects of things (e.g. compile), see comments inline
can you also do a e2e perf check with https://github.com/pytorch/ao/blob/main/tutorials/quantize_vit/run_vit_b_quant.py to make sure the performance are the same before and after change for vit model?
also add a kernel check might be useful to make sure we don't regress things:
| def test_expected_gpu_kernel_fbgemm(self): |
|
Updated logs:
|
Summary:
Introduce a new tensor subclass API. Main features are
Int8Tensor: Main API, which handles quantization and dequantization operationsThis api is integrated to global variants (
Int8WeightOnlyConfig,Int8DynamicActivationInt8WeightConfig) usingversion, and not defined as a default.Related Issue/PR:
This is reopened PR for #3038
Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Performance:
The following are the results of https://github.com/pytorch/ao/blob/main/tutorials/quantize_vit/run_vit_b_quant.py with a batch size of 32:
torch.compiletorch.compile