Skip to content

Commit 1e473ed

Browse files
authored
move float8 blockwise kernels out of prototype (#3256)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent d3bec87 commit 1e473ed

File tree

6 files changed

+7
-5
lines changed

6 files changed

+7
-5
lines changed

.github/workflows/1xL4_tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,4 @@ jobs:
5151
pytest test/dtypes/test_affine_quantized_float.py --verbose -s
5252
./test/float8/test_everything_single_gpu.sh
5353
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py
54+
python test/kernel/test_blockwise_triton.py --verbose -s

benchmarks/benchmark_blockwise_scaled_linear_triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from triton.testing import do_bench
1414

1515
from torchao.float8.float8_utils import compute_error
16-
from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (
16+
from torchao.kernel.blockwise_quantization import (
1717
blockwise_fp8_gemm,
1818
fp8_blockwise_act_quant,
1919
fp8_blockwise_weight_quant,

test/prototype/test_blockwise_triton.py renamed to test/kernel/test_blockwise_triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
triton = pytest.importorskip("triton", reason="Triton required to run this test")
1313

14-
from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (
14+
from torchao.kernel.blockwise_quantization import (
1515
blockwise_fp8_gemm,
1616
fp8_blockwise_act_quant,
1717
fp8_blockwise_weight_dequant,

torchao/prototype/blockwise_fp8_inference/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from .blockwise_linear import BlockwiseQuantLinear
2-
from .blockwise_quantization import (
1+
from torchao.kernel.blockwise_quantization import (
32
blockwise_fp8_gemm,
43
fp8_blockwise_act_quant,
54
fp8_blockwise_weight_dequant,
65
fp8_blockwise_weight_quant,
76
)
87

8+
from .blockwise_linear import BlockwiseQuantLinear
9+
910
__all__ = [
1011
"blockwise_fp8_gemm",
1112
"BlockwiseQuantLinear",

torchao/prototype/blockwise_fp8_inference/blockwise_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from torch import nn
99

10-
from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (
10+
from torchao.kernel.blockwise_quantization import (
1111
blockwise_fp8_gemm,
1212
fp8_blockwise_act_quant,
1313
)

0 commit comments

Comments
 (0)