Skip to content

Conversation

@namgyu-youn
Copy link
Contributor

@namgyu-youn namgyu-youn commented Sep 21, 2025

Summary:
Introduce new tensor subclass API for int8 quantization with clearer interface.

The main change can be summarized to the following:

  • Old: Complex affine transform (AffineQuantizedTensor) with separate layout handling
  • New: Direct int8 tensor with scaling factor and zero point

Related Issue/PR: #3012 (comment) #2752

Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Introduce new tensor subclass API for int8 quantization with clearer interface.

The main change can be summarized to the following:
- Old: Complex affine transform (AffineQuantizedTensor) with separate layout handling
- New: Direct int8 tensor with qdata, scale, and zero_point attributes

Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Future plan:
Implement block-wise quantization using `block_size` parameter
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3038

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 21, 2025
@jerryzh168
Copy link
Contributor

can you add a version 2 and expose this tensor through

class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
? similar to
class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):

@namgyu-youn namgyu-youn changed the title Add Int8PlainInt8Tensor for clearer interface Add Int8Tensor for clearer interface Sep 23, 2025
result = result.to(scale.dtype) * scale
result = result.view(*input_tensor.shape[:-1], -1)
else:
# FP × INT8 (static)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this is the code for weight only quant I think:

def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):

Copy link
Contributor Author

@namgyu-youn namgyu-youn Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done at 9383550 , thanks for pointing it out.

raise ValueError("Expected 2D tensor and block_size length 2")

# Rounding function from high precision dtype
scale = w.abs().max(dim=-1, keepdim=True)[0] / 127.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like block_size is not used? why is that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can checkout

def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias):
for expected granularity

also this should be using these quant primitive ops:

