@@ -845,6 +845,35 @@ void BIImport::GetCalledFunctions(const Function* pFunc, TFunctionsVec& calledFu
845
845
}
846
846
}
847
847
848
+ // Returns true when any pointer operand/return type in the call does not match
849
+ // the address space of the same position in the callee prototype.
850
+ static bool needsPointerASFix (const CallInst *inst, const Function *callee)
851
+ {
852
+ const FunctionType *callFTy = inst->getFunctionType ();
853
+ const FunctionType *fProtoTy = callee->getFunctionType ();
854
+
855
+ if (callFTy->getReturnType ()->isPointerTy () &&
856
+ fProtoTy ->getReturnType ()->isPointerTy () &&
857
+ callFTy->getReturnType ()->getPointerAddressSpace () !=
858
+ fProtoTy ->getReturnType ()->getPointerAddressSpace ())
859
+ return true ;
860
+
861
+ if (callFTy->getNumParams () != fProtoTy ->getNumParams ())
862
+ return false ;
863
+
864
+ for (unsigned i = 0 , e = callFTy->getNumParams (); i != e; ++i)
865
+ {
866
+ Type *callTy = callFTy->getParamType (i);
867
+ Type *protoTy = fProtoTy ->getParamType (i);
868
+
869
+ if (callTy->isPointerTy () && protoTy->isPointerTy () &&
870
+ callTy->getPointerAddressSpace () !=
871
+ protoTy->getPointerAddressSpace ())
872
+ return true ;
873
+ }
874
+ return false ;
875
+ }
876
+
848
877
void BIImport::removeFunctionBitcasts (Module& M)
849
878
{
850
879
std::vector<Instruction*> list_delete;
@@ -857,123 +886,136 @@ void BIImport::removeFunctionBitcasts(Module& M)
857
886
for (auto I = BB.begin (), E = BB.end (); I != E; I++)
858
887
{
859
888
CallInst* pInstCall = dyn_cast<CallInst>(I);
860
- if (!pInstCall || pInstCall->getCalledFunction ()) continue ;
861
- if (auto constExpr = dyn_cast<llvm::ConstantExpr>(IGCLLVM::getCalledValue (pInstCall)))
889
+ if (!pInstCall || pInstCall->getCalledFunction ())
890
+ continue ;
891
+
892
+ Function *funcToBeChanged = nullptr ;
893
+
894
+ // The call instruction might either use bitcast const expression or the function directly.
895
+ Value *calledVal = IGCLLVM::getCalledValue (pInstCall);
896
+ ConstantExpr *constExpr = dyn_cast<ConstantExpr>(calledVal);
897
+ if (constExpr)
862
898
{
863
- if (auto funcTobeChanged = dyn_cast<llvm::Function>(constExpr->stripPointerCasts ()))
864
- {
865
- if (funcTobeChanged->isDeclaration ()) continue ;
866
- // Map between values (functions) in source of bitcast
867
- // to their counterpart values in destination
868
- llvm::ValueToValueMapTy operandMap;
869
- Function* pDstFunc = nullptr ;
870
- auto BCFMI = bitcastFunctionMap.find (funcTobeChanged);
871
- bool notExists = BCFMI == bitcastFunctionMap.end ();
872
- if (!notExists)
899
+ funcToBeChanged = dyn_cast<Function>(constExpr->stripPointerCasts ());
900
+ }
901
+ else if (Function *directFunc = dyn_cast<Function>(calledVal))
902
+ {
903
+ if (needsPointerASFix (pInstCall, directFunc))
904
+ funcToBeChanged = directFunc;
905
+ }
906
+
907
+ if (!funcToBeChanged || funcToBeChanged->isDeclaration ())
908
+ continue ;
909
+
910
+ // Map between values (functions) in source of bitcast
911
+ // to their counterpart values in destination
912
+ llvm::ValueToValueMapTy operandMap;
913
+ Function* pDstFunc = nullptr ;
914
+ auto BCFMI = bitcastFunctionMap.find (funcToBeChanged);
915
+ bool notExists = BCFMI == bitcastFunctionMap.end ();
916
+ if (!notExists)
917
+ {
918
+ auto funcVec = bitcastFunctionMap[funcToBeChanged];
919
+ notExists = true ;
920
+ for (Function* F : funcVec) {
921
+ if (pInstCall->getFunctionType () == F->getFunctionType ())
873
922
{
874
- auto funcVec = bitcastFunctionMap[funcTobeChanged];
875
- notExists = true ;
876
- for (Function* F : funcVec) {
877
- if (pInstCall->getFunctionType () == F->getFunctionType ())
878
- {
879
- notExists = false ;
880
- pDstFunc = F;
881
- break ;
882
- }
883
- }
923
+ notExists = false ;
924
+ pDstFunc = F;
925
+ break ;
884
926
}
927
+ }
928
+ }
929
+
930
+ if (notExists)
931
+ {
932
+ pDstFunc = Function::Create (pInstCall->getFunctionType (), funcToBeChanged->getLinkage (), funcToBeChanged->getName (), &M);
933
+ if (pDstFunc->arg_size () != funcToBeChanged->arg_size ()) continue ;
934
+ // Go through and convert function arguments over, remembering the mapping.
935
+ Function::arg_iterator itSrcFunc = funcToBeChanged->arg_begin ();
936
+ Function::arg_iterator eSrcFunc = funcToBeChanged->arg_end ();
937
+ llvm::Function::arg_iterator itDest = pDstFunc->arg_begin ();
938
+
939
+ // Fix incorrect address space or incorrect pointer type caused by CloneFunctionInto later
940
+ // 1. AddressSpaceCast example: CloneFunctionInto causes incorrect LLVM IR, like below
941
+ // %arrayidx.le.i = getelementptr inbounds i8, i8 addrspace(1)* %8, i64 %conv.le.i
942
+ // %9 = load i8, i8 addrspace(4)* %arrayidx.le.i, align 1, !tbaa !309
943
+ // Address space should match for %arrayidx.le.i, so we insert necessary
944
+ // address space casts, which should be eliminated later by other passes
945
+ // 2. incorrect type example:
946
+ // %0 = load i16, %"class.sycl::_V1::ext::oneapi::bfloat16" addrspace(4)* %x, align 2
947
+ // Load value type should match pointer type for %x, so we insert necessary bitcast:
948
+ // %x.bcast = bitcast %"class.sycl::_V1::ext::oneapi::bfloat16" addrspace(4)* %x to i16 addrspace(4)*
949
+ // %0 = load i16, i16 addrspace(4)* %x.bcast, align 2
950
+ SmallVector<Instruction *, 5 > castInsts;
951
+
952
+ for (; itSrcFunc != eSrcFunc; ++itSrcFunc, ++itDest)
953
+ {
954
+ itDest->setName (itSrcFunc->getName ());
885
955
886
- if (notExists)
956
+ Type *srcType = (*itSrcFunc).getType ();
957
+ Value *destVal = &(*itDest);
958
+ Type *destType = destVal->getType ();
959
+ if (srcType->isPointerTy () && destType->isPointerTy ())
887
960
{
888
- pDstFunc = Function::Create (pInstCall->getFunctionType (), funcTobeChanged->getLinkage (), funcTobeChanged->getName (), &M);
889
- if (pDstFunc->arg_size () != funcTobeChanged->arg_size ()) continue ;
890
- // Go through and convert function arguments over, remembering the mapping.
891
- Function::arg_iterator itSrcFunc = funcTobeChanged->arg_begin ();
892
- Function::arg_iterator eSrcFunc = funcTobeChanged->arg_end ();
893
- llvm::Function::arg_iterator itDest = pDstFunc->arg_begin ();
894
-
895
- // Fix incorrect address space or incorrect pointer type caused by CloneFunctionInto later
896
- // 1. AddressSpaceCast example: CloneFunctionInto causes incorrect LLVM IR, like below
897
- // %arrayidx.le.i = getelementptr inbounds i8, i8 addrspace(1)* %8, i64 %conv.le.i
898
- // %9 = load i8, i8 addrspace(4)* %arrayidx.le.i, align 1, !tbaa !309
899
- // Address space should match for %arrayidx.le.i, so we insert necessary
900
- // address space casts, which should be eliminated later by other passes
901
- // 2. incorrect type example:
902
- // %0 = load i16, %"class.sycl::_V1::ext::oneapi::bfloat16" addrspace(4)* %x, align 2
903
- // Load value type should match pointer type for %x, so we insert necessary bitcast:
904
- // %x.bcast = bitcast %"class.sycl::_V1::ext::oneapi::bfloat16" addrspace(4)* %x to i16 addrspace(4)*
905
- // %0 = load i16, i16 addrspace(4)* %x.bcast, align 2
906
- SmallVector<Instruction *, 5 > castInsts;
907
-
908
- for (; itSrcFunc != eSrcFunc; ++itSrcFunc, ++itDest)
961
+ if (srcType->getPointerAddressSpace () != destType->getPointerAddressSpace ())
909
962
{
910
- itDest->setName (itSrcFunc->getName ());
911
-
912
- Type *srcType = (*itSrcFunc).getType ();
913
- Value *destVal = &(*itDest);
914
- Type *destType = destVal->getType ();
915
- if (srcType->isPointerTy () && destType->isPointerTy ())
916
- {
917
- if (srcType->getPointerAddressSpace () != destType->getPointerAddressSpace ())
918
- {
919
- AddrSpaceCastInst *newASC = new AddrSpaceCastInst (destVal, srcType, destVal->getName () + " .ascast" );
920
- castInsts.push_back (newASC);
921
- destVal = newASC;
922
- }
923
- PointerType* pSrcType = cast<PointerType>(srcType);
924
- if (!pSrcType->isOpaqueOrPointeeTypeMatches (destType))
925
- {
926
- BitCastInst *newBT = new BitCastInst (destVal, srcType, destVal->getName () + " .bcast" );
927
- castInsts.push_back (newBT);
928
- destVal = newBT;
929
- }
930
- }
931
-
932
- operandMap[&(*itSrcFunc)] = destVal;
963
+ AddrSpaceCastInst *newASC = new AddrSpaceCastInst (destVal, srcType, destVal->getName () + " .ascast" );
964
+ castInsts.push_back (newASC);
965
+ destVal = newASC;
966
+ }
967
+ PointerType* pSrcType = cast<PointerType>(srcType);
968
+ if (!pSrcType->isOpaqueOrPointeeTypeMatches (destType))
969
+ {
970
+ BitCastInst *newBT = new BitCastInst (destVal, srcType, destVal->getName () + " .bcast" );
971
+ castInsts.push_back (newBT);
972
+ destVal = newBT;
933
973
}
934
-
935
- // Clone the body of the function into the dest function.
936
- SmallVector<ReturnInst*, 8 > Returns; // Ignore returns.
937
- IGCLLVM::CloneFunctionInto (
938
- pDstFunc,
939
- funcTobeChanged,
940
- operandMap,
941
- IGCLLVM::CloneFunctionChangeType::LocalChangesOnly,
942
- Returns,
943
- " " );
944
-
945
- // Need to copy the attributes over too.
946
- AttributeList FuncAttrs = funcTobeChanged->getAttributes ();
947
- pDstFunc->setAttributes (FuncAttrs);
948
-
949
- // get first instruction in function and insert addressspacecast before it
950
- Instruction *firstInst = &(*pDstFunc->begin ()->getFirstInsertionPt ());
951
- for (Instruction *valToInsert : castInsts)
952
- valToInsert->insertBefore (firstInst);
953
-
954
- pDstFunc->setCallingConv (funcTobeChanged->getCallingConv ());
955
- bitcastFunctionMap[funcTobeChanged].push_back (pDstFunc);
956
974
}
957
975
958
- std::vector<Value*> Args;
959
- for (unsigned I = 0 , E = IGCLLVM::getNumArgOperands (pInstCall); I != E; ++I) {
960
- Args.push_back (pInstCall->getArgOperand (I));
961
- }
962
- auto newCI = CallInst::Create (pDstFunc, Args, " " , pInstCall);
963
- newCI->takeName (pInstCall);
964
- newCI->setCallingConv (pInstCall->getCallingConv ());
965
- newCI->setAttributes (pInstCall->getAttributes ());
966
- newCI->setDebugLoc (pInstCall->getDebugLoc ());
967
- pInstCall->replaceAllUsesWith (newCI);
968
- pInstCall->dropAllReferences ();
969
- if (constExpr->use_empty ())
970
- constExpr->dropAllReferences ();
971
- if (funcTobeChanged->use_empty ())
972
- funcTobeChanged->eraseFromParent ();
973
-
974
- list_delete.push_back (pInstCall);
976
+ operandMap[&(*itSrcFunc)] = destVal;
975
977
}
978
+
979
+ // Clone the body of the function into the dest function.
980
+ SmallVector<ReturnInst*, 8 > Returns; // Ignore returns.
981
+ IGCLLVM::CloneFunctionInto (
982
+ pDstFunc,
983
+ funcToBeChanged,
984
+ operandMap,
985
+ IGCLLVM::CloneFunctionChangeType::LocalChangesOnly,
986
+ Returns,
987
+ " " );
988
+
989
+ // Need to copy the attributes over too.
990
+ AttributeList FuncAttrs = funcToBeChanged->getAttributes ();
991
+ pDstFunc->setAttributes (FuncAttrs);
992
+
993
+ // get first instruction in function and insert addressspacecast before it
994
+ Instruction *firstInst = &(*pDstFunc->begin ()->getFirstInsertionPt ());
995
+ for (Instruction *valToInsert : castInsts)
996
+ valToInsert->insertBefore (firstInst);
997
+
998
+ pDstFunc->setCallingConv (funcToBeChanged->getCallingConv ());
999
+ bitcastFunctionMap[funcToBeChanged].push_back (pDstFunc);
1000
+ }
1001
+
1002
+ std::vector<Value*> Args;
1003
+ for (unsigned I = 0 , E = IGCLLVM::getNumArgOperands (pInstCall); I != E; ++I) {
1004
+ Args.push_back (pInstCall->getArgOperand (I));
976
1005
}
1006
+ auto newCI = CallInst::Create (pDstFunc, Args, " " , pInstCall);
1007
+ newCI->takeName (pInstCall);
1008
+ newCI->setCallingConv (pInstCall->getCallingConv ());
1009
+ newCI->setAttributes (pInstCall->getAttributes ());
1010
+ newCI->setDebugLoc (pInstCall->getDebugLoc ());
1011
+ pInstCall->replaceAllUsesWith (newCI);
1012
+ pInstCall->dropAllReferences ();
1013
+ if (constExpr && constExpr->use_empty ())
1014
+ constExpr->dropAllReferences ();
1015
+ if (funcToBeChanged->use_empty ())
1016
+ funcToBeChanged->eraseFromParent ();
1017
+
1018
+ list_delete.push_back (pInstCall);
977
1019
}
978
1020
}
979
1021
}
0 commit comments