diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index acd4bb7992be9..6a60525637179 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -200,6 +200,8 @@ MODULE_PASS("sycl-add-opt-level-attribute", SYCLAddOptLevelAttributePass()) MODULE_PASS("compile-time-properties", CompileTimePropertiesPass()) MODULE_PASS("sycl-remangle-libspirv", SYCLRemangleLibspirvPass()) MODULE_PASS("cleanup-sycl-metadata", CleanupSYCLMetadataPass()) +MODULE_PASS("cleanup-sycl-metadata-from-llvm-used", CleanupSYCLMetadataFromLLVMUsed()) +MODULE_PASS("remove-device-global-from-llvm-compiler-used", RemoveDeviceGlobalFromLLVMCompilerUsed()) MODULE_PASS("sycl-create-nvvm-annotations", SYCLCreateNVVMAnnotationsPass()) MODULE_PASS("lower-slm-reservation-calls", ESIMDLowerSLMReservationCalls()) MODULE_PASS("record-sycl-aspect-names", RecordSYCLAspectNamesPass()) diff --git a/llvm/lib/SYCLLowerIR/CleanupSYCLMetadata.cpp b/llvm/lib/SYCLLowerIR/CleanupSYCLMetadata.cpp index c70d4eeb1b7fb..928b80d7606aa 100644 --- a/llvm/lib/SYCLLowerIR/CleanupSYCLMetadata.cpp +++ b/llvm/lib/SYCLLowerIR/CleanupSYCLMetadata.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "llvm/SYCLLowerIR/CleanupSYCLMetadata.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/Constants.h" @@ -34,18 +35,18 @@ void cleanupSYCLCompilerModuleMetadata(const Module &M, llvm::StringRef MD) { } // GV is supposed to be either llvm.compiler.used or llvm.used. -SmallVector -eraseGlobalVariableAndReturnOperands(GlobalVariable *GV) { +SmallPtrSet +eraseGlobalVariableAndReturnUniqueOperands(GlobalVariable *GV) { assert(GV->user_empty() && "Users aren't expected"); Constant *Initializer = GV->getInitializer(); GV->setInitializer(nullptr); GV->eraseFromParent(); // Destroy the initializer and save operands. - SmallVector Operands; - Operands.resize(0); + SmallPtrSet Operands; + Operands.reserve(Initializer->getNumOperands()); for (auto &Op : Initializer->operands()) - Operands.push_back(cast(Op)); + Operands.insert(cast(Op)); assert(isSafeToDestroyConstant(Initializer) && "Cannot remove initializer of the given GV"); @@ -81,8 +82,8 @@ CleanupSYCLMetadataFromLLVMUsed::run(Module &M, ModuleAnalysisManager &) { if (!GV) return PreservedAnalyses::all(); - SmallVector IOperands = - eraseGlobalVariableAndReturnOperands(GV); + SmallPtrSet IOperands = + eraseGlobalVariableAndReturnUniqueOperands(GV); // Erase all operands. for (auto *Op : IOperands) { auto StrippedOp = Op->stripPointerCasts(); @@ -113,7 +114,8 @@ RemoveDeviceGlobalFromLLVMCompilerUsed::run(Module &M, const auto *VAT = cast(GV->getValueType()); // Destroy the initializer. Keep the operands so we keep the ones we need. - SmallVector IOperands = eraseGlobalVariableAndReturnOperands(GV); + SmallPtrSet IOperands = + eraseGlobalVariableAndReturnUniqueOperands(GV); // Iterate through all operands. If they are device_global then we drop them // and erase them if they have no uses afterwards. All other values are kept. diff --git a/llvm/test/SYCLLowerIR/CleanupSYCLCompilerInternalMetadata/repetative_values_in_llvm_used.ll b/llvm/test/SYCLLowerIR/CleanupSYCLCompilerInternalMetadata/repetative_values_in_llvm_used.ll new file mode 100644 index 0000000000000..9d8fe0f5e1836 --- /dev/null +++ b/llvm/test/SYCLLowerIR/CleanupSYCLCompilerInternalMetadata/repetative_values_in_llvm_used.ll @@ -0,0 +1,10 @@ +; RUN: opt --passes=cleanup-sycl-metadata-from-llvm-used %s + +; Check that CleanupSYCLMetadataFromLLVMUsed pass considers the case +; when llvm.used has one value more than once. It used to perform a double free +; due to llvm::isSafeToDestroyConstant(C) for the provided @C returns true. + +@C = linkonce_odr hidden addrspace(1) constant <{i64}> <{i64 0}> + +@llvm.used = appending global [2 x ptr addrspace(2)] [ptr addrspace(2) addrspacecast (ptr addrspace(1) @C to ptr addrspace(2)), ptr addrspace(2) addrspacecast (ptr addrspace(1) @C to ptr addrspace(2))], section "llvm.metadata" +