Skip to content

Commit 710192d

Browse files
[mxfp8 moe training] integrate triton quant/dequant kernels into mxfp8 all to all
stack-info: PR: #3197, branch: danielvegamyhre/stack/79
1 parent 82ded0b commit 710192d

File tree

1 file changed

+15
-21
lines changed
  • torchao/prototype/moe_training/kernels/mxfp8

1 file changed

+15
-21
lines changed

torchao/prototype/moe_training/kernels/mxfp8/comms.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
blockwise_barrier,
1212
sync_threads,
1313
)
14-
from torchao.prototype.mx_formats.config import ScaleCalculationMode
14+
from torchao.prototype.mx_formats.kernels import (
15+
triton_mxfp8_dequant_dim0,
16+
triton_to_mxfp8_dim0,
17+
)
1518
from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx
1619

1720

@@ -473,11 +476,9 @@ def forward(
473476
"""
474477
# Quantize input
475478
block_size = 32
476-
input_scales, input_data = to_mx(
479+
input_data, input_scales = triton_to_mxfp8_dim0(
477480
input,
478-
elem_dtype=torch.float8_e4m3fn,
479-
block_size=block_size,
480-
scaling_mode=ScaleCalculationMode.RCEIL,
481+
inner_block_size=block_size,
481482
)
482483

483484
# Dispatch data (async)
@@ -501,20 +502,17 @@ def forward(
501502
output_data = torch.ops._c10d_functional.wait_tensor(output_data)
502503

503504
# Dequantize output
504-
lowp_dtype = output_data.dtype
505505
hp_dtype = input.dtype
506-
hp_output = to_dtype(
506+
triton_hp_output = triton_mxfp8_dequant_dim0(
507507
output_data,
508-
output_scales.view(torch.float8_e8m0fnu),
509-
lowp_dtype,
510-
block_size,
508+
output_scales,
511509
hp_dtype,
510+
block_size,
512511
)
513-
514512
ctx.input_splits = input_splits
515513
ctx.output_splits = output_splits
516514
ctx.group = group
517-
return hp_output
515+
return triton_hp_output
518516

519517
@staticmethod
520518
def backward(ctx, grad_output_hp):
@@ -529,11 +527,9 @@ def backward(ctx, grad_output_hp):
529527

530528
# Quantize grad_output
531529
block_size = 32
532-
grad_out_scales, grad_out_data = to_mx(
530+
grad_out_data, grad_out_scales = triton_to_mxfp8_dim0(
533531
grad_output_hp,
534-
elem_dtype=torch.float8_e4m3fn,
535-
block_size=block_size,
536-
scaling_mode=ScaleCalculationMode.RCEIL,
532+
inner_block_size=block_size,
537533
)
538534

539535
# Dispatch data (async)
@@ -557,13 +553,11 @@ def backward(ctx, grad_output_hp):
557553
grad_input_scales = torch.ops._c10d_functional.wait_tensor(grad_input_scales)
558554

559555
hp_dtype = grad_output_hp.dtype
560-
lowp_dtype = grad_input_data.dtype
561-
grad_input_hp = to_dtype(
556+
grad_input_hp = triton_mxfp8_dequant_dim0(
562557
grad_input_data,
563-
grad_input_scales.view(torch.float8_e8m0fnu),
564-
lowp_dtype,
565-
block_size,
558+
grad_input_scales,
566559
hp_dtype,
560+
block_size,
567561
)
568562
return grad_input_hp, None, None, None
569563

0 commit comments

Comments
 (0)