- 
                Notifications
    
You must be signed in to change notification settings  - Fork 357
 
Add Int8Tensor for clearer interface #3038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Introduce new tensor subclass API for int8 quantization with clearer interface. The main change can be summarized to the following: - Old: Complex affine transform (AffineQuantizedTensor) with separate layout handling - New: Direct int8 tensor with qdata, scale, and zero_point attributes Test plan: test/quantization/quantize_/workflows/int8/test_int8_tensor.py Future plan: Implement block-wise quantization using `block_size` parameter
          
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3038
 Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes.  | 
    
| 
           can you add a version 2 and expose this tensor through ao/torchao/quantization/quant_api.py Line 1497 in 8525185 
 ao/torchao/quantization/quant_api.py Line 1752 in 8525185 
  | 
    
| result = result.to(scale.dtype) * scale | ||
| result = result.view(*input_tensor.shape[:-1], -1) | ||
| else: | ||
| # FP × INT8 (static) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also this is the code for weight only quant I think:
ao/torchao/dtypes/uintx/plain_layout.py
Line 250 in 122b307
| def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done at 9383550 , thanks for pointing it out.
| raise ValueError("Expected 2D tensor and block_size length 2") | ||
| 
               | 
          ||
| # Rounding function from high precision dtype | ||
| scale = w.abs().max(dim=-1, keepdim=True)[0] / 127.0 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like block_size is not used? why is that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can checkout
ao/torchao/dtypes/uintx/plain_layout.py
Line 232 in 8c5c33e
| def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): | 
also this should be using these quant primitive ops:
ao/torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py
Lines 79 to 97 in 8c5c33e
| scale, zero_point = choose_qparams_affine( | |
| input=preprocessed_w, | |
| mapping_type=MappingType.SYMMETRIC, | |
| block_size=block_size, | |
| target_dtype=target_dtype, | |
| quant_min=quant_min, | |
| quant_max=quant_max, | |
| eps=1e-6, | |
| ) | |
| wq = quantize_affine( | |
| input=preprocessed_w, | |
| block_size=block_size, | |
| scale=scale, | |
| zero_point=zero_point, | |
| output_dtype=target_dtype, | |
| quant_min=quant_min, | |
| quant_max=quant_max, | |
| ) | 
ao/torchao/quantization/quant_api.py
Line 1566 in 8c5c33e
| new_weight = to_affine_quantized_intx( | 
ao/torchao/dtypes/affine_quantized_tensor.py
Line 325 in 8c5c33e
| scale, zero_point = choose_qparams_affine( | 
this might require a bit too much context, let me know if you would like us to take over
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, surely want to take over! Drafted this PR for those updates, but will look into it today (6 hours later)
btw, version 2 is updated at c53dad0 (version 1 is default)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please rebase, and let me know when this is ready for review again @namgyu-youn
| if not isinstance(activation_tensor, Int8Tensor): | ||
| if weight_tensor.act_quant_kwargs.static_scale is not None: | ||
| # INT8 × INT8 (static): symmetric quantization only | ||
| static_scale = weight_tensor.act_quant_kwargs.static_scale | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK if this is needed I think it should be included in _choose_quant_func_and_quantize_tensor as well?
| implements_torch_function = Int8Tensor.implements_torch_function | ||
| 
               | 
          ||
| 
               | 
          ||
| @implements([aten.dequantize.self]) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed? if not we should remove for now
| if scale.numel() > 1 and scale.shape != qdata_fp.shape: | ||
| scale = scale.view(*scale.shape, *[1] * (qdata_fp.ndim - scale.ndim)) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is needed for block-level granularity. For example,
- Row-wise: If scale shape is (64, 1) and w_q (quantized weight shape) is (256, 512), we can naturally broadcast them
 - Channel-wise: If scale shape is (512,) and w_q is (256, 512), we can naturally broadcast them
 - Block-size granularity: If scale shape is (32, 64) and w_q is (256, 512), we have to rescale to broadcast them.
 
But we can also reuse _maybe_expand_scale_to_tensor_shape, similar to:
ao/torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Lines 149 to 154 in 4b79f9e
| def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: | |
| if output_dtype is None: | |
| output_dtype = self.dtype | |
| qdata, scale = self.qdata, self.scale | |
| return _dequantize_affine_float8(qdata, scale, output_dtype) | 
and
ao/torchao/quantization/quant_primitives.py
Lines 2407 to 2421 in f3fc5e7
| def _dequantize_affine_float8( | |
| tensor: torch.Tensor, | |
| scale: torch.Tensor, | |
| output_dtype: torch.dtype = torch.float32, | |
| ) -> torch.Tensor: | |
| """ | |
| Dequantizes the float8 tensor to high precision tensor. | |
| """ | |
| fp8_tensor = tensor.to(torch.float32) | |
| # Expand scale to match tensor dimensions for block-wise quantization | |
| scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, tensor.shape) | |
| hp_tensor = fp8_tensor * scale_expanded | |
| return hp_tensor.to(output_dtype) | 
| cls: type, | ||
| qdata: torch.Tensor, | ||
| scale: torch.Tensor, | ||
| block_size: list[int], | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I remember list has a higher python version requirements, so probably better to change this to List from typing I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it is only for List, not for Dict, Tuple, etc.?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably also for Dict and Tuple, I have only tried list before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a use case with old (< 3.9) python version? I remember list, dict, and tuple is natively supported to 3.10 https://docs.astral.sh/ruff/rules/non-pep585-annotation/.
Because PEP585 (type hint) is chained with pre-commit issue, prefer to focus on new versions (no need for from typing import List, Dict, Tuple). How about using list, dict, tuple focusing on new python versions? There might be a new issue by PEP 585 if we go with Dict, List, Tuple I feel.
| return module | ||
| 
               | 
          ||
| 
               | 
          ||
| def _unwrap_float8_linear(module: Float8Linear) -> nn.Linear: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some rebase issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I assumed 0a45f90 was a wrong way, which was the start of the rebase issue. The solution looks like dropping relevant commits using rebase.
But rebasing after all those commits is overwhelming to me. So, I really don't want to open a duplicate PR, but may I reopen the PR and link to this PR? I just want to remove unrelevant codes change log.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@namgyu-youn please feel free to close and reopen a new one if it's hard to fix rebase issue, seems like it's still not fully fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, just want to remove unrelevant change logs as you mentioned.
| 
           Updated log: 
 
  | 
    
        
          
                torchao/utils.py
              
                Outdated
          
        
      | from importlib.metadata import version | ||
| from math import gcd | ||
| from typing import Any, Callable, Optional, Type | ||
| from typing import Any, Callable, Optional | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please fix rebase to not have these changes, or open a new PR if you don't know how to fix rebase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please fix rebase, otherwise seems mostly OK I think
| @common_utils.parametrize( | ||
| "sizes", | ||
| [ | ||
| ((128,), 256, 128), | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does 3D inputs work? e.g. ((32, 128,), 256, 128),
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, 3D input raise ValueError triggered by from_hp(): #3038 (comment)
| assert error > 20, f"Quantization error is too high got a SQNR of {error}" | ||
| 
               | 
          ||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| def test_static_quantization(self, dtype): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a test the static quant config and test that as well Int8StaticActivationInt8WeightConfig
or maybe you can remove this for now and coordinate with @Xia-Weiwen (https://github.com/pytorch/ao/pull/3089/files#diff-bf4d50867e3d649de2d89146592bf47d2f258c4c19126c8acf0e120ee904b726) to add the static quant support separately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both are ok for me. Thanks.
| @common_utils.parametrize( | ||
| "config", | ||
| [ | ||
| Int8DynamicActivationInt8WeightConfig(version=2), | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will need to test the static quant as well, if that is added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changes in this PR should be reverted as well
| ) | ||
| 
               | 
          ||
| 
               | 
          ||
| @implements(aten.select.int) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this tested? if not, please remove for now
| if tensor.scale.numel() == 1: | ||
| # Per-tensor quantization - scale doesn't change | ||
| sliced_scale = tensor.scale | ||
| elif dim < tensor.scale.ndim and tensor.scale.shape[dim] > 1: | ||
| # Block-wise quantization - need to slice the scale appropriately | ||
| sliced_scale = func(tensor.scale, dim, start, end, step) | ||
| else: | ||
| sliced_scale = tensor.scale | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you match the implementation with Float8Tensor?
ao/torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Lines 449 to 456 in 53b5efd
| if self.scale.numel() == 1: | |
| # Per-tensor quantization - scale doesn't change | |
| sliced_scale = self.scale | |
| else: | |
| # Block-wise quantization - need to slice the scale appropriately | |
| sliced_scale = _slice_scale_for_dimension( | |
| self.scale, self.qdata.shape, dim, start, end, step | |
| ) | 
| 
           Updated logs: 
 In 680cec9 
  | 
    
| 
           @jerryzh168 sorry for the multiple-PRs again; reopened PR (#3241) is copy-pasted after the last commit in this PR and resolves rebase errors. Please check the above comment (change log; #3038 (comment)) first, and please take a look at #3241, thanks.  | 
    
Summary:
Introduce new tensor subclass API for int8 quantization with clearer interface.
The main change can be summarized to the following:
AffineQuantizedTensor) with separate layout handlingRelated Issue/PR: #3012 (comment) #2752
Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py