Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"_choose_qparams_and_quantize_affine_hqq",
"_choose_qparams_and_quantize_scale_only_hqq",
"_choose_qparams_and_quantize_affine_qqq",
"_choose_qparams_and_quantize_affine_sinq",
"_choose_scale_float8",
"_choose_qparams_gguf",
"_quantize_affine_no_zero_point",
Expand Down Expand Up @@ -2219,6 +2220,96 @@ def round_stoch(x: torch.Tensor) -> torch.Tensor:
return qdata, scale


def _choose_qparams_and_quantize_affine_sinq(
tensor: torch.Tensor,
nbits: float = 4,
group_size: int = 64,
niter: int = 20,
compute_dtype: torch.dtype = torch.float16,
device: str = "cuda",
verbose: bool = False,
) -> tuple:
"""
SINQ: Sinkhorn-Normalized Quantization (https://www.arxiv.org/abs/2509.22944)

Iteratively normalizes row and column standard deviations to minimize
matrix imbalance before quantization with dual scales.

Args:
tensor: Input weight tensor
nbits: Number of quantization bits (default: 4)
group_size: Quantization group size (default: 64)
niter: Number of Sinkhorn iterations (default: 20)
compute_dtype: Target compute dtype (default: torch.float16)
device: Target device for computation (default: "cuda")

Returns:
Tuple of (W_q, scale_row, zero, scale_col, shape)
"""
if group_size is not None:
assert _is_divisible(tensor.numel(), group_size), (
f"group_size must divide tensor elements. shape: {tensor.shape}, group_size: {group_size}"
)

W = tensor.to(device=device, dtype=torch.float32)
shape = W.shape

# Reshape for 1D tiling
W = W.reshape(-1, group_size) # [N*num_groups, group_size]

# Algorithm 1: Sinkhorn Normalization
q_min = min(W.std(dim=0).min().item(), W.std(dim=1).min().item())
q_min = max(q_min, 1e-8)

W_hat = W.clone()
q_col_acc = torch.ones(W.shape[1], device=device, dtype=torch.float32)
q_row_acc = torch.ones(W.shape[0], device=device, dtype=torch.float32)

for _ in range(niter):
# Normalize columns (dim=0)
q_col = W_hat.std(dim=0) / q_min
q_col = torch.clamp(q_col, min=1e-8)
W_hat = W_hat / q_col.unsqueeze(0)
q_col_acc = q_col_acc * q_col

# Normalize rows (dim=1)
q_row = W_hat.std(dim=1) / q_min
q_row = torch.clamp(q_row, min=1e-8)
W_hat = W_hat / q_row.unsqueeze(1)
q_row_acc = q_row_acc * q_row

# RTN quantization
_min = W_hat.min(dim=1, keepdim=True)[0]
_max = W_hat.max(dim=1, keepdim=True)[0]

max_v = 2**nbits - 1
min_v = 0

scale = (max_v / (_max - _min)).clamp(max=2e4)
zero = -_min * scale

if nbits == 4:
zero = _Round.apply(zero)

W_q = _Round.apply(W_hat * scale + zero).clamp(min_v, max_v)

# Recover with Sinkhorn factors
# W ≈ s (scale_row) ⊙ (Q + z) ⊙ t (scale_col)
scale_row = (1.0 / scale) * q_row_acc.unsqueeze(1) # [N*num_groups, 1]
scale_col = q_col_acc # [group_size]

# Reshape to original dimensions
W_q = W_q.reshape(shape).to(torch.uint8)
scale_row = scale_row.reshape(shape[0], -1).to(compute_dtype)
zero = zero.reshape(shape[0], -1).to(compute_dtype)

# Expand scale_col to original column dimension
num_groups = shape[1] // group_size
scale_col = scale_col.repeat(num_groups).to(compute_dtype)

return W_q, scale_row, zero, scale_col, shape


def _choose_qparams_affine_floatx(
tensor: torch.Tensor, ebits: int, mbits: int
) -> torch.Tensor:
Expand Down
Loading