Skip to content

Conversation

@namgyu-youn
Copy link
Contributor

@namgyu-youn namgyu-youn commented Oct 24, 2025

Summary:
Introduce a new tensor subclass API. Main features are

  • Int8Tensor: Main API, which handles quantization and dequantization operations
  • Utility operation functions: Tensor slice, index selection

This api is integrated to global variants (Int8WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig) using version, and not defined as a default.

Related Issue/PR:
This is reopened PR for #3038

Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Performance:
The following are the results of https://github.com/pytorch/ao/blob/main/tutorials/quantize_vit/run_vit_b_quant.py with a batch size of 32:

API With torch.compile Without torch.compile
Old API 65.47 ms 234.39 ms
New API 63.30 ms 239.30 ms

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3241

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 24, 2025
Comment on lines 140 to 141
quant_min=self.int8_min,
quant_max=self.int8_max,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can omit these two args if these are the same as default (-128, 127)

)

@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_quantization_shapes(self, dtype):
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be a combination of two tests, one for dynamic quant one for static quant, can you use something like this:

@common_utils.parametrize("mode", ["dynamic", "weight-only"])

also I feel it might be better to not add static quant in this PR, and in a separate PR add both the tensor support and config support for static quant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, not sure to remove static flags (although its not fully implemented) before, but small PR should be always better I feel. I will remove static_scale and all those supports.

if act_quant_kwargs is not None and act_quant_kwargs.static_scale is not None:
# INT8 × INT8 (static)
scale = act_quant_kwargs.static_scale
zero_point = torch.zeros_like(scale, dtype=torch.int8)
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think user should specify static_zero_point as well

but again, it's better to do this in a separate PR, since current state is a half of the static quant feature (no config)

Comment on lines 196 to 225
# Cast fp16 scale to float
intermediate_dtype = (
torch.float if x_scales.dtype == torch.half else x_scales.dtype
)
# Note: CUDA doesn't support int32/int64 matmul, so we convert to float
# Error message is NotImplementedError: "addmm_cuda" not implemented for 'Int'
# This may introduce minor numerical differences compared to int arithmetic
y_dot = torch.mm(tmp.to(intermediate_dtype), w_vals_t.to(intermediate_dtype))

# Apply activation scale
is_per_tensor_act = x_scales.numel() == 1
if is_per_tensor_act:
y_dot.mul_(x_scales.to(intermediate_dtype))
else:
# For block-wise activation scale, reshape to match y_dot
x_scales_reshaped = x_scales.view(y_dot.shape[0], -1)
y_dot.mul_(x_scales_reshaped.to(intermediate_dtype))

# Apply weight scale
is_per_tensor_weight = w_scales.numel() == 1
if is_per_tensor_weight:
result = y_dot.mul_(w_scales.to(intermediate_dtype))
else:
# Per-row weight scale - transpose and broadcast
w_scales_broadcast = w_scales.t().expand_as(y_dot)
result = y_dot.mul_(w_scales_broadcast.to(intermediate_dtype))

# Reshape back to original shape
result = result.view(*x_vals.shape[:-1], result.shape[-1])
result = result.to(activation_tensor.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should follow:

def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should

  1. split the static quant support to separate PR
  2. follow what https://github.com/pytorch/ao/blob/main/torchao/dtypes/uintx/plain_layout.py is doing for quantized linear implementation

this should be a refactor PR, not a refactor + some extra modifications + some feature implementations I think

aten = torch.ops.aten

# Unsupported case for now, this would be 1 scale per data element
# Per-tensor quantization (scalar scale)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change related?

Copy link
Contributor Author

@namgyu-youn namgyu-youn Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is updated to support tensor slicing granularity. Without this change, we can't use per-tensor (0D scale) and per-row (1D scale).

Int8WeightOnlyConfig(version=2),
],
)
def test_per_row_scale_shape(self, dtype, config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test like this one

to test all the variations and include the shape check there?


@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
@common_utils.parametrize("has_bias", [True, False])
def test_weight_only_linear_with_bias(self, dtype, has_bias):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can probably be merged into the linear varaints test as well

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I think the tensor changes looks good, but need to make a linear_variants tests to make sure we cover different aspects of things (e.g. compile), see comments inline

can you also do a e2e perf check with https://github.com/pytorch/ao/blob/main/tutorials/quantize_vit/run_vit_b_quant.py to make sure the performance are the same before and after change for vit model?

also add a kernel check might be useful to make sure we don't regress things:

def test_expected_gpu_kernel_fbgemm(self):

@namgyu-youn
Copy link
Contributor Author

Updated logs:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants