diff --git a/torchao/float8/config.py b/torchao/float8/config.py index b362390946..d87e97f26d 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -187,7 +187,7 @@ class Float8LinearConfig: # inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls # _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16. # This can cause a memory spike however so we keep this off by default. - pad_inner_dim: bool = False + pad_inner_dim: bool = True # If True, emulation is used instead of hardware accelerated gemm emulate: bool = False