Skip to content

Commit

Permalink
Add first support for generic address space
Browse files Browse the repository at this point in the history
This is taking over some work from google#994
Fixes google#1077
  • Loading branch information
rjodinchr committed Apr 18, 2023
1 parent 2de2868 commit 88b0a02
Show file tree
Hide file tree
Showing 38 changed files with 892 additions and 22 deletions.
7 changes: 3 additions & 4 deletions include/clspv/FeatureMacro.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ enum class FeatureMacro {
__opencl_c_subgroups,
// following items are not supported
__opencl_c_device_enqueue,
__opencl_c_generic_address_space,
__opencl_c_pipes,
__opencl_c_program_scope_global_variables,
// following items are always enabled, but no point in complaining if they are
Expand All @@ -44,6 +43,7 @@ enum class FeatureMacro {
__opencl_c_read_write_images,
__opencl_c_atomic_scope_device,
__opencl_c_atomic_scope_all_devices,
__opencl_c_generic_address_space,
__opencl_c_work_group_collective_functions
};

Expand All @@ -53,6 +53,7 @@ constexpr std::array<std::pair<FeatureMacro, const char *>, 15>
FeatureStr(__opencl_c_3d_image_writes),
FeatureStr(__opencl_c_atomic_order_acq_rel),
FeatureStr(__opencl_c_fp64), FeatureStr(__opencl_c_images),
FeatureStr(__opencl_c_generic_address_space),
FeatureStr(__opencl_c_subgroups),
// following items are always enabled by clang
FeatureStr(__opencl_c_int64),
Expand All @@ -62,9 +63,7 @@ constexpr std::array<std::pair<FeatureMacro, const char *>, 15>
FeatureStr(__opencl_c_atomic_scope_all_devices),
FeatureStr(__opencl_c_work_group_collective_functions),
// following items cannot be enabled so are automatically disabled
FeatureStr(__opencl_c_device_enqueue),
FeatureStr(__opencl_c_generic_address_space),
FeatureStr(__opencl_c_pipes),
FeatureStr(__opencl_c_device_enqueue), FeatureStr(__opencl_c_pipes),
FeatureStr(__opencl_c_program_scope_global_variables)};
#undef FeatureStr

Expand Down
3 changes: 2 additions & 1 deletion include/clspv/Option.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ SourceLanguage Language();
// Returns true when the source language makes use of the generic address space.
inline bool LanguageUsesGenericAddressSpace() {
return (Language() == SourceLanguage::OpenCL_CPP) ||
((Language() == SourceLanguage::OpenCL_C_20));
(Language() == SourceLanguage::OpenCL_C_20) ||
(Language() == SourceLanguage::OpenCL_C_30);
}

