diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp index beaa79c28e..9bfa3c2158 100644 --- a/source/val/validate_memory.cpp +++ b/source/val/validate_memory.cpp @@ -2115,16 +2115,23 @@ spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _, const auto layout_index = (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 3u : 2u; - const auto colmajor_id = inst->GetOperandAs(layout_index); - const auto colmajor = _.FindDef(colmajor_id); - if (!colmajor || !_.IsIntScalarType(colmajor->type_id()) || - !(spvOpcodeIsConstant(colmajor->opcode()) || - spvOpcodeIsSpecConstant(colmajor->opcode()))) { + const auto layout_id = inst->GetOperandAs(layout_index); + const auto layout_inst = _.FindDef(layout_id); + if (!layout_inst || !_.IsIntScalarType(layout_inst->type_id()) || + !spvOpcodeIsConstant(layout_inst->opcode())) { return _.diag(SPV_ERROR_INVALID_ID, inst) - << "MemoryLayout operand " << _.getIdName(colmajor_id) + << "MemoryLayout operand " << _.getIdName(layout_id) << " must be a 32-bit integer constant instruction."; } + bool stride_required = false; + uint64_t layout; + if (_.EvalConstantValUint64(layout_id, &layout)) { + stride_required = + (layout == (uint64_t)spv::CooperativeMatrixLayout::RowMajorKHR) || + (layout == (uint64_t)spv::CooperativeMatrixLayout::ColumnMajorKHR); + } + const auto stride_index = (inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 4u : 3u; if (inst->operands().size() > stride_index) { @@ -2135,6 +2142,9 @@ spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _, << "Stride operand " << _.getIdName(stride_id) << " must be a scalar integer type."; } + } else if (stride_required) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "MemoryLayout " << layout << " requires a Stride."; } const auto memory_access_index = diff --git a/test/opt/aggressive_dead_code_elim_test.cpp b/test/opt/aggressive_dead_code_elim_test.cpp index dcce4f5789..d837099fe9 100644 --- a/test/opt/aggressive_dead_code_elim_test.cpp +++ b/test/opt/aggressive_dead_code_elim_test.cpp @@ -8018,6 +8018,7 @@ TEST_F(AggressiveDCETest, StoringAPointer) { %uint_3 = OpConstant %uint 3 %uint_16 = OpConstant %uint 16 %uint_4 = OpConstant %uint 4 +%coop_stride = OpConstant %int 42 %_runtimearr_int = OpTypeRuntimeArray %int %_struct_4 = OpTypeStruct %_runtimearr_int %_ptr_StorageBuffer__struct_4 = OpTypePointer StorageBuffer %_struct_4 @@ -8047,7 +8048,7 @@ TEST_F(AggressiveDCETest, StoringAPointer) { %26 = OpVariable %_ptr_Function__ptr_Function_int Function %27 = OpVariable %_ptr_Function__struct_18 Function %28 = OpAccessChain %_ptr_StorageBuffer_int %2 %int_0 %uint_0 - %29 = OpCooperativeMatrixLoadKHR %17 %28 %int_1 + %29 = OpCooperativeMatrixLoadKHR %17 %28 %int_1 %coop_stride %30 = OpCompositeConstruct %_struct_18 %29 OpStore %27 %30 %31 = OpAccessChain %_ptr_Function_17 %27 %int_0 @@ -8059,7 +8060,7 @@ TEST_F(AggressiveDCETest, StoringAPointer) { OpStore %32 %34 %35 = OpAccessChain %_ptr_StorageBuffer_int %2 %int_0 %uint_64 %36 = OpLoad %17 %31 - OpCooperativeMatrixStoreKHR %35 %36 %int_0 + OpCooperativeMatrixStoreKHR %35 %36 %int_0 %coop_stride OpReturn OpFunctionEnd )"; diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp index b4689f2e9d..df92fff4c0 100644 --- a/test/val/val_memory_test.cpp +++ b/test/val/val_memory_test.cpp @@ -2351,7 +2351,11 @@ OpFunctionEnd)"; } std::string GenCoopMatLoadStoreShaderKHR(const std::string& storeMemoryAccess, - const std::string& loadMemoryAccess) { + const std::string& loadMemoryAccess, + unsigned layout = 0, + bool useSpecConstantLayout = false, + bool useStoreStride = true, + bool useLoadStride = true) { std::string s = R"( OpCapability Shader OpCapability GroupNonUniform @@ -2408,11 +2412,18 @@ OpDecorate %129 BuiltIn WorkgroupSize %33 = OpConstant %6 1024 %34 = OpConstant %6 1 %38 = OpConstant %6 8 -%39 = OpConstant %6 0 +%uint_0 = OpConstant %6 0 +)"; + if (useSpecConstantLayout) { + s += "%layout = OpSpecConstant %6 " + std::to_string(layout); + } else { + s += "%layout = OpConstant %6 " + std::to_string(layout); + } + s += R"( %68 = OpTypeFloat 32 %69 = OpConstant %6 16 %70 = OpConstant %6 3 -%71 = OpTypeCooperativeMatrixKHR %68 %70 %69 %38 %39 +%71 = OpTypeCooperativeMatrixKHR %68 %70 %69 %38 %uint_0 %72 = OpTypePointer Function %71 %74 = OpTypeRuntimeArray %68 %75 = OpTypeStruct %74 @@ -2422,7 +2433,7 @@ OpDecorate %129 BuiltIn WorkgroupSize %79 = OpConstant %78 0 %81 = OpConstant %6 5 %82 = OpTypePointer StorageBuffer %68 -%84 = OpConstant %6 64 +%stride = OpConstant %6 64 %88 = OpTypePointer Private %71 %89 = OpVariable %88 Private %92 = OpTypeRuntimeArray %68 @@ -2478,7 +2489,7 @@ OpStore %18 %30 %35 = OpAccessChain %31 %18 %34 %36 = OpLoad %6 %35 %37 = OpIMul %6 %33 %36 -%40 = OpAccessChain %31 %18 %39 +%40 = OpAccessChain %31 %18 %uint_0 %41 = OpLoad %6 %40 %42 = OpIMul %6 %38 %41 %43 = OpIAdd %6 %37 %42 @@ -2486,7 +2497,7 @@ OpStore %32 %43 %45 = OpAccessChain %31 %18 %34 %46 = OpLoad %6 %45 %47 = OpIMul %6 %33 %46 -%48 = OpAccessChain %31 %18 %39 +%48 = OpAccessChain %31 %18 %uint_0 %49 = OpLoad %6 %48 %50 = OpIMul %6 %38 %49 %51 = OpIAdd %6 %47 %50 @@ -2494,7 +2505,7 @@ OpStore %44 %51 %53 = OpAccessChain %31 %18 %34 %54 = OpLoad %6 %53 %55 = OpIMul %6 %33 %54 -%56 = OpAccessChain %31 %18 %39 +%56 = OpAccessChain %31 %18 %uint_0 %57 = OpLoad %6 %56 %58 = OpIMul %6 %38 %57 %59 = OpIAdd %6 %55 %58 @@ -2502,27 +2513,33 @@ OpStore %52 %59 %61 = OpAccessChain %31 %18 %34 %62 = OpLoad %6 %61 %63 = OpIMul %6 %33 %62 -%64 = OpAccessChain %31 %18 %39 +%64 = OpAccessChain %31 %18 %uint_0 %65 = OpLoad %6 %64 %66 = OpIMul %6 %38 %65 %67 = OpIAdd %6 %63 %66 OpStore %60 %67 %80 = OpLoad %6 %32 %83 = OpAccessChain %82 %77 %79 %80 -%87 = OpCooperativeMatrixLoadKHR %71 %83 %39 %84 )" + - loadMemoryAccess + R"( %81 +)"; + if (useLoadStride) { + s += "%87 = OpCooperativeMatrixLoadKHR %71 %83 %layout %stride " + + loadMemoryAccess + " %81"; + } else { + s += "%87 = OpCooperativeMatrixLoadKHR %71 %83 %layout"; + } + s += R"( OpStore %73 %87 %90 = OpLoad %71 %73 OpStore %89 %90 %96 = OpLoad %6 %44 %97 = OpAccessChain %82 %95 %79 %96 -%98 = OpCooperativeMatrixLoadKHR %71 %97 %39 %84 MakePointerVisibleKHR|NonPrivatePointerKHR %81 +%98 = OpCooperativeMatrixLoadKHR %71 %97 %layout %stride MakePointerVisibleKHR|NonPrivatePointerKHR %81 OpStore %91 %98 %100 = OpLoad %71 %91 OpStore %99 %100 %106 = OpLoad %6 %52 %107 = OpAccessChain %82 %105 %79 %106 -%108 = OpCooperativeMatrixLoadKHR %71 %107 %39 %84 MakePointerVisibleKHR|NonPrivatePointerKHR %81 +%108 = OpCooperativeMatrixLoadKHR %71 %107 %layout %stride MakePointerVisibleKHR|NonPrivatePointerKHR %81 OpStore %101 %108 %110 = OpLoad %71 %101 OpStore %109 %110 @@ -2532,7 +2549,14 @@ OpStore %111 %115 %116 = OpLoad %71 %111 %121 = OpLoad %6 %60 %122 = OpAccessChain %82 %120 %79 %121 -OpCooperativeMatrixStoreKHR %122 %116 %39 %84 )" + storeMemoryAccess + R"( %81 +)"; + if (useStoreStride) { + s += "OpCooperativeMatrixStoreKHR %122 %116 %layout %stride " + + storeMemoryAccess + " %81"; + } else { + s += "OpCooperativeMatrixStoreKHR %122 %116 %layout"; + } + s += R"( OpReturn OpFunctionEnd )"; @@ -2549,6 +2573,54 @@ TEST_F(ValidateMemory, CoopMatKHRLoadStoreSuccess) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); } +struct StrideMissingCase { + unsigned layout; + bool useLoadStride; + bool useStoreStride; +}; + +using ValidateCoopMatrixStrideMissing = + spvtest::ValidateBase; + +INSTANTIATE_TEST_SUITE_P( + CoopMatrixStrideMissing, ValidateCoopMatrixStrideMissing, + Values( + StrideMissingCase{(unsigned)spv::CooperativeMatrixLayout::RowMajorKHR, + false, true}, + StrideMissingCase{(unsigned)spv::CooperativeMatrixLayout::RowMajorKHR, + true, false}, + StrideMissingCase{ + (unsigned)spv::CooperativeMatrixLayout::ColumnMajorKHR, false, + true}, + StrideMissingCase{ + (unsigned)spv::CooperativeMatrixLayout::ColumnMajorKHR, true, + false})); + +TEST_P(ValidateCoopMatrixStrideMissing, CoopMatKHRLoadStrideMissingFail) { + const StrideMissingCase& param = GetParam(); + std::string spirv = GenCoopMatLoadStoreShaderKHR( + "MakePointerAvailableKHR|NonPrivatePointerKHR", + "MakePointerVisibleKHR|NonPrivatePointerKHR", param.layout, + false /*useSpecConstantLayout*/, param.useStoreStride, + param.useLoadStride); + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_1)); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("MemoryLayout " + std::to_string(param.layout) + + " requires a Stride")); +} + +TEST_F(ValidateMemory, CoopMatKHRMemoryLayoutFromSpecConstantSuccess) { + std::string spirv = GenCoopMatLoadStoreShaderKHR( + "MakePointerAvailableKHR|NonPrivatePointerKHR", + "MakePointerVisibleKHR|NonPrivatePointerKHR", + (unsigned)spv::CooperativeMatrixLayout::RowMajorKHR, + true /*useSpecConstantLayout*/); + + CompileSuccessfully(spirv.c_str(), SPV_ENV_VULKAN_1_1); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_1)); +} + TEST_F(ValidateMemory, CoopMatKHRStoreMemoryAccessFail) { std::string spirv = GenCoopMatLoadStoreShaderKHR( "MakePointerVisibleKHR|NonPrivatePointerKHR", @@ -6791,11 +6863,12 @@ OpDecorate %array ArrayStride 4 %rows = OpSpecConstant %int 1 %cols = OpSpecConstant %int 1 %matrix_a = OpConstant %int 1 +%stride = OpConstant %int 42 %matrix = OpTypeCooperativeMatrixKHR %float %subgroup %rows %cols %matrix_a %var = OpUntypedVariableKHR %untyped StorageBuffer %block %main = OpFunction %void None %void_fn %entry = OpLabel -%ld = OpCooperativeMatrixLoadKHR %matrix %var %int_0 +%ld = OpCooperativeMatrixLoadKHR %matrix %var %int_0 %stride OpReturn OpFunctionEnd )"; @@ -6832,12 +6905,13 @@ OpDecorate %array ArrayStride 4 %rows = OpSpecConstant %int 1 %cols = OpSpecConstant %int 1 %matrix_a = OpConstant %int 1 +%stride = OpConstant %int 42 %matrix = OpTypeCooperativeMatrixKHR %float %subgroup %rows %cols %matrix_a %var = OpUntypedVariableKHR %untyped StorageBuffer %block %main = OpFunction %void None %void_fn %entry = OpLabel %gep = OpUntypedAccessChainKHR %untyped %block %var %int_0 %int_0 -%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0 +%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0 %stride OpReturn OpFunctionEnd )"; @@ -6878,14 +6952,15 @@ OpDecorate %array ArrayStride 4 %rows = OpSpecConstant %int 1 %cols = OpSpecConstant %int 1 %matrix_a = OpConstant %int 1 +%stride = OpConstant %int 42 %matrix = OpTypeCooperativeMatrixKHR %float %subgroup %rows %cols %matrix_a %var1 = OpVariable %ptr StorageBuffer %var2 = OpUntypedVariableKHR %untyped StorageBuffer %block %main = OpFunction %void None %void_fn %entry = OpLabel %gep = OpAccessChain %ptr_float %var1 %int_0 %int_0 -%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0 -OpCooperativeMatrixStoreKHR %var2 %ld %int_0 +%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0 %stride +OpCooperativeMatrixStoreKHR %var2 %ld %int_0 %stride OpReturn OpFunctionEnd )"; @@ -6926,15 +7001,16 @@ OpDecorate %array ArrayStride 4 %rows = OpSpecConstant %int 1 %cols = OpSpecConstant %int 1 %matrix_a = OpConstant %int 1 +%stride = OpConstant %int 42 %matrix = OpTypeCooperativeMatrixKHR %float %subgroup %rows %cols %matrix_a %var1 = OpVariable %ptr StorageBuffer %var2 = OpUntypedVariableKHR %untyped StorageBuffer %block %main = OpFunction %void None %void_fn %entry = OpLabel %gep = OpAccessChain %ptr_float %var1 %int_0 %int_0 -%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0 +%ld = OpCooperativeMatrixLoadKHR %matrix %gep %int_0 %stride %gep2 = OpUntypedAccessChainKHR %untyped %block %var2 %int_0 %int_0 -OpCooperativeMatrixStoreKHR %gep2 %ld %int_0 +OpCooperativeMatrixStoreKHR %gep2 %ld %int_0 %stride OpReturn OpFunctionEnd )";