Skip to content

Commit

Permalink
arm64: implement atomic-store with CAS
Browse files Browse the repository at this point in the history
Only available with ARMv8.1 or higher, so behind a flag that would
be enabled by an ACE compatible compiler depending on -march.

As a sideeffect, avoids the races that could result in an infloop
with M1.
  • Loading branch information
carenas committed May 2, 2023
1 parent 1a4d5f2 commit 103aa07
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 10 deletions.
91 changes: 88 additions & 3 deletions sljit_src/sljitNativeARM_64.c
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,21 @@ static const sljit_u8 freg_map[SLJIT_NUMBER_OF_FLOAT_REGISTERS + 3] = {
#define BLR 0xd63f0000
#define BR 0xd61f0000
#define BRK 0xd4200000
#define CAS 0xc8a07c00
#define CASA 0xc8e07c00
#define CASAB 0x08e07c00
#define CASAH 0x48e07c00
#define CASAL 0xc8e0fc00
#define CASALB 0x08e0fc00
#define CASALH 0x48e0fc00
#define CASB 0x08a07c00
#define CASH 0x48a07c00
#define CASL 0xc8a0fc00
#define CASLB 0x08a0fc00
#define CASLH 0x48a0fc00
#define CBZ 0xb4000000
#define CCMPI 0xfa400800
#define CLREX 0xd5033f5f
#define CLZ 0xdac01000
#define CSEL 0x9a800000
#define CSINC 0x9a800400
Expand All @@ -101,6 +114,7 @@ static const sljit_u8 freg_map[SLJIT_NUMBER_OF_FLOAT_REGISTERS + 3] = {
#define FMUL 0x1e600800
#define FNEG 0x1e614000
#define FSUB 0x1e603800
#define LDAXR 0xc85ffc00
#define LDRI 0xf9400000
#define LDRI_F64 0xfd400000
#define LDRI_POST 0xf8400400
Expand Down Expand Up @@ -153,6 +167,13 @@ static const sljit_u8 freg_map[SLJIT_NUMBER_OF_FLOAT_REGISTERS + 3] = {
#define UDIV 0x9ac00800
#define UMULH 0x9bc03c00

#define CSET (CSINC | RM(TMP_ZERO) | RN(TMP_ZERO))
#define LDR (STRI | (1 << 22))
#define LDRB (STRBI | (1 << 22))
#define LDRH (LDRB | (1 << 30))
#define LDRSW ((LDRI ^ (1 << 30)) ^ (0x3 << 22))
#define MOV (ORR | RN(TMP_ZERO))

static sljit_s32 push_inst(struct sljit_compiler *compiler, sljit_ins ins)
{
sljit_ins *ptr = (sljit_ins*)ensure_buf(compiler, sizeof(sljit_ins));
Expand Down Expand Up @@ -2487,11 +2508,30 @@ SLJIT_API_FUNC_ATTRIBUTE sljit_s32 sljit_emit_atomic_load(struct sljit_compiler
sljit_s32 dst_reg,
sljit_s32 mem_reg)
{
sljit_ins ins = 0;
sljit_ins ins;

CHECK_ERROR();
CHECK(check_sljit_emit_atomic_load(compiler, op, dst_reg, mem_reg));

#ifdef __ARM_FEATURE_ATOMICS
switch (GET_OPCODE(op)) {
case SLJIT_MOV32:
ins = LDR ^ (1 << 30);
break;
case SLJIT_MOV_U32:
ins = LDRSW;
break;
case SLJIT_MOV_U16:
ins = LDRH;
break;
case SLJIT_MOV_U8:
ins = LDRB;
break;
default:
ins = LDR;
break;
}
#else /* !__ARM_FEATURE_ATOMICS */
switch (GET_OPCODE(op)) {
case SLJIT_MOV32:
case SLJIT_MOV_U32:
Expand All @@ -2507,7 +2547,7 @@ SLJIT_API_FUNC_ATTRIBUTE sljit_s32 sljit_emit_atomic_load(struct sljit_compiler
ins = LDXR;
break;
}

#endif /* ARM_FEATURE_ATOMICS */
return push_inst(compiler, ins | RN(mem_reg) | RT(dst_reg));
}

Expand All @@ -2517,10 +2557,54 @@ SLJIT_API_FUNC_ATTRIBUTE sljit_s32 sljit_emit_atomic_store(struct sljit_compiler
sljit_s32 temp_reg)
{
sljit_ins ins;
sljit_s32 tmp = temp_reg;
sljit_ins cmp = 0;
sljit_ins inv_bits = W_OP;

CHECK_ERROR();
CHECK(check_sljit_emit_atomic_store(compiler, op, src_reg, mem_reg, temp_reg));

#ifdef __ARM_FEATURE_ATOMICS
if (op & SLJIT_SET_ATOMIC_STORED)
cmp = (SUBS ^ W_OP) | RD(TMP_ZERO);

switch (GET_OPCODE(op)) {
case SLJIT_MOV32:
case SLJIT_MOV_U32:
ins = CAS ^ (1 << 30);
break;
case SLJIT_MOV_U16:
ins = CASH;
break;
case SLJIT_MOV_U8:
ins = CASB;
break;
default:
ins = CAS;
inv_bits = 0;
if (cmp)
cmp ^= W_OP;
break;
}

if (cmp) {
FAIL_IF(push_inst(compiler, MOV ^ inv_bits | RM(temp_reg) | RD(TMP_REG1)));
tmp = TMP_REG1;
}
FAIL_IF(push_inst(compiler, ins | RM(tmp) | RN(mem_reg) | RD(src_reg)));
if (!cmp)
return SLJIT_SUCCESS;

FAIL_IF(push_inst(compiler, cmp | RM(tmp) | RN(temp_reg)));
FAIL_IF(push_inst(compiler, CSET ^ inv_bits | RD(tmp)));
return push_inst(compiler, cmp | RM(tmp) | RN(TMP_ZERO));
#else /* !__ARM_FEATURE_ATOMICS */
SLJIT_UNUSED_ARG(tmp);
SLJIT_UNUSED_ARG(inv_bits);

if (op & SLJIT_SET_ATOMIC_STORED)
cmp = (SUBI ^ W_OP) | (1 << 29);

switch (GET_OPCODE(op)) {
case SLJIT_MOV32:
case SLJIT_MOV_U32:
Expand All @@ -2538,7 +2622,8 @@ SLJIT_API_FUNC_ATTRIBUTE sljit_s32 sljit_emit_atomic_store(struct sljit_compiler
}

FAIL_IF(push_inst(compiler, ins | RM(TMP_REG1) | RN(mem_reg) | RT(src_reg)));
return push_inst(compiler, (SUBI ^ W_OP) | (1 << 29) | RD(TMP_ZERO) | RN(TMP_REG1));
return cmp ? push_inst(compiler, cmp | RD(TMP_ZERO) | RN(TMP_REG1)) : SLJIT_SUCCESS;
#endif /* __ARM_FEATURE_ATOMICS */
}

SLJIT_API_FUNC_ATTRIBUTE sljit_s32 sljit_get_local_base(struct sljit_compiler *compiler, sljit_s32 dst, sljit_sw dstw, sljit_sw offset)
Expand Down
46 changes: 39 additions & 7 deletions test_src/sljitTest.c
Original file line number Diff line number Diff line change
Expand Up @@ -11530,19 +11530,19 @@ static void test92(void)
struct sljit_compiler *compiler = sljit_create_compiler(NULL, NULL);
struct sljit_label *label;
struct sljit_jump *jump;
sljit_sw buf[34];
sljit_sw buf[37];
sljit_s32 i;

if (verbose)
printf("Run test92\n");

FAILED(!compiler, "cannot create compiler\n");

for (i = 0; i < 34; i++)
for (i = 0; i < 36; i++)
buf[i] = WCONST(0x5555555555555555, 0x55555555);

buf[0] = 4678;
*(sljit_u8*)(buf + 2) = 178;
*(sljit_u8*)(buf + 2) = 78;
*(sljit_u8*)(buf + 5) = 211;
*(sljit_u16*)(buf + 9) = 17897;
*(sljit_u16*)(buf + 12) = 57812;
Expand All @@ -11552,6 +11552,9 @@ static void test92(void)
((sljit_u8*)(buf + 26))[1] = 105;
((sljit_u8*)(buf + 28))[2] = 13;
((sljit_u16*)(buf + 31))[1] = 14876;
#if (defined SLJIT_64BIT_ARCHITECTURE && SLJIT_64BIT_ARCHITECTURE)
((sljit_s32*)(buf + 33))[1] = -1;
#endif /* SLJIT_64BIT_ARCHITECTURE */

sljit_emit_enter(compiler, 0, SLJIT_ARGS1(VOID, P), 5, 5, 0, 0, 2 * sizeof(sljit_sw));

Expand Down Expand Up @@ -11591,7 +11594,7 @@ static void test92(void)
sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R0, 0, SLJIT_S0, 0, SLJIT_IMM, 5 * sizeof(sljit_sw));
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_S1, 0, SLJIT_IMM, 97);
/* buf[5] */
sljit_emit_atomic_store(compiler, SLJIT_MOV_U8 | SLJIT_SET_ATOMIC_STORED, SLJIT_S1, SLJIT_R0, SLJIT_S2);
sljit_emit_atomic_store(compiler, SLJIT_MOV32_U8 | SLJIT_SET_ATOMIC_STORED, SLJIT_S1, SLJIT_R0, SLJIT_S2);
sljit_set_label(sljit_emit_jump(compiler, SLJIT_ATOMIC_NOT_STORED), label);
/* buf[6] */
sljit_emit_op1(compiler, SLJIT_MOV32, SLJIT_MEM1(SLJIT_S0), 6 * sizeof(sljit_sw), SLJIT_R1, 0);
Expand Down Expand Up @@ -11724,8 +11727,29 @@ static void test92(void)
sljit_set_label(sljit_emit_jump(compiler, SLJIT_ATOMIC_NOT_STORED), label);
/* buf[32] */
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_MEM1(SLJIT_S0), 32 * sizeof(sljit_sw), SLJIT_S1, 0);

#if (defined SLJIT_64BIT_ARCHITECTURE && SLJIT_64BIT_ARCHITECTURE)
sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R1, 0, SLJIT_S0, 0, SLJIT_IMM, 33 * sizeof(sljit_sw) + sizeof(sljit_u32));
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_S1, 0, SLJIT_IMM, 0);
label = sljit_emit_label(compiler);
sljit_emit_atomic_load(compiler, SLJIT_MOV32, SLJIT_R0, SLJIT_R1);
sljit_emit_op1(compiler, SLJIT_MOV32, SLJIT_S1, 0, SLJIT_R0, 0);
sljit_emit_op1(compiler, SLJIT_MOV32, SLJIT_R2, 0, SLJIT_IMM, 0xdeadbeef);
/* buf[33] */
sljit_emit_op_flags(compiler, SLJIT_MOV, SLJIT_MEM1(SLJIT_S0), 33 * sizeof(sljit_sw), SLJIT_ATOMIC_STORED);
sljit_emit_atomic_store(compiler, SLJIT_MOV32 | SLJIT_SET_ATOMIC_STORED, SLJIT_R2, SLJIT_R1, SLJIT_R0);
sljit_set_label(sljit_emit_jump(compiler, SLJIT_ATOMIC_NOT_STORED), label);
/* buf[34] */
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_MEM1(SLJIT_S0), 34 * sizeof(sljit_sw), SLJIT_S1, 0);
#endif /* SLJIT_64BIT_ARCHITECTURE */