// Return the SPIR-V binary version
Expand Down
3 changes: 3 additions & 0 deletions lib/BuiltinsEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ enum BuiltinType : unsigned int {
kMemFence,
kReadMemFence,
kWriteMemFence,
kToGlobal,
kToLocal,
kToPrivate,
kType_MemoryFence_End,

kType_Geometric_Start,
Expand Down
3 changes: 3 additions & 0 deletions lib/BuiltinsMap.inc
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,9 @@ static std::unordered_map<const char *, Builtins::BuiltinType, cstr_hash,
{"mem_fence", Builtins::kMemFence},
{"read_mem_fence", Builtins::kReadMemFence},
{"write_mem_fence", Builtins::kWriteMemFence},
{"__to_global", Builtins::kToGlobal},
{"__to_local", Builtins::kToLocal},
{"__to_private", Builtins::kToPrivate},

// Geometric
{"cross", Builtins::kCross},
Expand Down
1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ add_library(clspv_passes OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/LongVectorLoweringPass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/SetImageChannelMetadataPass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ThreeElementVectorLoweringPass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/LowerAddrSpaceCastPass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/MultiVersionUBOFunctionsPass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/NativeMathPass.cpp
${CMAKE_CURRENT_SOURCE_DIR}/NormalizeGlobalVariable.cpp
Expand Down
7 changes: 7 additions & 0 deletions lib/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,13 @@ int RunPassPipeline(llvm::Module &M, llvm::raw_svector_ostream *binaryStream) {
pm.addPass(clspv::InlineFuncWithPointerToFunctionArgPass());
pm.addPass(clspv::InlineFuncWithSingleCallSitePass());

// This pass needs to be after every inlining to make sure we are capable of
// removing every addrspacecast. It only needs to run if generic addrspace
// is used.
if (clspv::Option::LanguageUsesGenericAddressSpace()) {
pm.addPass(clspv::LowerAddrSpaceCastPass());
}

// Mem2Reg pass should be run early because O0 level optimization leaves
// redundant alloca, load and store instructions from function arguments.
// clspv needs to remove them ahead of transformation.
Expand Down
1 change: 0 additions & 1 deletion lib/FeatureMacro.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ namespace clspv {
FeatureMacro FeatureMacroLookup(const std::string &name) {
constexpr std::array<FeatureMacro, 4> NotSuppported{
FeatureMacro::__opencl_c_pipes,
FeatureMacro::__opencl_c_generic_address_space,
FeatureMacro::__opencl_c_device_enqueue,
FeatureMacro::__opencl_c_program_scope_global_variables};

Expand Down
287 changes: 287 additions & 0 deletions lib/LowerAddrSpaceCastPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
// Copyright 2023 The Clspv Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "LowerAddrSpaceCastPass.h"
#include "BitcastUtils.h"
#include "clspv/AddressSpace.h"
#include "Types.h"

#include "llvm/IR/Constants.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Transforms/Utils/Local.h"

using namespace llvm;

#define DEBUG_TYPE "LowerAddrSpaceCast"

namespace {

using PartitionCallback = std::function<void(Instruction *)>;

/// Partition the @p Instructions based on their liveness.
void partitionInstructions(ArrayRef<WeakTrackingVH> Instructions,
PartitionCallback OnDead,
PartitionCallback OnAlive) {
for (auto OldValueHandle : Instructions) {
// Handle situations when the weak handle is no longer valid.
if (!OldValueHandle.pointsToAliveValue()) {
continue; // Nothing else to do for this handle.
}

auto *OldInstruction = cast<Instruction>(OldValueHandle);
bool Dead = OldInstruction->use_empty();
if (Dead) {
OnDead(OldInstruction);
} else {
OnAlive(OldInstruction);
}
}
}

bool isGenericPTy(Type *Ty) {
return Ty && Ty->isPointerTy() &&
Ty->getPointerAddressSpace() == clspv::AddressSpace::Generic;
}
} // namespace

PreservedAnalyses clspv::LowerAddrSpaceCastPass::run(Module &M,
ModuleAnalysisManager &) {
PreservedAnalyses PA;

for (auto &F : M.functions()) {
BitcastUtils::RemoveCstExprFromFunction(&F);
runOnFunction(F);
}

return PA;
}

Value *clspv::LowerAddrSpaceCastPass::visit(Value *V) {
auto it = ValueMap.find(V);
if (it != ValueMap.end()) {
return it->second;
}
auto *I = dyn_cast<Instruction>(V);
if (I == nullptr) {
return V;
}

if (auto *alloca = dyn_cast<AllocaInst>(I)) {
if (alloca->getAllocatedType()->isPointerTy() &&
alloca->getAllocatedType()->getPointerAddressSpace() !=
clspv::AddressSpace::Private) {
return visit(alloca);
}
}

if (isGenericPTy(I->getType())) {
return visit(I);
}

for (auto &Operand : I->operands()) {
if (isGenericPTy(Operand->getType())) {
return visit(I);
}
}

return V;
}

llvm::Value *
clspv::LowerAddrSpaceCastPass::visitAllocaInst(llvm::AllocaInst &I) {
IRBuilder<> B(&I);
auto alloca = B.CreateAlloca(
PointerType::get(I.getContext(), clspv::AddressSpace::Private),
I.getArraySize(), I.getName());
registerReplacement(&I, alloca);
return alloca;
}

llvm::Value *clspv::LowerAddrSpaceCastPass::visitLoadInst(llvm::LoadInst &I) {
IRBuilder<> B(&I);
Type *Ty = I.getType();
Value *Ptr = visit(I.getPointerOperand());
if (isGenericPTy(Ty)) {
Ty = clspv::InferType(Ptr, I.getContext(), &TypeCache);
}
auto load = B.CreateLoad(Ty, Ptr, I.getName());
registerReplacement(&I, load);
if (!isGenericPTy(I.getType())) {
I.replaceAllUsesWith(load);
}
return load;
}

llvm::Value *clspv::LowerAddrSpaceCastPass::visitStoreInst(llvm::StoreInst &I) {
IRBuilder<> B(&I);
Value *Val = visit(I.getValueOperand());
Value *Ptr = visit(I.getPointerOperand());
if (isa<ConstantPointerNull>(Val)) {
Val = ConstantPointerNull::get(PointerType::get(
I.getContext(), clspv::InferType(Ptr, I.getContext(), &TypeCache)
->getPointerAddressSpace()));
}
auto store = B.CreateStore(Val, Ptr);
registerReplacement(&I, store);
return store;
}

llvm::Value *clspv::LowerAddrSpaceCastPass::visitGetElementPtrInst(
llvm::GetElementPtrInst &I) {
IRBuilder<> B(&I);
auto gep = B.CreateGEP(I.getSourceElementType(), visit(I.getPointerOperand()),
SmallVector<Value *>{I.indices()}, I.getName(),
I.isInBounds());
registerReplacement(&I, gep);
return gep;
}

llvm::Value *clspv::LowerAddrSpaceCastPass::visitAddrSpaceCastInst(
llvm::AddrSpaceCastInst &I) {
auto ptr = visit(I.getPointerOperand());
// Returns a pointer that points to a region in the address space if
// "to_addrspace" can cast ptr to the address space. Otherwise it returns
// NULL.
if (ptr->getType() != I.getSrcTy() && ptr->getType() != I.getDestTy()) {
ptr = ConstantPointerNull::get(cast<PointerType>(I.getType()));
I.replaceAllUsesWith(ptr);
}
registerReplacement(&I, ptr);
return ptr;
}

llvm::Value *clspv::LowerAddrSpaceCastPass::visitICmpInst(llvm::ICmpInst &I) {
IRBuilder<> B(&I);
Value *Op0 = visit(I.getOperand(0));
Value *Op1 = visit(I.getOperand(1));
if (Op0->getType() != Op1->getType()) {
if (isa<ConstantPointerNull>(Op0)) {
Op0 = ConstantPointerNull::get(cast<PointerType>(Op1->getType()));
} else if (isa<ConstantPointerNull>(Op1)) {
Op1 = ConstantPointerNull::get(cast<PointerType>(Op0->getType()));
} else {
llvm_unreachable("unsupported operand of icmp in loweraddrspacecast");
}
}

auto icmp = B.CreateICmp(I.getPredicate(), Op0, Op1, I.getName());
registerReplacement(&I, icmp);
I.replaceAllUsesWith(icmp);
return icmp;
}

Value *clspv::LowerAddrSpaceCastPass::visitInstruction(Instruction &I) {
#ifndef NDEBUG
dbgs() << "Instruction not handled: " << I << '\n';
#endif
llvm_unreachable("Missing support for instruction");
}

void clspv::LowerAddrSpaceCastPass::registerReplacement(Value *U, Value *V) {
LLVM_DEBUG(dbgs() << "Replacement for " << *U << ": " << *V << '\n');
assert(ValueMap.count(U) == 0 && "Value already registered");
ValueMap.insert({U, V});
}

void clspv::LowerAddrSpaceCastPass::runOnFunction(Function &F) {
LLVM_DEBUG(dbgs() << "Processing " << F.getName() << '\n');

// Skip declarations.
if (F.isDeclaration()) {
return;
}
for (Instruction &I : instructions(&F)) {
// Use the Value overload of visit to ensure cache is used.
visit(static_cast<Value *>(&I));
}

cleanDeadInstructions();

LLVM_DEBUG(dbgs() << "Final version for " << F.getName() << '\n');
LLVM_DEBUG(dbgs() << F << '\n');
}

void clspv::LowerAddrSpaceCastPass::cleanDeadInstructions() {
// Collect all instructions that have been replaced by another one, and remove
// them from the function. To address dependencies, use a fixed-point
// algorithm:
// 1. Collect the instructions that have been replaced.
// 2. Collect among these instructions the ones which have no uses and remove
// them.
// 3. Repeat step 2 until no progress is made.

// Select instructions that were replaced by another one.
// Ignore constants as they are not owned by the module and therefore don't
// need to be removed.
using WeakInstructions = SmallVector<WeakTrackingVH, 32>;
WeakInstructions OldInstructions;
for (const auto &Mapping : ValueMap) {
if (Mapping.getSecond() != nullptr) {
if (auto *OldInstruction = dyn_cast<Instruction>(Mapping.getFirst())) {
OldInstructions.push_back(OldInstruction);
} else {
assert(isa<Constant>(Mapping.getFirst()) &&
"Only Instruction and Constant are expected in ValueMap");
}
}
}

// Erase any mapping, as they won't be valid anymore.
ValueMap.clear();

for (bool Progress = true; Progress;) {
std::size_t PreviousSize = OldInstructions.size();

// Identify instructions that are actually dead and can be removed using
// RecursivelyDeleteTriviallyDeadInstructions.
// Use a third buffer to capture the instructions that are still alive to
// avoid mutating OldInstructions while iterating over it.
WeakInstructions NextBatch;
WeakInstructions TriviallyDeads;
partitionInstructions(
OldInstructions,
[&TriviallyDeads](Instruction *DeadInstruction) {
// Additionally, manually remove from the parent instructions with
// possible side-effect, generally speaking, such as call or alloca
// instructions. Those are not trivially dead.
if (isInstructionTriviallyDead(DeadInstruction)) {
TriviallyDeads.push_back(DeadInstruction);
} else {
DeadInstruction->eraseFromParent();
}
},
[&NextBatch](Instruction *AliveInstruction) {
NextBatch.push_back(AliveInstruction);
});

RecursivelyDeleteTriviallyDeadInstructions(TriviallyDeads);

// Update OldInstructions for the next iteration of the fixed-point.
OldInstructions = std::move(NextBatch);
Progress = (OldInstructions.size() < PreviousSize);
}

#ifndef NDEBUG
if (!OldInstructions.empty()) {
dbgs() << "These values were expected to be removed:\n";
for (auto ValueHandle : OldInstructions) {
dbgs() << '\t' << *ValueHandle << '\n';
}
llvm_unreachable("Not all supposedly-dead instruction were removed!");
}
#endif
}
Loading

0 comments on commit 88b0a02

Please sign in to comment.