Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4999,6 +4999,12 @@ def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> {
let Prototype = "bool(bool)";
}

def HLSLWaveActiveBitOr : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_bit_or"];
let Attributes = [NoThrow, Const];
let Prototype = "void (...)";
}

def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_count_bits"];
let Attributes = [NoThrow, Const];
Expand Down
9 changes: 9 additions & 0 deletions clang/lib/CodeGen/CGHLSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
}
case Builtin::BI__builtin_hlsl_wave_active_bit_or: {
Value *Op = EmitScalarExpr(E->getArg(0));
assert(Op->getType()->hasUnsignedIntegerRepresentation() &&
"Intrinsic WaveActiveBitOr operand must have a unsigned integer representation");

Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveBitOrIntrinsic();
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
}
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveBitOr, wave_reduce_or)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count)
Expand Down
30 changes: 30 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -2498,6 +2498,36 @@ __attribute__((convergent)) double3 WaveReadLaneAt(double3, uint32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) double4 WaveReadLaneAt(double4, uint32_t);

//===----------------------------------------------------------------------===//
// WaveActiveBitOr builtins
//===----------------------------------------------------------------------===//

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint WaveActiveBitOr(uint);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint2 WaveActiveBitOr(uint2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint3 WaveActiveBitOr(uint3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint4 WaveActiveBitOr(uint4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint64_t WaveActiveBitOr(uint64_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint64_t2 WaveActiveBitOr(uint64_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint64_t3 WaveActiveBitOr(uint64_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint64_t4 WaveActiveBitOr(uint64_t4);

//===----------------------------------------------------------------------===//
// WaveActiveMax builtins
//===----------------------------------------------------------------------===//
Expand Down
23 changes: 23 additions & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3211,6 +3211,29 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
TheCall->setType(ArgTyExpr);
break;
}
case Builtin::BI__builtin_hlsl_wave_active_bit_or: {
if (SemaRef.checkArgCount(TheCall, 1))
return true;

// Ensure input expr type is a scalar/vector and the same as the return type
if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
return true;
if (CheckWaveActive(&SemaRef, TheCall))
return true;

// Ensure expression parameter type can be interpreted as a uint
ExprResult Expr = TheCall->getArg(0);
QualType ArgTyExpr = Expr.get()->getType();
if (!ArgTyExpr->isIntegerType()) {
SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyExpr << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
return true;
}

TheCall->setType(ArgTyExpr);
break;
}
// Note these are llvm builtins that we want to catch invalid intrinsic
// generation. Normal handling of these builitns will occur elsewhere.
case Builtin::BI__builtin_elementwise_bitreverse: {
Expand Down
30 changes: 30 additions & 0 deletions clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV

// Test basic lowering to runtime function call.

// CHECK-LABEL: test_uint
uint test_uint(uint expr) {
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.or.i32([[TY]] %[[#]])
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.reduce.or.i32([[TY]] %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return WaveActiveBitOr(expr);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.or.i32([[TY]]) #[[#attr:]]
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.or.i32([[TY]]) #[[#attr:]]

// CHECK-LABEL: test_uint64_t
uint64_t test_uint64_t(uint64_t expr) {
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.or.i64([[TY]] %[[#]])
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.reduce.or.i64([[TY]] %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return WaveActiveBitOr(expr);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.or.i64([[TY]]) #[[#attr:]]
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.or.i64([[TY]]) #[[#attr:]]
38 changes: 38 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify

uint test_too_few_arg() {
return __builtin_hlsl_wave_active_bit_or();
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
}

uint2 test_too_many_arg(uint2 p0) {
return __builtin_hlsl_wave_active_bit_or(p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
}

bool test_expr_bool_type_check(bool p0) {
return __builtin_hlsl_wave_active_bit_or(p0);
// expected-error@-1 {{invalid operand of type 'bool'}}
}

float test_expr_float_type_check(float p0) {
return __builtin_hlsl_wave_active_bit_or(p0);
// expected-error@-1 {{invalid operand of type 'float'}}
}

bool2 test_expr_bool_vec_type_check(bool2 p0) {
return __builtin_hlsl_wave_active_bit_or(p0);
// expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}}
}

float2 test_expr_float_type_check(float2 p0) {
return __builtin_hlsl_wave_active_bit_or(p0);
// expected-error@-1 {{invalid operand of type 'float2' (aka 'vector<float, 2>')}}
}

struct S { float f; };

S test_expr_struct_type_check(S p0) {
return __builtin_hlsl_wave_active_bit_or(p0);
// expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}}
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1
def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_or : DefaultAttrsIntrinsic<[llvm_anyint_ty], [llvm_anyint_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
Expand Down
3 changes: 2 additions & 1 deletion llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_or : DefaultAttrsIntrinsic<[llvm_anyint_ty], [llvm_anyint_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
Expand All @@ -136,7 +137,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_sclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_spv_nclamp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;

// Create resource handle given the binding information. Returns a
// Create resource handle given the binding information. Returns a
// type appropriate for the kind of resource given the set id, binding id,
// array size of the binding, as well as an index and an indicator
// whether that index may be non-uniform.
Expand Down
22 changes: 22 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,10 @@ defvar WaveOpKind_Product = 1;
defvar WaveOpKind_Min = 2;
defvar WaveOpKind_Max = 3;

defvar WaveBitOpKind_And = 0;
defvar WaveBitOpKind_Or = 1;
defvar WaveBitOpKind_Xor = 2;

defvar SignedOpKind_Signed = 0;
defvar SignedOpKind_Unsigned = 1;

Expand Down Expand Up @@ -1069,6 +1073,24 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> {
let attributes = [Attributes<DXIL1_0, []>];
}

def WaveActiveBit : DXILOp<120, waveActiveBit> {
let Doc = "returns the result of the operation across waves";
let intrinsics = [
IntrinSelect<int_dx_wave_reduce_or,
[
IntrinArgIndex<0>, IntrinArgI8<WaveBitOpKind_Or>,
]>,
];

let arguments = [OverloadTy, Int8Ty];
let result = OverloadTy;
let overloads = [
Overloads<DXIL1_0, [Int32Ty, Int64Ty]>
];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, []>];
}

def WaveAllBitCount : DXILOp<135, waveAllOp> {
let Doc = "returns the count of bits set to 1 across the wave";
let intrinsics = [IntrinSelect<int_dx_wave_active_countbits>];
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/DXILShaderFlags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ static bool checkWaveOps(Intrinsic::ID IID) {
case Intrinsic::dx_wave_readlane:
case Intrinsic::dx_wave_active_countbits:
// Wave Active Op Variants
case Intrinsic::dx_wave_reduce_or:
case Intrinsic::dx_wave_reduce_sum:
case Intrinsic::dx_wave_reduce_usum:
case Intrinsic::dx_wave_reduce_max:
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
case Intrinsic::dx_saturate:
case Intrinsic::dx_splitdouble:
case Intrinsic::dx_wave_readlane:
case Intrinsic::dx_wave_reduce_or:
case Intrinsic::dx_wave_reduce_max:
case Intrinsic::dx_wave_reduce_sum:
case Intrinsic::dx_wave_reduce_umax:
Expand Down
31 changes: 31 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectDot4AddPackedExpansion(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectWaveReduceOr(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectWaveReduceMax(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, bool IsUnsigned) const;

Expand Down Expand Up @@ -2427,6 +2430,32 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
return Result;
}

bool SPIRVInstructionSelector::selectWaveReduceOr(
Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {

assert(I.getNumOperands() == 3);
assert(I.getOperand(2).isReg());
MachineBasicBlock &BB = *I.getParent();
Register InputRegister = I.getOperand(2).getReg();
SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);

if (!InputType)
report_fatal_error("Input Type could not be determined.");

SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);

auto Opcode = SPIRV::OpGroupNonUniformBitwiseOr;

return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
!STI.isShader()))
.addImm(SPIRV::GroupOperation::Reduce)
.addUse(I.getOperand(2).getReg())
.constrainAllUses(TII, TRI, RBI);
}

bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
Expand Down Expand Up @@ -3427,6 +3456,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
case Intrinsic::spv_wave_is_first_lane:
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformElect);
case Intrinsic::spv_wave_reduce_or:
return selectWaveReduceOr(ResVReg, ResType, I);
case Intrinsic::spv_wave_reduce_umax:
return selectWaveReduceMax(ResVReg, ResType, I, /*IsUnsigned*/ true);
case Intrinsic::spv_wave_reduce_max:
Expand Down
7 changes: 7 additions & 0 deletions llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ entry:
ret i1 %ret
}

define noundef i32 @wave_bit_or(i32 %x) {
entry:
; CHECK: Function wave_bit_or : [[WAVE_FLAG]]
%ret = call i32 @llvm.dx.wave.reduce.or(i32 %x)
ret i32 %ret
}

define noundef i1 @wave_readlane(i1 %x, i32 %idx) {
entry:
; CHECK: Function wave_readlane : [[WAVE_FLAG]]
Expand Down
19 changes: 19 additions & 0 deletions llvm/test/CodeGen/DirectX/WaveActiveBitOr.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s

define noundef i32 @wave_bitor_simple(i32 noundef %p1) {
entry:
; CHECK: call i32 @dx.op.waveActiveBit.i32(i32 120, i32 %p1, i8 1){{$}}
%ret = call i32 @llvm.dx.wave.reduce.or.i32(i32 %p1)
ret i32 %ret
}

declare i32 @llvm.dx.wave.reduce.or.i32(i32)

define noundef i64 @wave_bitor_simple64(i64 noundef %p1) {
entry:
; CHECK: call i64 @dx.op.waveActiveBit.i64(i32 120, i64 %p1, i8 1){{$}}
%ret = call i64 @llvm.dx.wave.reduce.or.i64(i64 %p1)
ret i64 %ret
}

declare i64 @llvm.dx.wave.reduce.or.i64(i64)
30 changes: 30 additions & 0 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}

; Test lowering to spir-v backend for various types and scalar/vector

; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#uint64:]] = OpTypeInt 64 0
; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3

; CHECK-LABEL: Begin function test_uint
; CHECK: %[[#iexpr:]] = OpFunctionParameter %[[#uint]]
define i32 @test_uint(i32 %iexpr) {
entry:
; CHECK: %[[#iret:]] = OpGroupNonUniformBitwiseOr %[[#uint]] %[[#scope]] Reduce %[[#iexpr]]
%0 = call i32 @llvm.spv.wave.reduce.or.i32(i32 %iexpr)
ret i32 %0
}

declare i32 @llvm.spv.wave.reduce.or.i32(i32)

; CHECK-LABEL: Begin function test_uint64
; CHECK: %[[#iexpr64:]] = OpFunctionParameter %[[#uint64]]
define i64 @test_uint64(i64 %iexpr64) {
entry:
; CHECK: %[[#iret:]] = OpGroupNonUniformBitwiseOr %[[#uint64]] %[[#scope]] Reduce %[[#iexpr64]]
%0 = call i64 @llvm.spv.wave.reduce.or.i64(i64 %iexpr64)
ret i64 %0
}

declare i64 @llvm.spv.wave.reduce.or.i64(i64)