/* buf[35] */
sljit_emit_op_flags(compiler, SLJIT_MOV, SLJIT_MEM1(SLJIT_S0), 35 * sizeof(sljit_sw), SLJIT_ATOMIC_STORED);

sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R1, 0, SLJIT_S0, 0, SLJIT_IMM, 36 * sizeof(sljit_sw));
sljit_emit_atomic_load(compiler, SLJIT_MOV, SLJIT_R0, SLJIT_R1);
sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R2, 0, SLJIT_R1, 0, SLJIT_IMM, 1);
/* buf[36] */
sljit_emit_atomic_store(compiler, SLJIT_MOV, SLJIT_R2, SLJIT_R1, SLJIT_R0);

sljit_emit_return_void(compiler);

Expand All @@ -11739,7 +11763,7 @@ static void test92(void)
FAILED(buf[1] != 4678, "test92 case 2 failed\n");
FAILED(*(sljit_u8*)(buf + 2) != 203, "test92 case 3 failed\n");
FAILED(((sljit_u8*)(buf + 2))[1] != 0x55, "test92 case 4 failed\n");
FAILED(buf[3] != 178, "test92 case 5 failed\n");
FAILED(buf[3] != 78, "test92 case 5 failed\n");
FAILED(buf[4] != 203, "test92 case 6 failed\n");
FAILED(*(sljit_u8*)(buf + 5) != 97, "test92 case 7 failed\n");
FAILED(((sljit_u8*)(buf + 5))[1] != 0x55, "test92 case 8 failed\n");
Expand All @@ -11764,6 +11788,9 @@ static void test92(void)
FAILED(((sljit_u8*)(buf + 17))[4] != 0x55, "test92 case 24 failed\n");
#endif /* SLJIT_64BIT_ARCHITECTURE */
FAILED(*(sljit_u32*)(buf + 18) != 987609876, "test92 case 25 failed\n");
#if (defined SLJIT_64BIT_ARCHITECTURE && SLJIT_64BIT_ARCHITECTURE)
FAILED(((sljit_u8*)(buf + 18))[4] != 0x55, "test92 case 25 (overflow) failed\n");
#endif /* SLJIT_64BIT_ARCHITECTURE */
FAILED(buf[19] != -573621, "test92 case 26 failed\n");
FAILED(*(sljit_u8*)(buf + 20) != 240, "test92 case 27 failed\n");
FAILED(((sljit_u8*)(buf + 20))[1] != 0x55, "test92 case 28 failed\n");
Expand All @@ -11784,7 +11811,12 @@ static void test92(void)
FAILED(((sljit_u16*)(buf + 31))[0] != 0x5555, "test92 case 43 failed\n");
FAILED(((sljit_u16*)(buf + 31))[1] != 51403, "test92 case 44 failed\n");
FAILED(buf[32] != 14876, "test92 case 45 failed\n");
FAILED(buf[33] != 1, "test92 case 46 failed\n");
#if (defined SLJIT_64BIT_ARCHITECTURE && SLJIT_64BIT_ARCHITECTURE)
FAILED(((sljit_u32*)(buf + 33))[0] != 0x55555555, "test92 case 46 failed\n");
FAILED(((sljit_u32*)(buf + 33))[1] != 0xdeadbeef, "test92 case 47 failed\n");
FAILED(buf[34] != 0xffffffff, "test92 case 48 failed\n");
#endif /* SLJIT_64BIT_ARCHITECTURE */
FAILED(buf[35] != 1, "test92 case 49 failed\n");

sljit_free_code(code.code, NULL);
#endif
Expand Down

0 comments on commit 103aa07

Please sign in to comment.