Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate presence of Stride operand to OpCooperativeMatrix{Load,Store}KHR #5777

Merged
merged 4 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions source/val/validate_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2115,16 +2115,23 @@ spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _,

const auto layout_index =
(inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 3u : 2u;
const auto colmajor_id = inst->GetOperandAs<uint32_t>(layout_index);
const auto colmajor = _.FindDef(colmajor_id);
if (!colmajor || !_.IsIntScalarType(colmajor->type_id()) ||
!(spvOpcodeIsConstant(colmajor->opcode()) ||
spvOpcodeIsSpecConstant(colmajor->opcode()))) {
const auto layout_id = inst->GetOperandAs<uint32_t>(layout_index);
const auto layout_inst = _.FindDef(layout_id);
if (!layout_inst || !_.IsIntScalarType(layout_inst->type_id()) ||
!spvOpcodeIsConstant(layout_inst->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "MemoryLayout operand <id> " << _.getIdName(colmajor_id)
<< "MemoryLayout operand <id> " << _.getIdName(layout_id)
<< " must be a 32-bit integer constant instruction.";
}

bool stride_required = false;
uint64_t layout;
if (_.EvalConstantValUint64(layout_id, &layout)) {
stride_required =
(layout == (uint64_t)spv::CooperativeMatrixLayout::RowMajorKHR) ||
(layout == (uint64_t)spv::CooperativeMatrixLayout::ColumnMajorKHR);
}

const auto stride_index =
(inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 4u : 3u;
if (inst->operands().size() > stride_index) {
Expand All @@ -2135,6 +2142,9 @@ spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _,
<< "Stride operand <id> " << _.getIdName(stride_id)
<< " must be a scalar integer type.";
}
} else if (stride_required) {
alan-baker marked this conversation as resolved.
Show resolved Hide resolved
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "MemoryLayout " << layout << " requires a Stride.";
}

const auto memory_access_index =
Expand Down
5 changes: 3 additions & 2 deletions test/opt/aggressive_dead_code_elim_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8018,6 +8018,7 @@ TEST_F(AggressiveDCETest, StoringAPointer) {
%uint_3 = OpConstant %uint 3
%uint_16 = OpConstant %uint 16
%uint_4 = OpConstant %uint 4
%coop_stride = OpConstant %int 42
%_runtimearr_int = OpTypeRuntimeArray %int
%_struct_4 = OpTypeStruct %_runtimearr_int
%_ptr_StorageBuffer__struct_4 = OpTypePointer StorageBuffer %_struct_4
Expand Down Expand Up @@ -8047,7 +8048,7 @@ TEST_F(AggressiveDCETest, StoringAPointer) {
%26 = OpVariable %_ptr_Function__ptr_Function_int Function
%27 = OpVariable %_ptr_Function__struct_18 Function
%28 = OpAccessChain %_ptr_StorageBuffer_int %2 %int_0 %uint_0
%29 = OpCooperativeMatrixLoadKHR %17 %28 %int_1
%29 = OpCooperativeMatrixLoadKHR %17 %28 %int_1 %coop_stride
%30 = OpCompositeConstruct %_struct_18 %29
OpStore %27 %30
%31 = OpAccessChain %_ptr_Function_17 %27 %int_0
Expand All @@ -8059,7 +8060,7 @@ TEST_F(AggressiveDCETest, StoringAPointer) {
OpStore %32 %34
%35 = OpAccessChain %_ptr_StorageBuffer_int %2 %int_0 %uint_64
%36 = OpLoad %17 %31
OpCooperativeMatrixStoreKHR %35 %36 %int_0
OpCooperativeMatrixStoreKHR %35 %36 %int_0 %coop_stride
OpReturn
OpFunctionEnd
)";
Expand Down
114 changes: 95 additions & 19 deletions test/val/val_memory_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2351,7 +2351,11 @@ OpFunctionEnd)";
}

std::string GenCoopMatLoadStoreShaderKHR(const std::string& storeMemoryAccess,
const std::string& loadMemoryAccess) {
const std::string& loadMemoryAccess,
unsigned layout = 0,
bool useSpecConstantLayout = false,
bool useStoreStride = true,
bool useLoadStride = true) {
std::string s = R"(
OpCapability Shader
OpCapability GroupNonUniform
Expand Down Expand Up @@ -2408,11 +2412,18 @@ OpDecorate %129 BuiltIn WorkgroupSize
%33 = OpConstant %6 1024
%34 = OpConstant %6 1
%38 = OpConstant %6 8
%39 = OpConstant %6 0
%uint_0 = OpConstant %6 0
)";
if (useSpecConstantLayout) {
s += "%layout = OpSpecConstant %6 " + std::to_string(layout);
} else {
s += "%layout = OpConstant %6 " + std::to_string(layout);
}
s += R"(
%68 = OpTypeFloat 32
%69 = OpConstant %6 16
%70 = OpConstant %6 3
%71 = OpTypeCooperativeMatrixKHR %68 %70 %69 %38 %39
%71 = OpTypeCooperativeMatrixKHR %68 %70 %69 %38 %uint_0
%72 = OpTypePointer Function %71
%74 = OpTypeRuntimeArray %68
%75 = OpTypeStruct %74
Expand All @@ -2422,7 +2433,7 @@ OpDecorate %129 BuiltIn WorkgroupSize
%79 = OpConstant %78 0
%81 = OpConstant %6 5
%82 = OpTypePointer StorageBuffer %68
%84 = OpConstant %6 64
%stride = OpConstant %6 64
%88 = OpTypePointer Private %71
%89 = OpVariable %88 Private
%92 = OpTypeRuntimeArray %68
Expand Down Expand Up @@ -2478,51 +2489,57 @@ OpStore %18 %30
%35 = OpAccessChain %31 %18 %34
%36 = OpLoad %6 %35
%37 = OpIMul %6 %33 %36
%40 = OpAccessChain %31 %18 %39
%40 = OpAccessChain %31 %18 %uint_0
%41 = OpLoad %6 %40
%42 = OpIMul %6 %38 %41
%43 = OpIAdd %6 %37 %42
OpStore %32 %43
%45 = OpAccessChain %31 %18 %34
%46 = OpLoad %6 %45
%47 = OpIMul %6 %33 %46
%48 = OpAccessChain %31 %18 %39
%48 = OpAccessChain %31 %18 %uint_0
%49 = OpLoad %6 %48
%50 = OpIMul %6 %38 %49
%51 = OpIAdd %6 %47 %50
OpStore %44 %51
%53 = OpAccessChain %31 %18 %34
%54 = OpLoad %6 %53
%55 = OpIMul %6 %33 %54
%56 = OpAccessChain %31 %18 %39
%56 = OpAccessChain %31 %18 %uint_0
%57 = OpLoad %6 %56
%58 = OpIMul %6 %38 %57
%59 = OpIAdd %6 %55 %58
OpStore %52 %59
%61 = OpAccessChain %31 %18 %34
%62 = OpLoad %6 %61
%63 = OpIMul %6 %33 %62
%64 = OpAccessChain %31 %18 %39
%64 = OpAccessChain %31 %18 %uint_0
%65 = OpLoad %6 %64
%66 = OpIMul %6 %38 %65
%67 = OpIAdd %6 %63 %66
OpStore %60 %67
%80 = OpLoad %6 %32
%83 = OpAccessChain %82 %77 %79 %80
%87 = OpCooperativeMatrixLoadKHR %71 %83 %39 %84 )" +
loadMemoryAccess + R"( %81
)";
if (useLoadStride) {
s += "%87 = OpCooperativeMatrixLoadKHR %71 %83 %layout %stride " +
loadMemoryAccess + " %81";
} else {
s += "%87 = OpCooperativeMatrixLoadKHR %71 %83 %layout";
}
s += R"(
OpStore %73 %87
%90 = OpLoad %71 %73
OpStore %89 %90
%96 = OpLoad %6 %44
%97 = OpAccessChain %82 %95 %79 %96
%98 = OpCooperativeMatrixLoadKHR %71 %97 %39 %84 MakePointerVisibleKHR|NonPrivatePointerKHR %81
%98 = OpCooperativeMatrixLoadKHR %71 %97 %layout %stride MakePointerVisibleKHR|NonPrivatePointerKHR %81
OpStore %91 %98
%100 = OpLoad %71 %91
OpStore %99 %100
%106 = OpLoad %6 %52
%107 = OpAccessChain %82 %105 %79 %106
%108 = OpCooperativeMatrixLoadKHR %71 %107 %39 %84 MakePointerVisibleKHR|NonPrivatePointerKHR %81
%108 = OpCooperativeMatrixLoadKHR %71 %107 %layout %stride MakePointerVisibleKHR|NonPrivatePointerKHR %81
OpStore %101 %108
%110 = OpLoad %71 %101
OpStore %109 %110
Expand All @@ -2532,7 +2549,14 @@ OpStore %111 %115
%116 = OpLoad %71 %111
%121 = OpLoad %6 %60
%122 = OpAccessChain %82 %120 %79 %121
OpCooperativeMatrixStoreKHR %122 %116 %39 %84 )" + storeMemoryAccess + R"( %81
)";
if (useStoreStride) {
s += "OpCooperativeMatrixStoreKHR %122 %116 %layout %stride " +
storeMemoryAccess + " %81";
} else {
s += "OpCooperativeMatrixStoreKHR %122 %116 %layout";
}
s += R"(
OpReturn
OpFunctionEnd
)";
Expand All @@ -2549,6 +2573,54 @@ TEST_F(ValidateMemory, CoopMatKHRLoadStoreSuccess) {
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1));
}

struct StrideMissingCase {
unsigned layout;
bool useLoadStride;
bool useStoreStride;
};

using ValidateCoopMatrixStrideMissing =
spvtest::ValidateBase<StrideMissingCase>;

INSTANTIATE_TEST_SUITE_P(
CoopMatrixStrideMissing, ValidateCoopMatrixStrideMissing,
Values(
StrideMissingCase{(unsigned)spv::CooperativeMatrixLayout::RowMajorKHR,
false, true},
StrideMissingCase{(unsigned)spv::CooperativeMatrixLayout::RowMajorKHR,
true, false},
StrideMissingCase{
(unsigned)spv::CooperativeMatrixLayout::ColumnMajorKHR, false,
true},
StrideMissingCase{
(unsigned)spv::CooperativeMatrixLayout::ColumnMajorKHR, true,
false}));

TEST_P(ValidateCoopMatrixStrideMissing, CoopMatKHRLoadStrideMissingFail) {
const StrideMissingCase& param = GetParam();
std::string spirv = GenCoopMatLoadStoreShaderKHR(
"MakePointerAvailableKHR|NonPrivatePointerKHR",
"MakePointerVisibleKHR|NonPrivatePointerKHR", param.layout,
false /*useSpecConstantLayout*/, param.useStoreStride,
param.useLoadStride);
CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(getDiagnosticString(),
HasSubstr("MemoryLayout " + std::to_string(param.layout) +
" requires a Stride"));
}

