Skip to content

Commit

Permalink
Validate presence of Stride operand to OpCooperativeMatrix{Load,Store…
Browse files Browse the repository at this point in the history
…}KHR (KhronosGroup#5777)

* Validate Stride operand to OpCooperativeMatrix{Load,Store}KHR

The specification requires the Stride operand for the RowMajorKHR and
ColumnMajorKHR layouts.

Signed-off-by: Kevin Petit <kevin.petit@arm.com>
Change-Id: I51084b9b8dedebf9cab7ae25334ee56b75ef0126

* Update source/val/validate_memory.cpp

Co-authored-by: alan-baker <alanbaker@google.com>

* add test to exercise memory layout from spec constant and fix validation

Change-Id: I06d7308c4a2b62d26d69e88e03bfa009a7f8fff3

* format fixes

Change-Id: I9cbabec0ed2172dcd228cc385551cb7a5b79df1a

---------

Signed-off-by: Kevin Petit <kevin.petit@arm.com>
Co-authored-by: alan-baker <alanbaker@google.com>
  • Loading branch information
2 people authored and Keenuts committed Nov 12, 2024
1 parent 0e0c17c commit d3eac11
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 27 deletions.
22 changes: 16 additions & 6 deletions source/val/validate_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2115,16 +2115,23 @@ spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _,

const auto layout_index =
(inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 3u : 2u;
const auto colmajor_id = inst->GetOperandAs<uint32_t>(layout_index);
const auto colmajor = _.FindDef(colmajor_id);
if (!colmajor || !_.IsIntScalarType(colmajor->type_id()) ||
!(spvOpcodeIsConstant(colmajor->opcode()) ||
spvOpcodeIsSpecConstant(colmajor->opcode()))) {
const auto layout_id = inst->GetOperandAs<uint32_t>(layout_index);
const auto layout_inst = _.FindDef(layout_id);
if (!layout_inst || !_.IsIntScalarType(layout_inst->type_id()) ||
!spvOpcodeIsConstant(layout_inst->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "MemoryLayout operand <id> " << _.getIdName(colmajor_id)
<< "MemoryLayout operand <id> " << _.getIdName(layout_id)
<< " must be a 32-bit integer constant instruction.";
}

bool stride_required = false;
uint64_t layout;
if (_.EvalConstantValUint64(layout_id, &layout)) {
stride_required =
(layout == (uint64_t)spv::CooperativeMatrixLayout::RowMajorKHR) ||
(layout == (uint64_t)spv::CooperativeMatrixLayout::ColumnMajorKHR);
}

const auto stride_index =
(inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 4u : 3u;
if (inst->operands().size() > stride_index) {
Expand All @@ -2135,6 +2142,9 @@ spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _,
<< "Stride operand <id> " << _.getIdName(stride_id)
<< " must be a scalar integer type.";
}
} else if (stride_required) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "MemoryLayout " << layout << " requires a Stride.";
}

const auto memory_access_index =
Expand Down
5 changes: 3 additions & 2 deletions test/opt/aggressive_dead_code_elim_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8018,6 +8018,7 @@ TEST_F(AggressiveDCETest, StoringAPointer) {
%uint_3 = OpConstant %uint 3
%uint_16 = OpConstant %uint 16
%uint_4 = OpConstant %uint 4
%coop_stride = OpConstant %int 42
%_runtimearr_int = OpTypeRuntimeArray %int
%_struct_4 = OpTypeStruct %_runtimearr_int
%_ptr_StorageBuffer__struct_4 = OpTypePointer StorageBuffer %_struct_4
Expand Down Expand Up @@ -8047,7 +8048,7 @@ TEST_F(AggressiveDCETest, StoringAPointer) {
%26 = OpVariable %_ptr_Function__ptr_Function_int Function
%27 = OpVariable %_ptr_Function__struct_18 Function
%28 = OpAccessChain %_ptr_StorageBuffer_int %2 %int_0 %uint_0
%29 = OpCooperativeMatrixLoadKHR %17 %28 %int_1
%29 = OpCooperativeMatrixLoadKHR %17 %28 %int_1 %coop_stride
%30 = OpCompositeConstruct %_struct_18 %29
OpStore %27 %30
%31 = OpAccessChain %_ptr_Function_17 %27 %int_0
Expand All @@ -8059,7 +8060,7 @@ TEST_F(AggressiveDCETest, StoringAPointer) {
OpStore %32 %34
%35 = OpAccessChain %_ptr_StorageBuffer_int %2 %int_0 %uint_64
%36 = OpLoad %17 %31
OpCooperativeMatrixStoreKHR %35 %36 %int_0
OpCooperativeMatrixStoreKHR %35 %36 %int_0 %coop_stride
OpReturn
OpFunctionEnd
)";
Expand Down
114 changes: 95 additions & 19 deletions test/val/val_memory_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2351,7 +2351,11 @@ OpFunctionEnd)";
}

std::string GenCoopMatLoadStoreShaderKHR(const std::string& storeMemoryAccess,
const std::string& loadMemoryAccess) {
const std::string& loadMemoryAccess,
unsigned layout = 0,
bool useSpecConstantLayout = false,
bool useStoreStride = true,
bool useLoadStride = true) {
std::string s = R"(
OpCapability Shader
OpCapability GroupNonUniform
Expand Down Expand Up @@ -2408,11 +2412,18 @@ OpDecorate %129 BuiltIn WorkgroupSize
%33 = OpConstant %6 1024
%34 = OpConstant %6 1
%38 = OpConstant %6 8
%39 = OpConstant %6 0
%uint_0 = OpConstant %6 0
)";
if (useSpecConstantLayout) {
s += "%layout = OpSpecConstant %6 " + std::to_string(layout);
} else {
s += "%layout = OpConstant %6 " + std::to_string(layout);
}
s += R"(
%68 = OpTypeFloat 32
%69 = OpConstant %6 16
%70 = OpConstant %6 3
%71 = OpTypeCooperativeMatrixKHR %68 %70 %69 %38 %39
%71 = OpTypeCooperativeMatrixKHR %68 %70 %69 %38 %uint_0
%72 = OpTypePointer Function %71
%74 = OpTypeRuntimeArray %68
%75 = OpTypeStruct %74
Expand All @@ -2422,7 +2433,7 @@ OpDecorate %129 BuiltIn WorkgroupSize
%79 = OpConstant %78 0
%81 = OpConstant %6 5
%82 = OpTypePointer StorageBuffer %68
%84 = OpConstant %6 64
%stride = OpConstant %6 64
%88 = OpTypePointer Private %71
%89 = OpVariable %88 Private
%92 = OpTypeRuntimeArray %68
Expand Down Expand Up @@ -2478,51 +2489,57 @@ OpStore %18 %30
%35 = OpAccessChain %31 %18 %34
%36 = OpLoad %6 %35
%37 = OpIMul %6 %33 %36
%40 = OpAccessChain %31 %18 %39
%40 = OpAccessChain %31 %18 %uint_0
%41 = OpLoad %6 %40
%42 = OpIMul %6 %38 %41
%43 = OpIAdd %6 %37 %42
OpStore %32 %43
%45 = OpAccessChain %31 %18 %34
%46 = OpLoad %6 %45
%47 = OpIMul %6 %33 %46
%48 = OpAccessChain %31 %18 %39
%48 = OpAccessChain %31 %18 %uint_0
%49 = OpLoad %6 %48
%50 = OpIMul %6 %38 %49
%51 = OpIAdd %6 %47 %50
OpStore %44 %51
%53 = OpAccessChain %31 %18 %34
%54 = OpLoad %6 %53
%55 = OpIMul %6 %33 %54
%56 = OpAccessChain %31 %18 %39
%56 = OpAccessChain %31 %18 %uint_0
%57 = OpLoad %6 %56
%58 = OpIMul %6 %38 %57
%59 = OpIAdd %6 %55 %58
OpStore %52 %59
%61 = OpAccessChain %31 %18 %34
%62 = OpLoad %6 %61
%63 = OpIMul %6 %33 %62
%64 = OpAccessChain %31 %18 %39
%64 = OpAccessChain %31 %18 %uint_0
%65 = OpLoad %6 %64
%66 = OpIMul %6 %38 %65
%67 = OpIAdd %6 %63 %66
OpStore %60 %67
%80 = OpLoad %6 %32
%83 = OpAccessChain %82 %77 %79 %80
%87 = OpCooperativeMatrixLoadKHR %71 %83 %39 %84 )" +
loadMemoryAccess + R"( %81
)";
if (useLoadStride) {
s += "%87 = OpCooperativeMatrixLoadKHR %71 %83 %layout %stride " +
loadMemoryAccess + " %81";
} else {
s += "%87 = OpCooperativeMatrixLoadKHR %71 %83 %layout";
}
s += R"(
OpStore %73 %87
%90 = OpLoad %71 %73
OpStore %89 %90
%96 = OpLoad %6 %44
%97 = OpAccessChain %82 %95 %79 %96
%98 = OpCooperativeMatrixLoadKHR %71 %97 %39 %84 MakePointerVisibleKHR|NonPrivatePointerKHR %81
%98 = OpCooperativeMatrixLoadKHR %71 %97 %layout %stride MakePointerVisibleKHR|NonPrivatePointerKHR %81
OpStore %91 %98
%100 = OpLoad %71 %91
OpStore %99 %100
%106 = OpLoad %6 %52
%107 = OpAccessChain %82 %105 %79 %106
%108 = OpCooperativeMatrixLoadKHR %71 %107 %39 %84 MakePointerVisibleKHR|NonPrivatePointerKHR %81
%108 = OpCooperativeMatrixLoadKHR %71 %107 %layout %stride MakePointerVisibleKHR|NonPrivatePointerKHR %81
OpStore %101 %108
%110 = OpLoad %71 %101
OpStore %109 %110
Expand All @@ -2532,7 +2549,14 @@ OpStore %111 %115
%116 = OpLoad %71 %111
%121 = OpLoad %6 %60
%122 = OpAccessChain %82 %120 %79 %121
OpCooperativeMatrixStoreKHR %122 %116 %39 %84 )" + storeMemoryAccess + R"( %81
)";
if (useStoreStride) {
s += "OpCooperativeMatrixStoreKHR %122 %116 %layout %stride " +
storeMemoryAccess + " %81";
} else {
s += "OpCooperativeMatrixStoreKHR %122 %116 %layout";
}
s += R"(
OpReturn
OpFunctionEnd
)";
Expand All @@ -2549,6 +2573,54 @@ TEST_F(ValidateMemory, CoopMatKHRLoadStoreSuccess) {
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1));
}

