Skip to content

Commit

Permalink
arm64: implement atomic-store with CAS
Browse files Browse the repository at this point in the history
Only available with ARMv8.1 or higher, so behind a flag that would
be enabled by an ACE compatible compiler depending on -march.

As a sideeffect, avoids the races that could result in an infloop
with M1.
  • Loading branch information
carenas committed Apr 27, 2023
1 parent 1a4d5f2 commit c8d1db9
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 39 deletions.
116 changes: 116 additions & 0 deletions doc/tutorial/deadlock.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* $CC [-DDEADLOCK] -Isljit_src -o deadlock deadlock.c sljit_src/sljitLir.c
*
* SPDX: 0BSD
*
* © 2023 Carlo Marcelo Arenas Belón
*/

#include "sljitLir.h"

#include <stdio.h>
#include <stdlib.h>
#include <pthread.h>

struct data {
long value;
long *canary;
pthread_t id;
};
typedef void *(*worker_fn)(void *);

static int gdb;

static worker_fn create_worker(void)
{
worker_fn code;
struct sljit_label *retry;
struct sljit_jump *exit, *skip;
struct sljit_compiler *C = sljit_create_compiler(NULL, NULL);
#if SLJIT_CONFIG_X86
sljit_u8 inst = 0xcc;
#elif SLJIT_CONFIG_S390X
sljit_u8 inst[2] = { 0x0, 0x1 };
#elif SLJIT_CONFIG_ARM_64
sljit_u32 inst = 0xd4200000;
#elif SLJIT_CONFIG_ARM_THUMB2
sljit_u16 inst = 0xde01;
#elif SLJIT_CONFIG_ARM
sljit_u32 inst = 0xe7d001f0;
#else
//#error "Not Supported"
#endif

sljit_emit_enter(C, 0, SLJIT_ARGS1(P, P), 6, 1, 0, 0, 0);
if (gdb)
sljit_emit_op_custom(C, &inst, sizeof(inst));
sljit_emit_op2(C, SLJIT_ADD, SLJIT_R5, 0, SLJIT_S0, 0,
SLJIT_IMM, SLJIT_OFFSETOF(struct data, canary));
sljit_emit_op1(C, SLJIT_MOV, SLJIT_R1, 0, SLJIT_IMM, 0);
retry = sljit_emit_label(C);
sljit_emit_atomic_load(C, SLJIT_MOV, SLJIT_R2, SLJIT_S0);
// skip = sljit_emit_cmp(C, SLJIT_EQUAL, SLJIT_MEM1(SLJIT_R4), 0,
// SLJIT_IMM, 0);
sljit_emit_op1(C, SLJIT_MOV, SLJIT_R4, 0, SLJIT_MEM1(SLJIT_R5), 0);
#ifdef DEADLOCK
sljit_emit_op1(C, SLJIT_MOV, SLJIT_MEM1(SLJIT_R4), 0, SLJIT_IMM, 1000);
#else
sljit_emit_op1(C, SLJIT_MOV, SLJIT_R4, 0, SLJIT_MEM1(SLJIT_R4), 0);
#endif
// sljit_set_label(skip, sljit_emit_label(C));
sljit_emit_op2(C, SLJIT_ADD, SLJIT_R1, 0, SLJIT_R1, 0, SLJIT_IMM, 1);
sljit_emit_atomic_store(C, SLJIT_MOV | SLJIT_SET_ATOMIC_STORED,
SLJIT_R1, SLJIT_S0, SLJIT_R2);
exit = sljit_emit_jump(C, SLJIT_ATOMIC_STORED);
sljit_set_label(sljit_emit_jump(C, SLJIT_JUMP), retry);
sljit_set_label(exit, sljit_emit_label(C));
sljit_emit_op1(C, SLJIT_MOV, SLJIT_RETURN_REG, 0, SLJIT_IMM, 0);
sljit_emit_return(C, SLJIT_MOV_P, SLJIT_RETURN_REG, 0);

code = sljit_generate_code(C);

sljit_free_compiler(C);
return code;
}