TEST_F(ValidateMemory, CoopMatKHRMemoryLayoutFromSpecConstantSuccess) {
std::string spirv = GenCoopMatLoadStoreShaderKHR(
"MakePointerAvailableKHR|NonPrivatePointerKHR",
"MakePointerVisibleKHR|NonPrivatePointerKHR",
(unsigned)spv::CooperativeMatrixLayout::RowMajorKHR,
true /*useSpecConstantLayout*/);

CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1));
}

TEST_F(ValidateMemory, CoopMatKHRStoreMemoryAccessFail) {
std::string spirv = GenCoopMatLoadStoreShaderKHR(
"MakePointerVisibleKHR|NonPrivatePointerKHR",
Expand Down Expand Up @@ -6791,11 +6863,12 @@ OpDecorate %array ArrayStride 4
%rows = OpSpecConstant %int 1
%cols = OpSpecConstant %int 1
%matrix_a = OpConstant %int 1
%stride = OpConstant %int 42
%matrix = OpTypeCooperativeMatrixKHR %float %subgroup %rows %cols %matrix_a
%var = OpUntypedVariableKHR %untyped StorageBuffer %block
%main = OpFunction %void None %void_fn
%entry = OpLabel
%ld = OpCooperativeMatrixLoadKHR %matrix %var %int_0
%ld = OpCooperativeMatrixLoadKHR %matrix %var %int_0 %stride
OpReturn
OpFunctionEnd
)";
Expand Down Expand Up @@ -6832,12 +6905,13 @@ OpDecorate %array ArrayStride 4
%rows = OpSpecConstant %int 1
%cols = OpSpecConstant %int 1
%matrix_a = OpConstant %int 1
%stride = OpConstant %int 42
%matrix = OpTypeCooperativeMatrixKHR %float %subgroup %rows %cols %matrix_a
%var = OpUntypedVariableKHR %untyped StorageBuffer %block
%main = OpFunction %void None %void_fn
%entry = OpLabel
%gep = OpUntypedAccessChainKHR %untyped %block %var %int_0 %int_0
%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0
%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0 %stride
OpReturn
OpFunctionEnd
)";
Expand Down Expand Up @@ -6878,14 +6952,15 @@ OpDecorate %array ArrayStride 4
%rows = OpSpecConstant %int 1
%cols = OpSpecConstant %int 1
%matrix_a = OpConstant %int 1
%stride = OpConstant %int 42
%matrix = OpTypeCooperativeMatrixKHR %float %subgroup %rows %cols %matrix_a
%var1 = OpVariable %ptr StorageBuffer
%var2 = OpUntypedVariableKHR %untyped StorageBuffer %block
%main = OpFunction %void None %void_fn
%entry = OpLabel
%gep = OpAccessChain %ptr_float %var1 %int_0 %int_0
%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0
OpCooperativeMatrixStoreKHR %var2 %ld %int_0
%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0 %stride
OpCooperativeMatrixStoreKHR %var2 %ld %int_0 %stride
OpReturn
OpFunctionEnd
)";
Expand Down Expand Up @@ -6926,15 +7001,16 @@ OpDecorate %array ArrayStride 4
%rows = OpSpecConstant %int 1
%cols = OpSpecConstant %int 1
%matrix_a = OpConstant %int 1
%stride = OpConstant %int 42
%matrix = OpTypeCooperativeMatrixKHR %float %subgroup %rows %cols %matrix_a
%var1 = OpVariable %ptr StorageBuffer
%var2 = OpUntypedVariableKHR %untyped StorageBuffer %block
%main = OpFunction %void None %void_fn
%entry = OpLabel
%gep = OpAccessChain %ptr_float %var1 %int_0 %int_0
%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0
%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0 %stride
%gep2 = OpUntypedAccessChainKHR %untyped %block %var2 %int_0 %int_0
OpCooperativeMatrixStoreKHR %gep2 %ld %int_0
OpCooperativeMatrixStoreKHR %gep2 %ld %int_0 %stride
OpReturn
OpFunctionEnd
)";
Expand Down