Skip to content

Conversation

@JohannesGaessler
Copy link
Collaborator

This PR adds support for Volta tensor cores to the mul_mat_f kernel. The longer-term goal is to enable these tensor cores also for the MMA FlashAttention kernel.

Performance changes
GPU Model Microbatch size Test t/s master t/s 6537cc8 Speedup
V100-PCIE-32GB deepseek2 16B F16 1 pp512 88.54 88.39 1.00
V100-PCIE-32GB deepseek2 16B F16 2 pp512 72.62 141.05 1.94
V100-PCIE-32GB deepseek2 16B F16 3 pp512 89.10 174.72 1.96
V100-PCIE-32GB deepseek2 16B F16 4 pp512 101.98 229.25 2.25
V100-PCIE-32GB deepseek2 16B F16 5 pp512 114.94 266.92 2.32
V100-PCIE-32GB deepseek2 16B F16 6 pp512 130.69 301.50 2.31
V100-PCIE-32GB deepseek2 16B F16 7 pp512 142.80 331.81 2.32
V100-PCIE-32GB deepseek2 16B F16 8 pp512 154.34 363.39 2.35
V100-PCIE-32GB deepseek2 16B F16 9 pp512 163.69 376.59 2.30
V100-PCIE-32GB deepseek2 16B F16 10 pp512 179.43 409.79 2.28
V100-PCIE-32GB deepseek2 16B F16 11 pp512 190.28 437.53 2.30
V100-PCIE-32GB deepseek2 16B F16 12 pp512 205.59 468.86 2.28
V100-PCIE-32GB deepseek2 16B F16 13 pp512 206.84 477.87 2.31
V100-PCIE-32GB deepseek2 16B F16 14 pp512 223.11 511.71 2.29
V100-PCIE-32GB deepseek2 16B F16 15 pp512 233.90 531.07 2.27
V100-PCIE-32GB deepseek2 16B F16 16 pp512 241.63 548.72 2.27
V100-PCIE-32GB deepseek2 16B F16 32 pp512 388.11 847.36 2.18
V100-PCIE-32GB deepseek2 16B F16 64 pp512 653.26 1267.11 1.94
V100-PCIE-32GB deepseek2 16B F16 128 pp512 1032.66 1690.62 1.64
V100-PCIE-32GB deepseek2 16B F16 256 pp512 1651.41 1662.73 1.01
V100-PCIE-32GB deepseek2 16B F16 512 pp512 2276.55 2280.14 1.00
V100-PCIE-32GB llama 8B F16 1 pp512 56.56 56.55 1.00
V100-PCIE-32GB llama 8B F16 2 pp512 106.37 106.35 1.00
V100-PCIE-32GB llama 8B F16 3 pp512 121.33 121.35 1.00
V100-PCIE-32GB llama 8B F16 4 pp512 147.60 194.89 1.32
V100-PCIE-32GB llama 8B F16 5 pp512 183.08 240.10 1.31
V100-PCIE-32GB llama 8B F16 6 pp512 218.26 284.60 1.30
V100-PCIE-32GB llama 8B F16 7 pp512 252.91 327.78 1.30
V100-PCIE-32GB llama 8B F16 8 pp512 289.76 375.19 1.29
V100-PCIE-32GB llama 8B F16 9 pp512 330.82 410.80 1.24
V100-PCIE-32GB llama 8B F16 10 pp512 362.85 448.04 1.23
V100-PCIE-32GB llama 8B F16 11 pp512 387.93 489.95 1.26
V100-PCIE-32GB llama 8B F16 12 pp512 422.32 532.48 1.26
V100-PCIE-32GB llama 8B F16 13 pp512 452.50 567.58 1.25
V100-PCIE-32GB llama 8B F16 14 pp512 486.98 609.92 1.25
V100-PCIE-32GB llama 8B F16 15 pp512 517.35 639.23 1.24
V100-PCIE-32GB llama 8B F16 16 pp512 558.80 691.65 1.24

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Oct 29, 2025
@JohannesGaessler
Copy link
Collaborator Author

How about adding something like this to the tiles:

        static constexpr __device__ bool supported() {
            if (I ==  8 && J ==  4) return true;
            if (I ==  8 && J ==  8) return true;
            if (I == 16 && J ==  8) return true;
            if (I == 16 && J == 16) return true;
            if (I == 32 && J ==  8) return true;
            return false;
        }

We can use that to check whether specific combinations of tile sizes/types are supported and use that to determine which, if any, code to run. It would then be the responsibility of the kernel to check for support using these methods. I'm not sure whether we should keep the static asserts or use NO_DEVICE_CODE when trying to use unsupported configurations.

@am17an
Copy link
Collaborator

am17an commented Oct 30, 2025

This function would return values based on which architecture this is being compiled for, right? If so this makes sense to me

return threadIdx.x / 4;
} else if constexpr (I == 16 && J == 4) {
return l * 8 + threadIdx.x / 4;
return (l * 8) | (threadIdx.x / 4);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this more performant? I find the earlier version easier to read

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good thing you're asking that because I misremembered the table from the CUDA documentation showing instruction throughput. The way I remembered it integer additions and binary operations had the same throughput but on a silicon level you would have lower power draw. In actually though the throughput of additions is twice that of binary operations.

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
@JohannesGaessler JohannesGaessler merged commit 31c511a into ggml-org:master Oct 31, 2025
72 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants