The general idea is to introduce a custom nn.Linear layer that maintains 2:4 semi-structured sparsity dynamically throughout training. The layer will be initialized with random weights and corresponding masks that follow 2:4 semi-structured sparsity. When a weight crosses 0 (i.e flips sign), it is automatically pruned (i.e set to 0), freeing up a nonzero slot in its 2:4 group. After another training step, one of the zero weights within the same 2:4 group will be chosen for regrowth, selected based on the magnitude of the densely calculated gradient.
I'd be happy to work on implementing this feature and provide a detailed implementation plan for it, assuming this repository is open to landing this.