int main(int argc, char *argv[])
{
int winner = 0, looser = 0, other = 0;
int i, num_threads = -1;
struct data *threads;
worker_fn code;

if (argc > 1)
num_threads = atoi(argv[1]);

if (argc > 2)
gdb = 1;

code = create_worker();

if (num_threads <= 0)
num_threads = 1;

threads = calloc(num_threads, sizeof(struct data));
for (i = 0; i < num_threads; i++) {
threads[i].canary = i ? &threads[i - 1].value :
&threads[num_threads - 1].value;
pthread_create(&threads[i].id, NULL, code, &threads[i].value);
}
for (i = 0; i < num_threads; i++)
pthread_join(threads[i].id, NULL);
for (i = 0; i < num_threads; i++) {
switch(threads[i].value) {
case 1: ++winner; break;
case 1000: ++looser; break;
default: ++other; printf("%ld\n", threads[i].value);
}
}

printf("threads: %d won, %d lost, %d hang\n", winner, looser, other);

free(threads);
sljit_free_code((void *)code, NULL);

return 0;
}
62 changes: 52 additions & 10 deletions sljit_src/sljitNativeARM_64.c
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,18 @@ static const sljit_u8 freg_map[SLJIT_NUMBER_OF_FLOAT_REGISTERS + 3] = {
#define BLR 0xd63f0000
#define BR 0xd61f0000
#define BRK 0xd4200000
#define CASA 0xc8e07c00
#define CASB 0x08a07c00
#define CASH 0x48a07c00
#define CASL 0xc8a0fc00
#define CASLB 0x08a0fc00
#define CASLH 0x48a0fc00
#define CBZ 0xb4000000
#define CCMPI 0xfa400800
#define CLZ 0xdac01000
#define CSEL 0x9a800000
#define CSINC 0x9a800400
#define CLREX 0xd5033f5f
#define EOR 0xca000000
#define EORI 0xd2000000
#define EXTR 0x93c00000
Expand Down Expand Up @@ -2483,14 +2490,9 @@ SLJIT_API_FUNC_ATTRIBUTE sljit_s32 sljit_emit_fmem_update(struct sljit_compiler
return push_inst(compiler, inst | VT(freg) | RN(mem & REG_MASK) | (sljit_ins)((memw & 0x1ff) << 12));
}

SLJIT_API_FUNC_ATTRIBUTE sljit_s32 sljit_emit_atomic_load(struct sljit_compiler *compiler, sljit_s32 op,
sljit_s32 dst_reg,
sljit_s32 mem_reg)
static SLJIT_INLINE sljit_ins atomic_load_ins(const sljit_s32 op)
{
sljit_ins ins = 0;

CHECK_ERROR();
CHECK(check_sljit_emit_atomic_load(compiler, op, dst_reg, mem_reg));
sljit_ins ins;

switch (GET_OPCODE(op)) {
case SLJIT_MOV32:
Expand All @@ -2508,19 +2510,58 @@ SLJIT_API_FUNC_ATTRIBUTE sljit_s32 sljit_emit_atomic_load(struct sljit_compiler
break;
}

return push_inst(compiler, ins | RN(mem_reg) | RT(dst_reg));
return ins;
}

SLJIT_API_FUNC_ATTRIBUTE sljit_s32 sljit_emit_atomic_load(struct sljit_compiler *compiler, sljit_s32 op,
sljit_s32 dst_reg,
sljit_s32 mem_reg)
{
CHECK_ERROR();
CHECK(check_sljit_emit_atomic_load(compiler, op, dst_reg, mem_reg));

#ifdef __ARM_FEATURE_ATOMICS
return sljit_emit_op1(compiler, op, dst_reg, 0, SLJIT_MEM1(mem_reg), 0);
#else
return push_inst(compiler, atomic_load_ins(op) | RN(mem_reg) | RT(dst_reg));
#endif /* ARM_FEATURE_ATOMICS */
}

SLJIT_API_FUNC_ATTRIBUTE sljit_s32 sljit_emit_atomic_store(struct sljit_compiler *compiler, sljit_s32 op,
sljit_s32 src_reg,
sljit_s32 mem_reg,
sljit_s32 temp_reg)
{
sljit_ins ins;
sljit_ins ins, cmp;

CHECK_ERROR();
CHECK(check_sljit_emit_atomic_store(compiler, op, src_reg, mem_reg, temp_reg));

#if __ARM_FEATURE_ATOMICS
cmp = SUBS ^ W_OP;
switch (GET_OPCODE(op)) {
case SLJIT_MOV32:
case SLJIT_MOV_U32:
ins = CASL ^ (1 << 30);
break;
case SLJIT_MOV_U8:
ins = CASLB;
break;
case SLJIT_MOV_U16:
ins = CASLH;
break;
default:
ins = CASL;
cmp = SUBS;
break;
}
FAIL_IF(push_inst(compiler, atomic_load_ins(op) | RN(mem_reg) | RT(TMP_REG2)));
FAIL_IF(push_inst(compiler, ins | RM(temp_reg) | RN(mem_reg) | RD(src_reg)));
if (op & SLJIT_SET_ATOMIC_STORED)
FAIL_IF(push_inst(compiler, cmp | RM(TMP_REG2) | RN(temp_reg) | RD(TMP_REG1)));
return SLJIT_SUCCESS;
#else
cmp = (SUBI ^ W_OP) | (1 << 29);
switch (GET_OPCODE(op)) {
case SLJIT_MOV32:
case SLJIT_MOV_U32:
Expand All @@ -2538,7 +2579,8 @@ SLJIT_API_FUNC_ATTRIBUTE sljit_s32 sljit_emit_atomic_store(struct sljit_compiler
}

FAIL_IF(push_inst(compiler, ins | RM(TMP_REG1) | RN(mem_reg) | RT(src_reg)));
return push_inst(compiler, (SUBI ^ W_OP) | (1 << 29) | RD(TMP_ZERO) | RN(TMP_REG1));
return (op & SLJIT_SET_ATOMIC_STORED) ? push_inst(compiler, cmp | RD(TMP_ZERO) | RN(TMP_REG1)) : SLJIT_SUCCESS;
#endif /* Armv8.1 LSE */
}

SLJIT_API_FUNC_ATTRIBUTE sljit_s32 sljit_get_local_base(struct sljit_compiler *compiler, sljit_s32 dst, sljit_sw dstw, sljit_sw offset)
Expand Down
70 changes: 41 additions & 29 deletions test_src/sljitTest.c
Original file line number Diff line number Diff line change
Expand Up @@ -11530,7 +11530,7 @@ static void test92(void)
struct sljit_compiler *compiler = sljit_create_compiler(NULL, NULL);
struct sljit_label *label;
struct sljit_jump *jump;
sljit_sw buf[34];
sljit_sw buf[35];
sljit_s32 i;

if (verbose)
Expand All @@ -11542,7 +11542,9 @@ static void test92(void)
buf[i] = WCONST(0x5555555555555555, 0x55555555);

buf[0] = 4678;
*(sljit_u8*)(buf + 2) = 178;
buf[1] = -9856;
*(sljit_u8*)(buf + 2) = 78;
buf[3] = 203;
*(sljit_u8*)(buf + 5) = 211;
*(sljit_u16*)(buf + 9) = 17897;
*(sljit_u16*)(buf + 12) = 57812;
Expand All @@ -11555,7 +11557,6 @@ static void test92(void)

sljit_emit_enter(compiler, 0, SLJIT_ARGS1(VOID, P), 5, 5, 0, 0, 2 * sizeof(sljit_sw));

sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_MEM1(SLJIT_S0), sizeof(sljit_sw), SLJIT_IMM, -9856);
label = sljit_emit_label(compiler);
sljit_emit_atomic_load(compiler, SLJIT_MOV, SLJIT_R1, SLJIT_S0);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R2, 0, SLJIT_R1, 0);
Expand All @@ -11568,8 +11569,8 @@ static void test92(void)
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_MEM1(SLJIT_S0), sizeof(sljit_sw), SLJIT_R2, 0);

sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R1, 0, SLJIT_S0, 0, SLJIT_IMM, 2 * sizeof(sljit_sw));
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_MEM1(SLJIT_S0), 3 * sizeof(sljit_sw), SLJIT_IMM, 203);
label = sljit_emit_label(compiler);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R2, 0, SLJIT_IMM, -1);
sljit_emit_atomic_load(compiler, SLJIT_MOV_U8, SLJIT_R2, SLJIT_R1);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_S2, 0, SLJIT_R2, 0);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R0, 0, SLJIT_MEM1(SLJIT_S0), 3 * sizeof(sljit_sw));
Expand Down Expand Up @@ -11602,6 +11603,7 @@ static void test92(void)

label = sljit_emit_label(compiler);
sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R1, 0, SLJIT_S0, 0, SLJIT_IMM, 9 * sizeof(sljit_sw));
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_S2, 0, SLJIT_IMM, WCONST(0xAAAAAAAAAAAAAAAAu, 0xAAAAAAAAu));
sljit_emit_atomic_load(compiler, SLJIT_MOV_U16, SLJIT_S2, SLJIT_R1);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R2, 0, SLJIT_S2, 0);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R1, 0, SLJIT_S2, 0);
Expand Down Expand Up @@ -11631,10 +11633,11 @@ static void test92(void)

sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_S2, 0, SLJIT_S0, 0, SLJIT_IMM, 15 * sizeof(sljit_sw));
label = sljit_emit_label(compiler);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R2, 0, SLJIT_IMM, WCONST(0xAAAAAAAAAAAAAAAAu, 0xAAAAAAAAu));
sljit_emit_atomic_load(compiler, SLJIT_MOV_U32, SLJIT_R2, SLJIT_S2);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R0, 0, SLJIT_R2, 0);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_S3, 0, SLJIT_R2, 0);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R1, 0, SLJIT_IMM, 987654321);
sljit_emit_op1(compiler, SLJIT_MOV_U32, SLJIT_R1, 0, SLJIT_IMM, 987654321);
sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_S1, 0, SLJIT_S0, 0, SLJIT_IMM, 15 * sizeof(sljit_sw));
/* buf[15] */
sljit_emit_atomic_store(compiler, SLJIT_MOV_U32 | SLJIT_SET_ATOMIC_STORED, SLJIT_R1, SLJIT_S1, SLJIT_S3);
Expand All @@ -11647,15 +11650,15 @@ static void test92(void)
sljit_emit_atomic_load(compiler, SLJIT_MOV32, SLJIT_R0, SLJIT_R2);
sljit_emit_op1(compiler, SLJIT_MOV32, SLJIT_S1, 0, SLJIT_R0, 0);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_S2, 0, SLJIT_R0, 0);
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R0, 0, SLJIT_IMM, -573621);
sljit_emit_op1(compiler, SLJIT_MOV32, SLJIT_R0, 0, SLJIT_IMM, -573621);
sljit_emit_op1(compiler, SLJIT_MOV32, SLJIT_R1, 0, SLJIT_IMM, 678906789);
/* buf[17] */
sljit_emit_atomic_store(compiler, SLJIT_MOV32 | SLJIT_SET_ATOMIC_STORED, SLJIT_R1, SLJIT_R2, SLJIT_S2);
sljit_set_label(sljit_emit_jump(compiler, SLJIT_ATOMIC_NOT_STORED), label);
/* buf[18] */
sljit_emit_op1(compiler, SLJIT_MOV32, SLJIT_MEM1(SLJIT_S0), 18 * sizeof(sljit_sw), SLJIT_S1, 0);
/* buf[19] */
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_MEM1(SLJIT_S0), 19 * sizeof(sljit_sw), SLJIT_R0, 0);
sljit_emit_op1(compiler, SLJIT_MOV32, SLJIT_MEM1(SLJIT_S0), 19 * sizeof(sljit_sw), SLJIT_R0, 0);

sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R1, 0, SLJIT_S0, 0, SLJIT_IMM, 20 * sizeof(sljit_sw));
label = sljit_emit_label(compiler);
Expand Down Expand Up @@ -11727,6 +11730,11 @@ static void test92(void)
/* buf[33] */
sljit_emit_op_flags(compiler, SLJIT_MOV, SLJIT_MEM1(SLJIT_S0), 33 * sizeof(sljit_sw), SLJIT_ATOMIC_STORED);

sljit_emit_op2(compiler, SLJIT_ADD, SLJIT_R1, 0, SLJIT_S0, 0, SLJIT_IMM, 34 * sizeof(sljit_sw));
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_R2, 0, SLJIT_MEM1(SLJIT_S0), 32 * sizeof(sljit_sw));
sljit_emit_atomic_load(compiler, SLJIT_MOV, SLJIT_R0, SLJIT_R1);
sljit_emit_atomic_store(compiler, SLJIT_MOV, SLJIT_R2, SLJIT_R1, SLJIT_R0);

sljit_emit_return_void(compiler);

