Skip to content

Commit

Permalink
[SOL] add support for (pseudo) atomics to SBF (#23)
Browse files Browse the repository at this point in the history
Lower atomic operations to their regular non-atomic equivalents. Lowering for
all operations except atomic fence is done at DAG legalization time. Fences are
removed at instruction emission time.
  • Loading branch information
alessandrod authored and LucasSte committed Feb 16, 2024
1 parent 470273a commit 665ba1f
Show file tree
Hide file tree
Showing 4 changed files with 470 additions and 12 deletions.
193 changes: 181 additions & 12 deletions llvm/lib/Target/BPF/BPFISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,30 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::STACKSAVE, MVT::Other, Expand);
setOperationAction(ISD::STACKRESTORE, MVT::Other, Expand);

// Set unsupported atomic operations as Custom so
// we can emit better error messages than fatal error
// from selectiondag.
for (auto VT : {MVT::i8, MVT::i16, MVT::i32}) {
for (auto VT : {MVT::i8, MVT::i16, MVT::i32, MVT::i32, MVT::i64}) {
if (Subtarget->isSolana()) {
// Implement custom lowering for all atomic operations
setOperationAction(ISD::ATOMIC_SWAP, VT, Custom);
setOperationAction(ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS, VT, Custom);
setOperationAction(ISD::ATOMIC_LOAD_ADD, VT, Custom);
setOperationAction(ISD::ATOMIC_LOAD_AND, VT, Custom);
setOperationAction(ISD::ATOMIC_LOAD_MAX, VT, Custom);
setOperationAction(ISD::ATOMIC_LOAD_MIN, VT, Custom);
setOperationAction(ISD::ATOMIC_LOAD_NAND, VT, Custom);
setOperationAction(ISD::ATOMIC_LOAD_OR, VT, Custom);
setOperationAction(ISD::ATOMIC_LOAD_SUB, VT, Custom);
setOperationAction(ISD::ATOMIC_LOAD_UMAX, VT, Custom);
setOperationAction(ISD::ATOMIC_LOAD_UMIN, VT, Custom);
setOperationAction(ISD::ATOMIC_LOAD_XOR, VT, Custom);
continue;
}

if (VT == MVT::i64) {
continue;
}

// Set unsupported atomic operations as Custom so we can emit better error
// messages than fatal error from selectiondag.
if (VT == MVT::i32) {
if (STI.getHasAlu32())
continue;
Expand Down Expand Up @@ -211,7 +231,17 @@ bool BPFTargetLowering::allowsMisalignedMemoryAccesses(
return isSolana;
}

bool BPFTargetLowering::isOffsetFoldingLegal(const GlobalAddressSDNode *GA) const {
bool BPFTargetLowering::lowerAtomicStoreAsStoreSDNode(
const StoreInst &SI) const {
return Subtarget->isSolana();
}

bool BPFTargetLowering::lowerAtomicLoadAsLoadSDNode(const LoadInst &LI) const {
return Subtarget->isSolana();
}

bool BPFTargetLowering::isOffsetFoldingLegal(
const GlobalAddressSDNode *GA) const {
return false;
}

Expand Down Expand Up @@ -281,19 +311,31 @@ BPFTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
return TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
}

void BPFTargetLowering::ReplaceNodeResults(
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
void BPFTargetLowering::ReplaceNodeResults(SDNode *N,
SmallVectorImpl<SDValue> &Results,
SelectionDAG &DAG) const {
const char *err_msg;
uint32_t Opcode = N->getOpcode();
switch (Opcode) {
default:
report_fatal_error("Unhandled custom legalization");
case ISD::ATOMIC_SWAP:
case ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS:
case ISD::ATOMIC_LOAD_ADD:
case ISD::ATOMIC_LOAD_AND:
case ISD::ATOMIC_LOAD_MAX:
case ISD::ATOMIC_LOAD_MIN:
case ISD::ATOMIC_LOAD_NAND:
case ISD::ATOMIC_LOAD_OR:
case ISD::ATOMIC_LOAD_SUB:
case ISD::ATOMIC_LOAD_UMAX:
case ISD::ATOMIC_LOAD_UMIN:
case ISD::ATOMIC_LOAD_XOR:
case ISD::ATOMIC_SWAP:
case ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS:
if (Subtarget->isSolana()) {
// We do lowering during legalization, see LowerOperation()
return;
}

if (HasAlu32 || Opcode == ISD::ATOMIC_LOAD_ADD)
err_msg = "Unsupported atomic operations, please use 32/64 bit version";
else
Expand All @@ -313,10 +355,23 @@ SDValue BPFTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerGlobalAddress(Op, DAG);
case ISD::SELECT_CC:
return LowerSELECT_CC(Op, DAG);
case ISD::ATOMIC_SWAP:
case ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS:
case ISD::ATOMIC_LOAD_ADD:
case ISD::ATOMIC_LOAD_AND:
case ISD::ATOMIC_LOAD_MAX:
case ISD::ATOMIC_LOAD_MIN:
case ISD::ATOMIC_LOAD_NAND:
case ISD::ATOMIC_LOAD_OR:
case ISD::ATOMIC_LOAD_SUB:
case ISD::ATOMIC_LOAD_UMAX:
case ISD::ATOMIC_LOAD_UMIN:
case ISD::ATOMIC_LOAD_XOR:
return LowerATOMICRMW(Op, DAG);
case ISD::DYNAMIC_STACKALLOC:
report_fatal_error("Unsupported dynamic stack allocation");
default:
llvm_unreachable("unimplemented operand");
llvm_unreachable("unimplemented atomic operand");
}
}

Expand Down Expand Up @@ -412,7 +467,6 @@ SDValue BPFTargetLowering::LowerFormalArguments(
fail(DL, DAG, "functions with VarArgs or StructRet are not supported");
}


return Chain;
}

Expand Down Expand Up @@ -740,6 +794,114 @@ SDValue BPFTargetLowering::LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const {
return DAG.getNode(BPFISD::SELECT_CC, DL, VTs, Ops);
}

SDValue BPFTargetLowering::LowerATOMICRMW(SDValue Op, SelectionDAG &DAG) const {
SDLoc DL(Op);
AtomicSDNode *AN = cast<AtomicSDNode>(Op);
assert(AN && "Expected custom lowering of an atomic load node");

SDValue Chain = AN->getChain();
SDValue Ptr = AN->getBasePtr();
EVT PtrVT = AN->getMemoryVT();
EVT RetVT = Op.getValueType();

// Load the current value
SDValue Load =
DAG.getExtLoad(ISD::EXTLOAD, DL, RetVT, Chain, Ptr, MachinePointerInfo(),
PtrVT, AN->getAlignment());
Chain = Load.getValue(1);

// Most ops return the current value, except CMP_SWAP_WITH_SUCCESS see below
SDValue Ret = Load;
SDValue RetFlag;

// Val contains the new value we want to set. For CMP_SWAP, Cmp contains the
// expected current value.
SDValue Cmp, Val;
if (AN->isCompareAndSwap()) {
Cmp = Op.getOperand(2);
Val = Op.getOperand(3);

// The Cmp value must match the pointer type
EVT CmpVT = Cmp->getValueType(0);
if (CmpVT != RetVT) {
Cmp = RetVT.bitsGT(CmpVT) ? DAG.getNode(ISD::SIGN_EXTEND, DL, RetVT, Cmp)
: DAG.getNode(ISD::TRUNCATE, DL, RetVT, Cmp);
}
} else {
Val = AN->getVal();
}

// The new value type must match the pointer type
EVT ValVT = Val->getValueType(0);
if (ValVT != RetVT) {
Val = RetVT.bitsGT(ValVT) ? DAG.getNode(ISD::SIGN_EXTEND, DL, RetVT, Val)
: DAG.getNode(ISD::TRUNCATE, DL, RetVT, Val);
ValVT = Val->getValueType(0);
}

SDValue NewVal;
switch (Op.getOpcode()) {
case ISD::ATOMIC_SWAP:
NewVal = Val;
break;
case ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS: {
EVT RetFlagVT = AN->getValueType(1);
NewVal = DAG.getSelectCC(DL, Load, Cmp, Val, Load, ISD::SETEQ);
RetFlag = DAG.getSelectCC(
DL, Load, Cmp, DAG.getBoolConstant(true, DL, RetFlagVT, RetFlagVT),
DAG.getBoolConstant(false, DL, RetFlagVT, RetFlagVT), ISD::SETEQ);
break;
}
case ISD::ATOMIC_LOAD_ADD:
NewVal = DAG.getNode(ISD::ADD, DL, ValVT, Load, Val);
break;
case ISD::ATOMIC_LOAD_SUB:
NewVal = DAG.getNode(ISD::SUB, DL, ValVT, Load, Val);
break;
case ISD::ATOMIC_LOAD_AND:
NewVal = DAG.getNode(ISD::AND, DL, ValVT, Load, Val);
break;
case ISD::ATOMIC_LOAD_NAND: {
NewVal =
DAG.getNOT(DL, DAG.getNode(ISD::AND, DL, ValVT, Load, Val), ValVT);
break;
}
case ISD::ATOMIC_LOAD_OR:
NewVal = DAG.getNode(ISD::OR, DL, ValVT, Load, Val);
break;
case ISD::ATOMIC_LOAD_XOR:
NewVal = DAG.getNode(ISD::XOR, DL, ValVT, Load, Val);
break;
case ISD::ATOMIC_LOAD_MIN:
NewVal = DAG.getNode(ISD::SMIN, DL, ValVT, Load, Val);
break;
case ISD::ATOMIC_LOAD_UMIN:
NewVal = DAG.getNode(ISD::UMIN, DL, ValVT, Load, Val);
break;
case ISD::ATOMIC_LOAD_MAX:
NewVal = DAG.getNode(ISD::SMAX, DL, ValVT, Load, Val);
break;
case ISD::ATOMIC_LOAD_UMAX:
NewVal = DAG.getNode(ISD::UMAX, DL, ValVT, Load, Val);
break;
default:
llvm_unreachable("unknown atomicrmw op");
}

Chain =
DAG.getTruncStore(Chain, DL, NewVal, Ptr, MachinePointerInfo(), PtrVT);

if (RetFlag) {
// CMP_SWAP_WITH_SUCCESS returns {value, success, chain}
Ret = DAG.getMergeValues({Ret, RetFlag, Chain}, DL);
} else {
// All the other ops return {value, chain}
Ret = DAG.getMergeValues({Ret, Chain}, DL);
}

return Ret;
}

const char *BPFTargetLowering::getTargetNodeName(unsigned Opcode) const {
switch ((BPFISD::NodeType)Opcode) {
case BPFISD::FIRST_NUMBER:
Expand Down Expand Up @@ -843,6 +1005,7 @@ BPFTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
Opc == BPF::Select_32_64);

bool isMemcpyOp = Opc == BPF::MEMCPY;
bool isAtomicFence = Opc == BPF::ATOMIC_FENCE;

#ifndef NDEBUG
bool isSelectRIOp = (Opc == BPF::Select_Ri ||
Expand All @@ -851,13 +1014,19 @@ BPFTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
Opc == BPF::Select_Ri_32_64);


assert((isSelectRROp || isSelectRIOp || isMemcpyOp) &&
assert((isSelectRROp || isSelectRIOp || isMemcpyOp || isAtomicFence) &&
"Unexpected instr type to insert");
#endif

if (isMemcpyOp)
return EmitInstrWithCustomInserterMemcpy(MI, BB);

if (isAtomicFence) {
// this is currently a nop
MI.eraseFromParent();
return BB;
}

bool is32BitCmp = (Opc == BPF::Select_32 ||
Opc == BPF::Select_32_64 ||
Opc == BPF::Select_Ri_32 ||
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/BPF/BPFISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ class BPFTargetLowering : public TargetLowering {

MVT getScalarShiftAmountTy(const DataLayout &, EVT) const override;

bool lowerAtomicStoreAsStoreSDNode(const StoreInst &SI) const override;
bool lowerAtomicLoadAsLoadSDNode(const LoadInst &LI) const override;

private:
// Control Instruction Selection Features
bool HasAlu32;
Expand All @@ -80,6 +83,7 @@ class BPFTargetLowering : public TargetLowering {
SDValue LowerBR_CC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSELECT_CC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerATOMICRMW(SDValue Op, SelectionDAG &DAG) const;

// Lower the result values of a call, copying them out of physregs into vregs
SDValue LowerCallResult(SDValue Chain, SDValue InGlue,
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/BPF/BPFInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def BPFIsLittleEndian : Predicate<"CurDAG->getDataLayout().isLittleEndian()">;
def BPFIsBigEndian : Predicate<"!CurDAG->getDataLayout().isLittleEndian()">;
def BPFHasALU32 : Predicate<"Subtarget->getHasAlu32()">;
def BPFNoALU32 : Predicate<"!Subtarget->getHasAlu32()">;
def BPFSubtargetSolana : Predicate<"Subtarget->isSolana()">;

def brtarget : Operand<OtherVT> {
let PrintMethod = "printBrTargetOperand";
Expand Down Expand Up @@ -747,6 +748,14 @@ def : Pat<(atomic_load_sub_32 ADDRri:$addr, GPR32:$val),
def : Pat<(atomic_load_sub_64 ADDRri:$addr, GPR:$val),
(XFADDD ADDRri:$addr, (NEG_64 GPR:$val))>;

let Predicates = [BPFSubtargetSolana], usesCustomInserter = 1, isCodeGenOnly = 1 in {
def ATOMIC_FENCE : Pseudo<
(outs),
(ins),
"#atomic_fence",
[(atomic_fence timm, timm)]>;
}

// Atomic Exchange
class XCHG<BPFWidthModifer SizeOp, string OpcodeStr, PatFrag OpNode>
: TYPE_LD_ST<BPF_ATOMIC.Value, SizeOp.Value,
Expand Down
Loading

0 comments on commit 665ba1f

Please sign in to comment.