struct StrideMissingCase {
unsigned layout;
bool useLoadStride;
bool useStoreStride;
};

using ValidateCoopMatrixStrideMissing =
spvtest::ValidateBase<StrideMissingCase>;

INSTANTIATE_TEST_SUITE_P(
CoopMatrixStrideMissing, ValidateCoopMatrixStrideMissing,
Values(
StrideMissingCase{(unsigned)spv::CooperativeMatrixLayout::RowMajorKHR,
false, true},
StrideMissingCase{(unsigned)spv::CooperativeMatrixLayout::RowMajorKHR,
true, false},
StrideMissingCase{
(unsigned)spv::CooperativeMatrixLayout::ColumnMajorKHR, false,
true},
StrideMissingCase{
(unsigned)spv::CooperativeMatrixLayout::ColumnMajorKHR, true,
false}));

TEST_P(ValidateCoopMatrixStrideMissing, CoopMatKHRLoadStrideMissingFail) {
const StrideMissingCase& param = GetParam();
std::string spirv = GenCoopMatLoadStoreShaderKHR(
"MakePointerAvailableKHR|NonPrivatePointerKHR",
"MakePointerVisibleKHR|NonPrivatePointerKHR", param.layout,
false /*useSpecConstantLayout*/, param.useStoreStride,
param.useLoadStride);
CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(getDiagnosticString(),
HasSubstr("MemoryLayout " + std::to_string(param.layout) +
" requires a Stride"));
}