code.code = sljit_generate_code(compiler);
Expand All @@ -11739,7 +11747,7 @@ static void test92(void)
FAILED(buf[1] != 4678, "test92 case 2 failed\n");
FAILED(*(sljit_u8*)(buf + 2) != 203, "test92 case 3 failed\n");
FAILED(((sljit_u8*)(buf + 2))[1] != 0x55, "test92 case 4 failed\n");
FAILED(buf[3] != 178, "test92 case 5 failed\n");
FAILED(buf[3] != 78, "test92 case 5 failed\n");
FAILED(buf[4] != 203, "test92 case 6 failed\n");
FAILED(*(sljit_u8*)(buf + 5) != 97, "test92 case 7 failed\n");
FAILED(((sljit_u8*)(buf + 5))[1] != 0x55, "test92 case 8 failed\n");
Expand All @@ -11764,27 +11772,31 @@ static void test92(void)
FAILED(((sljit_u8*)(buf + 17))[4] != 0x55, "test92 case 24 failed\n");
#endif /* SLJIT_64BIT_ARCHITECTURE */
FAILED(*(sljit_u32*)(buf + 18) != 987609876, "test92 case 25 failed\n");
FAILED(buf[19] != -573621, "test92 case 26 failed\n");
FAILED(*(sljit_u8*)(buf + 20) != 240, "test92 case 27 failed\n");
FAILED(((sljit_u8*)(buf + 20))[1] != 0x55, "test92 case 28 failed\n");
FAILED(buf[21] != 192, "test92 case 29 failed\n");
FAILED(buf[22] != -5893, "test92 case 30 failed\n");
FAILED(buf[23] != 4059, "test92 case 31 failed\n");
FAILED(buf[24] != 6359, "test92 case 32 failed\n");
FAILED(buf[25] != (sljit_sw)(buf + 23), "test92 case 33 failed\n");
FAILED(((sljit_u8*)(buf + 26))[0] != 0x55, "test92 case 34 failed\n");
FAILED(((sljit_u8*)(buf + 26))[1] != 204, "test92 case 35 failed\n");
FAILED(((sljit_u8*)(buf + 26))[2] != 0x55, "test92 case 36 failed\n");
FAILED(buf[27] != 105, "test92 case 37 failed\n");
FAILED(((sljit_u8*)(buf + 28))[1] != 0x55, "test92 case 38 failed\n");
FAILED(((sljit_u8*)(buf + 28))[2] != 240, "test92 case 39 failed\n");
FAILED(((sljit_u8*)(buf + 28))[3] != 0x55, "test92 case 40 failed\n");
FAILED(buf[29] != 13, "test92 case 41 failed\n");
FAILED(buf[30] != 0, "test92 case 42 failed\n");
FAILED(((sljit_u16*)(buf + 31))[0] != 0x5555, "test92 case 43 failed\n");
FAILED(((sljit_u16*)(buf + 31))[1] != 51403, "test92 case 44 failed\n");
FAILED(buf[32] != 14876, "test92 case 45 failed\n");
FAILED(buf[33] != 1, "test92 case 46 failed\n");
FAILED(*(sljit_s32*)(buf + 19) != -573621, "test92 case 26 failed\n");
#if (defined SLJIT_64BIT_ARCHITECTURE && SLJIT_64BIT_ARCHITECTURE)
FAILED(((sljit_s32*)(buf + 19))[1] != 0x55555555, "test92 case 27 failed\n");
#endif /* SLJIT_64BIT_ARCHITECTURE */
FAILED(*(sljit_u8*)(buf + 20) != 240, "test92 case 28 failed\n");
FAILED(((sljit_u8*)(buf + 20))[1] != 0x55, "test92 case 29 failed\n");
FAILED(buf[21] != 192, "test92 case 30 failed\n");
FAILED(buf[22] != -5893, "test92 case 31 failed\n");
FAILED(buf[23] != 4059, "test92 case 32 failed\n");
FAILED(buf[24] != 6359, "test92 case 33 failed\n");
FAILED(buf[25] != (sljit_sw)(buf + 23), "test92 case 34 failed\n");
FAILED(((sljit_u8*)(buf + 26))[0] != 0x55, "test92 case 35 failed\n");
FAILED(((sljit_u8*)(buf + 26))[1] != 204, "test92 case 36 failed\n");
FAILED(((sljit_u8*)(buf + 26))[2] != 0x55, "test92 case 37 failed\n");
FAILED(buf[27] != 105, "test92 case 38 failed\n");
FAILED(((sljit_u8*)(buf + 28))[1] != 0x55, "test92 case 39 failed\n");
FAILED(((sljit_u8*)(buf + 28))[2] != 240, "test92 case 40 failed\n");
FAILED(((sljit_u8*)(buf + 28))[3] != 0x55, "test92 case 41 failed\n");
FAILED(buf[29] != 13, "test92 case 42 failed\n");
FAILED(buf[30] != 0, "test92 case 43 failed\n");
FAILED(((sljit_u16*)(buf + 31))[0] != 0x5555, "test92 case 44 failed\n");
FAILED(((sljit_u16*)(buf + 31))[1] != 51403, "test92 case 45 failed\n");
FAILED(buf[32] != 14876, "test92 case 46 failed\n");
FAILED(buf[33] != 1, "test92 case 47 failed\n");
FAILED(buf[34] != buf[32], "test92 case 48 failed\n");

sljit_free_code(code.code, NULL);
#endif
Expand Down

0 comments on commit c8d1db9

Please sign in to comment.