- 
                Notifications
    You must be signed in to change notification settings 
- Fork 13.5k
CUDA: Volta tensor core support for MMF #16843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: Volta tensor core support for MMF #16843
Conversation
6537cc8    to
    12e61ec      
    Compare
  
    12e61ec    to
    5d14386      
    Compare
  
    | How about adding something like this to the tiles:         static constexpr __device__ bool supported() {
            if (I ==  8 && J ==  4) return true;
            if (I ==  8 && J ==  8) return true;
            if (I == 16 && J ==  8) return true;
            if (I == 16 && J == 16) return true;
            if (I == 32 && J ==  8) return true;
            return false;
        }We can use that to check whether specific combinations of tile sizes/types are supported and use that to determine which, if any, code to run. It would then be the responsibility of the kernel to check for support using these methods. I'm not sure whether we should keep the static asserts or use  | 
| This function would return values based on which architecture this is being compiled for, right? If so this makes sense to me | 
| return threadIdx.x / 4; | ||
| } else if constexpr (I == 16 && J == 4) { | ||
| return l * 8 + threadIdx.x / 4; | ||
| return (l * 8) | (threadIdx.x / 4); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this more performant? I find the earlier version easier to read
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a good thing you're asking that because I misremembered the table from the CUDA documentation showing instruction throughput. The way I remembered it integer additions and binary operations had the same throughput but on a silicon level you would have lower power draw. In actually though the throughput of additions is twice that of binary operations.
Co-authored-by: Aman Gupta <amangupta052@gmail.com>
This PR adds support for Volta tensor cores to the
mul_mat_fkernel. The longer-term goal is to enable these tensor cores also for the MMA FlashAttention kernel.Performance changes