scale, zero_point = choose_qparams_affine(
input=preprocessed_w,
mapping_type=MappingType.SYMMETRIC,
block_size=block_size,
target_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=1e-6,
)
wq = quantize_affine(
input=preprocessed_w,
block_size=block_size,
scale=scale,
zero_point=zero_point,
output_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
)
, arguments can be found by tracing through the code path for int8 in
new_weight = to_affine_quantized_intx(
and
scale, zero_point = choose_qparams_affine(

this might require a bit too much context, let me know if you would like us to take over

Copy link
Contributor Author

@namgyu-youn namgyu-youn Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, surely want to take over! Drafted this PR for those updates, but will look into it today (6 hours later)

btw, version 2 is updated at c53dad0 (version 1 is default)

@namgyu-youn namgyu-youn marked this pull request as draft September 28, 2025 13:23
@namgyu-youn namgyu-youn marked this pull request as ready for review September 30, 2025 06:09
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please rebase, and let me know when this is ready for review again @namgyu-youn

if not isinstance(activation_tensor, Int8Tensor):
if weight_tensor.act_quant_kwargs.static_scale is not None:
# INT8 × INT8 (static): symmetric quantization only
static_scale = weight_tensor.act_quant_kwargs.static_scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK if this is needed I think it should be included in _choose_quant_func_and_quantize_tensor as well?

implements_torch_function = Int8Tensor.implements_torch_function


@implements([aten.dequantize.self])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed? if not we should remove for now

Comment on lines 142 to 143
if scale.numel() > 1 and scale.shape != qdata_fp.shape:
scale = scale.view(*scale.shape, *[1] * (qdata_fp.ndim - scale.ndim))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed?

Copy link
Contributor Author

@namgyu-youn namgyu-youn Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is needed for block-level granularity. For example,

  1. Row-wise: If scale shape is (64, 1) and w_q (quantized weight shape) is (256, 512), we can naturally broadcast them
  2. Channel-wise: If scale shape is (512,) and w_q is (256, 512), we can naturally broadcast them
  3. Block-size granularity: If scale shape is (32, 64) and w_q is (256, 512), we have to rescale to broadcast them.

But we can also reuse _maybe_expand_scale_to_tensor_shape, similar to:

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
if output_dtype is None:
output_dtype = self.dtype
qdata, scale = self.qdata, self.scale
return _dequantize_affine_float8(qdata, scale, output_dtype)

and

def _dequantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Dequantizes the float8 tensor to high precision tensor.
"""
fp8_tensor = tensor.to(torch.float32)
# Expand scale to match tensor dimensions for block-wise quantization
scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, tensor.shape)
hp_tensor = fp8_tensor * scale_expanded
return hp_tensor.to(output_dtype)

cls: type,
qdata: torch.Tensor,
scale: torch.Tensor,
block_size: list[int],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I remember list has a higher python version requirements, so probably better to change this to List from typing I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it is only for List, not for Dict, Tuple, etc.?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably also for Dict and Tuple, I have only tried list before

Copy link
Contributor Author

@namgyu-youn namgyu-youn Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a use case with old (< 3.9) python version? I remember list, dict, and tuple is natively supported to 3.10 https://docs.astral.sh/ruff/rules/non-pep585-annotation/.

Because PEP585 (type hint) is chained with pre-commit issue, prefer to focus on new versions (no need for from typing import List, Dict, Tuple). How about using list, dict, tuple focusing on new python versions? There might be a new issue by PEP 585 if we go with Dict, List, Tuple I feel.

return module


def _unwrap_float8_linear(module: Float8Linear) -> nn.Linear:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some rebase issue?

Copy link
Contributor Author

@namgyu-youn namgyu-youn Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I assumed 0a45f90 was a wrong way, which was the start of the rebase issue. The solution looks like dropping relevant commits using rebase.

But rebasing after all those commits is overwhelming to me. So, I really don't want to open a duplicate PR, but may I reopen the PR and link to this PR? I just want to remove unrelevant codes change log.

Copy link
Contributor

@jerryzh168 jerryzh168 Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@namgyu-youn please feel free to close and reopen a new one if it's hard to fix rebase issue, seems like it's still not fully fixed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just want to remove unrelevant change logs as you mentioned.

@namgyu-youn
Copy link
Contributor Author

namgyu-youn commented Oct 17, 2025

Updated log:

To reviewers:
Unfortunately, I can't build and run local tests, caused by #2919, after trying downgrade and gradual installation. Please feel free to direct commit if test_int8_tensor.py fails.

torchao/utils.py Outdated
from importlib.metadata import version
from math import gcd
from typing import Any, Callable, Optional, Type
from typing import Any, Callable, Optional
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please fix rebase to not have these changes, or open a new PR if you don't know how to fix rebase

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please fix rebase, otherwise seems mostly OK I think

@namgyu-youn namgyu-youn marked this pull request as draft October 23, 2025 12:27
@namgyu-youn namgyu-youn marked this pull request as ready for review October 23, 2025 17:37
@common_utils.parametrize(
"sizes",
[
((128,), 256, 128),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does 3D inputs work? e.g. ((32, 128,), 256, 128),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, 3D input raise ValueError triggered by from_hp(): #3038 (comment)

assert error > 20, f"Quantization error is too high got a SQNR of {error}"

@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_static_quantization(self, dtype):
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test the static quant config and test that as well Int8StaticActivationInt8WeightConfig

or maybe you can remove this for now and coordinate with @Xia-Weiwen (https://github.com/pytorch/ao/pull/3089/files#diff-bf4d50867e3d649de2d89146592bf47d2f258c4c19126c8acf0e120ee904b726) to add the static quant support separately?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both are ok for me. Thanks.

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will need to test the static quant as well, if that is added

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes in this PR should be reverted as well

)


@implements(aten.select.int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this tested? if not, please remove for now

Comment on lines 230 to 237
if tensor.scale.numel() == 1:
# Per-tensor quantization - scale doesn't change
sliced_scale = tensor.scale
elif dim < tensor.scale.ndim and tensor.scale.shape[dim] > 1:
# Block-wise quantization - need to slice the scale appropriately
sliced_scale = func(tensor.scale, dim, start, end, step)
else:
sliced_scale = tensor.scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you match the implementation with Float8Tensor?

if self.scale.numel() == 1:
# Per-tensor quantization - scale doesn't change
sliced_scale = self.scale
else:
# Block-wise quantization - need to slice the scale appropriately
sliced_scale = _slice_scale_for_dimension(
self.scale, self.qdata.shape, dim, start, end, step
)

@namgyu-youn
Copy link
Contributor Author

Updated logs:
In 062f3cc

In 680cec9

  • Update setUp for common used args
  • Split test_error_handling_and_dequant unit test into test_invalid_input_handling & test_dequantization_accuracy

@namgyu-youn
Copy link
Contributor Author

@jerryzh168 sorry for the multiple-PRs again; reopened PR (#3241) is copy-pasted after the last commit in this PR and resolves rebase errors. Please check the above comment (change log; #3038 (comment)) first, and please take a look at #3241, thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants