1111 blockwise_barrier ,
1212 sync_threads ,
1313)
14- from torchao .prototype .mx_formats .config import ScaleCalculationMode
14+ from torchao .prototype .mx_formats .kernels import (
15+ triton_mxfp8_dequant_dim0 ,
16+ triton_to_mxfp8_dim0 ,
17+ )
1518from torchao .prototype .mx_formats .mx_tensor import to_dtype , to_mx
1619
1720
@@ -473,11 +476,9 @@ def forward(
473476 """
474477 # Quantize input
475478 block_size = 32
476- input_scales , input_data = to_mx (
479+ input_data , input_scales = triton_to_mxfp8_dim0 (
477480 input ,
478- elem_dtype = torch .float8_e4m3fn ,
479- block_size = block_size ,
480- scaling_mode = ScaleCalculationMode .RCEIL ,
481+ inner_block_size = block_size ,
481482 )
482483
483484 # Dispatch data (async)
@@ -501,20 +502,17 @@ def forward(
501502 output_data = torch .ops ._c10d_functional .wait_tensor (output_data )
502503
503504 # Dequantize output
504- lowp_dtype = output_data .dtype
505505 hp_dtype = input .dtype
506- hp_output = to_dtype (
506+ triton_hp_output = triton_mxfp8_dequant_dim0 (
507507 output_data ,
508- output_scales .view (torch .float8_e8m0fnu ),
509- lowp_dtype ,
510- block_size ,
508+ output_scales ,
511509 hp_dtype ,
510+ block_size ,
512511 )
513-
514512 ctx .input_splits = input_splits
515513 ctx .output_splits = output_splits
516514 ctx .group = group
517- return hp_output
515+ return triton_hp_output
518516
519517 @staticmethod
520518 def backward (ctx , grad_output_hp ):
@@ -529,11 +527,9 @@ def backward(ctx, grad_output_hp):
529527
530528 # Quantize grad_output
531529 block_size = 32
532- grad_out_scales , grad_out_data = to_mx (
530+ grad_out_data , grad_out_scales = triton_to_mxfp8_dim0 (
533531 grad_output_hp ,
534- elem_dtype = torch .float8_e4m3fn ,
535- block_size = block_size ,
536- scaling_mode = ScaleCalculationMode .RCEIL ,
532+ inner_block_size = block_size ,
537533 )
538534
539535 # Dispatch data (async)
@@ -557,13 +553,11 @@ def backward(ctx, grad_output_hp):
557553 grad_input_scales = torch .ops ._c10d_functional .wait_tensor (grad_input_scales )
558554
559555 hp_dtype = grad_output_hp .dtype
560- lowp_dtype = grad_input_data .dtype
561- grad_input_hp = to_dtype (
556+ grad_input_hp = triton_mxfp8_dequant_dim0 (
562557 grad_input_data ,
563- grad_input_scales .view (torch .float8_e8m0fnu ),
564- lowp_dtype ,
565- block_size ,
558+ grad_input_scales ,
566559 hp_dtype ,
560+ block_size ,
567561 )
568562 return grad_input_hp , None , None , None
569563
0 commit comments