TEST_F(ValidateMemory, CoopMatKHRMemoryLayoutFromSpecConstantSuccess) {
std::string spirv = GenCoopMatLoadStoreShaderKHR(
"MakePointerAvailableKHR|NonPrivatePointerKHR",
"MakePointerVisibleKHR|NonPrivatePointerKHR",
(unsigned)spv::CooperativeMatrixLayout::RowMajorKHR,
true /*useSpecConstantLayout*/);

CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1));
}

TEST_F(ValidateMemory, CoopMatKHRStoreMemoryAccessFail) {
std::string spirv = GenCoopMatLoadStoreShaderKHR(
"MakePointerVisibleKHR|NonPrivatePointerKHR",
Expand Down Expand Up @@ -6791,11 +6863,12 @@ OpDecorate %array ArrayStride 4
%rows = OpSpecConstant %int 1
%cols = OpSpecConstant %int 1
%matrix_a = OpConstant %int 1
%stride = OpConstant %int 42
%matrix = OpTypeCooperativeMatrixKHR %float %subgroup %rows %cols %matrix_a
%var = OpUntypedVariableKHR %untyped StorageBuffer %block
%main = OpFunction %void None %void_fn
%entry = OpLabel
%ld = OpCooperativeMatrixLoadKHR %matrix %var %int_0
%ld = OpCooperativeMatrixLoadKHR %matrix %var %int_0 %stride
OpReturn
OpFunctionEnd
)";
Expand Down Expand Up @@ -6832,12 +6905,13 @@ OpDecorate %array ArrayStride 4
%rows = OpSpecConstant %int 1
%cols = OpSpecConstant %int 1
%matrix_a = OpConstant %int 1
%stride = OpConstant %int 42
%matrix = OpTypeCooperativeMatrixKHR %float %subgroup %rows %cols %matrix_a
%var = OpUntypedVariableKHR %untyped StorageBuffer %block
%main = OpFunction %void None %void_fn
%entry = OpLabel
%gep = OpUntypedAccessChainKHR %untyped %block %var %int_0 %int_0
%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0
%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0 %stride
OpReturn
OpFunctionEnd
)";
Expand Down Expand Up @@ -6878,14 +6952,15 @@ OpDecorate %array ArrayStride 4
%rows = OpSpecConstant %int 1
%cols = OpSpecConstant %int 1
%matrix_a = OpConstant %int 1
%stride = OpConstant %int 42
%matrix = OpTypeCooperativeMatrixKHR %float %subgroup %rows %cols %matrix_a
%var1 = OpVariable %ptr StorageBuffer
%var2 = OpUntypedVariableKHR %untyped StorageBuffer %block
%main = OpFunction %void None %void_fn
%entry = OpLabel
%gep = OpAccessChain %ptr_float %var1 %int_0 %int_0
%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0
OpCooperativeMatrixStoreKHR %var2 %ld %int_0
%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0 %stride
OpCooperativeMatrixStoreKHR %var2 %ld %int_0 %stride
OpReturn
OpFunctionEnd
)";
Expand Down Expand Up @@ -6926,15 +7001,16 @@ OpDecorate %array ArrayStride 4
%rows = OpSpecConstant %int 1
%cols = OpSpecConstant %int 1
%matrix_a = OpConstant %int 1
%stride = OpConstant %int 42
%matrix = OpTypeCooperativeMatrixKHR %float %subgroup %rows %cols %matrix_a
%var1 = OpVariable %ptr StorageBuffer
%var2 = OpUntypedVariableKHR %untyped StorageBuffer %block
%main = OpFunction %void None %void_fn
%entry = OpLabel
%gep = OpAccessChain %ptr_float %var1 %int_0 %int_0
%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0
%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0 %stride
%gep2 = OpUntypedAccessChainKHR %untyped %block %var2 %int_0 %int_0
OpCooperativeMatrixStoreKHR %gep2 %ld %int_0
OpCooperativeMatrixStoreKHR %gep2 %ld %int_0 %stride
OpReturn
OpFunctionEnd
)";
Expand Down

0 comments on commit d3eac11

Please sign in to comment.