Skip to content

Commit

Permalink
remove cluster.rpc and cluster.msg
Browse files Browse the repository at this point in the history
add cluster.lua for more flexible rpc
  • Loading branch information
findstr committed Nov 18, 2024
1 parent 92384d7 commit 204f2d8
Show file tree
Hide file tree
Showing 14 changed files with 547 additions and 776 deletions.
69 changes: 49 additions & 20 deletions examples/rpc.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
local core = require "core"
local crypto = require "core.crypto"
local rpc = require "cluster.rpc"
local cluster = require "core.cluster"
local zproto = require "zproto"

local proto = zproto:parse [[
Expand All @@ -12,41 +12,70 @@ pong 0x2 {
}
]]

assert(proto)
local function unmarshal(cmd, buf, size)
local dat, size = proto:unpack(buf, size, true)
local body = proto:decode(cmd, dat, size)
return body
end

local server = rpc.listen {
addr = "127.0.0.1:9999",
proto = proto,
local function marshal(cmd, body)
if type(cmd) == "string" then
cmd = proto:tag(cmd)
end
print("marshal", cmd, body)
local dat, size = proto:encode(cmd, body, true)
local buf, size = proto:pack(dat, size, true)
return cmd, buf, size
end

local callret = {
["ping"] = "pong",
[0x01] = "pong",
}

local server = cluster.new {
marshal = marshal,
unmarshal = unmarshal,
callret = callret,
accept = function(fd, addr)
print("accept", fd, addr)
end,

call = function(msg, cmd, fd)
print("callee", msg.txt, fd)
return msg
end,
close = function(fd, errno)
print("close", fd, errno)
end,
}

server.listen("127.0.0.1:9999")

local client = cluster.new {
marshal = marshal,
unmarshal = unmarshal,
callret = callret,
call = function(msg, cmd, fd)
print("callee", msg.txt, cmd, fd)
return "pong", msg
end
print("callee", msg.txt, fd)
return msg
end,
close = function(fd, errno)
print("close", fd, errno)
end,
}


core.start(function()
for i = 1, 3 do
core.fork(function()
local conn = rpc.connect {
addr = "127.0.0.1:9999",
proto = proto,
timeout = 5000,
close = function(fd, errno)
end,
}
while true do
local fd, err = client.connect("127.0.0.1:9999")
print("connect", fd, err)
for j = 1, 10000 do
local txt = crypto.randomkey(5)
local ack, cmd = conn:call("ping", {txt = txt})
print("caller", conn, txt, ack.txt)
local ack, cmd = client.ping(fd, {txt = txt})
print("caller", fd, cmd, txt, ack.txt)
assert(ack.txt == txt)
assert(cmd == proto:tag("pong"))
assert(cmd == "pong")
core.sleep(1000)
end
end)
Expand Down
12 changes: 6 additions & 6 deletions lualib-src/lualib-core.c
Original file line number Diff line number Diff line change
Expand Up @@ -525,16 +525,16 @@ lsendsize(lua_State *L)
static int
ltracespan(lua_State *L)
{
silly_trace_span_t span;
span = (silly_trace_span_t)luaL_checkinteger(L, 1);
silly_tracespan_t span;
span = (silly_tracespan_t)luaL_checkinteger(L, 1);
silly_trace_span(span);
return 0;
}

static int
ltracenew(lua_State *L)
{
silly_trace_id_t traceid;
silly_traceid_t traceid;
traceid = silly_trace_new();
lua_pushinteger(L, (lua_Integer)traceid);
return 1;
Expand All @@ -543,13 +543,13 @@ ltracenew(lua_State *L)
static int
ltraceset(lua_State *L)
{
silly_trace_id_t traceid;
silly_traceid_t traceid;
lua_State *co = lua_tothread(L, 1);
silly_worker_resume(co);
if lua_isnoneornil(L, 2) {
traceid = TRACE_WORKER_ID;
} else {
traceid = (silly_trace_id_t)luaL_checkinteger(L, 2);
traceid = (silly_traceid_t)luaL_checkinteger(L, 2);
}
traceid = silly_trace_set(traceid);
lua_pushinteger(L, (lua_Integer)traceid);
Expand All @@ -559,7 +559,7 @@ ltraceset(lua_State *L)
static int
ltraceget(lua_State *L)
{
silly_trace_id_t traceid;
silly_traceid_t traceid;
traceid = silly_trace_get();
lua_pushinteger(L, (lua_Integer)traceid);
return 1;
Expand Down
172 changes: 69 additions & 103 deletions lualib-src/lualib-netpacket.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "silly_trace.h"
#include "silly_malloc.h"

#define ACK_BIT (1UL << 31)
#define DEFAULT_QUEUE_SIZE 2048
#define HASH_SIZE 2048
#define HASH(a) (a % HASH_SIZE)
Expand Down Expand Up @@ -43,6 +44,8 @@ struct netpacket {
struct incomplete *hash[HASH_SIZE];
};

static session_t session_idx = 0;

static int
lcreate(lua_State *L)
{
Expand Down Expand Up @@ -225,14 +228,11 @@ clear_incomplete(lua_State *L, int sid)
}

static inline const char *
getbuffer(lua_State *L, int *stk, size_t *sz)
getbuffer(lua_State *L, int n, size_t *sz)
{
int n = *stk;
if (lua_type(L, n) == LUA_TSTRING) {
*stk = n + 1;
return lua_tolstring(L, n, sz);
} else {
*stk = n + 2;
*sz = luaL_checkinteger(L, n + 1);
return lua_touserdata(L, n);
}
Expand All @@ -255,143 +255,112 @@ pop_packet(lua_State *L)
}
}

static int
lpop(lua_State *L)
{
struct packet *pk = pop_packet(L);
if (pk == NULL)
return 0;
lua_pushinteger(L, pk->fd);
lua_pushlightuserdata(L, pk->buff);
lua_pushinteger(L, pk->size);
return 3;
}
//rpc_cookie {traceid(uint64),cmd(uint32),session(uint32)}

static int
lpack(lua_State *L)
{
uint8_t *p;
int stk = 1;
size_t size;
const char *str;
str = getbuffer(L, &stk, &size);
if (size > USHRT_MAX)
luaL_error(L, "netpacket.pack data large then:%d\n", USHRT_MAX);
p = silly_malloc(2 + size);
p[0] = (size >> 8) & 0xff;
p[1] = size & 0xff;
memcpy(p + 2, str, size);
lua_pushlightuserdata(L, p);
lua_pushinteger(L, 2 + size);
return 2;
}
#define req_cookie_size (sizeof(silly_traceid_t)+sizeof(cmd_t)+sizeof(session_t))
#define req_traceid_ref(ptr) (*(silly_traceid_t*)(ptr))
#define req_cmd_ref(ptr) (*(cmd_t *)(ptr+sizeof(silly_traceid_t)))
#define req_session_ref(ptr) (*(session_t*)(ptr+sizeof(silly_traceid_t)+sizeof(cmd_t)))

#define ack_cookie_size (sizeof(session_t))
#define ack_session_ref(ptr) (*(session_t*)(ptr))

static int
lmsgpop(lua_State *L)
lpop(lua_State *L)
{
int size;
char *buf;
cmd_t cmd;
session_t session;
struct packet *pk = pop_packet(L);
if (pk == NULL)
return 0;
size = pk->size - sizeof(cmd_t);
size = pk->size - ack_cookie_size;
buf = pk->buff;
if (size < 0)
return 0;
//WARN: pointer cast may not align, can't cross platform
cmd = *(cmd_t *)(buf + size);
lua_pushinteger(L, pk->fd);
lua_pushlightuserdata(L, buf);
lua_pushinteger(L, size);
lua_pushinteger(L, cmd);
return 4;
session = ack_session_ref(buf+size);
if ((session & ACK_BIT) == ACK_BIT) { //rpc ack
lua_pushinteger(L, pk->fd);
lua_pushlightuserdata(L, buf);
lua_pushinteger(L, size);
lua_pushinteger(L, (lua_Integer)(session & ~ACK_BIT));
lua_pushnil(L); //cmd
lua_pushinteger(L, 0); //traceid
} else {
void *cookie;
size = pk->size - req_cookie_size;
cookie = (void *)(buf + size);
lua_pushinteger(L, pk->fd);
lua_pushlightuserdata(L, buf);
lua_pushinteger(L, size);
lua_pushinteger(L, session);
lua_pushinteger(L, req_cmd_ref(cookie));
lua_pushinteger(L, (lua_Integer)req_traceid_ref(cookie));
}
return 6;
}

static int
lmsgpack(lua_State *L)
lrequest(lua_State *L)
{
cmd_t cmd;
uint8_t *p;
const char *str;
void *cookie;
size_t size, body;
int cmd, stk = 1;
str = getbuffer(L, &stk, &size);
if (size > (USHRT_MAX - sizeof(cmd_t))) {
session_t session;
silly_traceid_t traceid;
cmd = luaL_checkinteger(L, 1);
traceid = luaL_checkinteger(L, 2);
str = getbuffer(L, 3, &size);
if (size > (USHRT_MAX - req_cookie_size)) {
luaL_error(L, "netpacket.pack data large then:%d\n",
USHRT_MAX - sizeof(cmd_t));
USHRT_MAX - req_cookie_size);
}
session = session_idx++;
if (session >= ACK_BIT) {
session_idx = 0;
session = 0;
}
cmd = luaL_checkinteger(L, stk);
body = size + sizeof(cmd_t);
body = size + req_cookie_size;
p = silly_malloc(2 + body);
p[0] = (body >> 8) & 0xff;
p[1] = body & 0xff;
memcpy(p + 2, str, size);
//WARN: pointer cast may not align, can't cross platform
*(cmd_t *)&p[size+2] = cmd;
cookie = (void *)&p[2 + size];
req_cmd_ref(cookie) = cmd;
req_session_ref(cookie) = session;
req_traceid_ref(cookie) = traceid;
lua_pushinteger(L, session);
lua_pushlightuserdata(L, p);
lua_pushinteger(L, 2 + body);
return 2;
}

struct rpc_cookie {
cmd_t cmd;
session_t session;
silly_trace_id_t traceid;
};

static int
lrpcpop(lua_State *L)
{
int size;
char *buf;
struct rpc_cookie *rpc;
struct packet *pk = pop_packet(L);
if (pk == NULL)
return 0;
size = pk->size - sizeof(struct rpc_cookie);
buf = pk->buff;
if (size < 0)
return 0;
//WARN: pointer cast may not align, can't cross platform
rpc = (struct rpc_cookie *)(buf + size);
lua_pushinteger(L, pk->fd);
lua_pushlightuserdata(L, buf);
lua_pushinteger(L, size);
lua_pushinteger(L, rpc->cmd);
lua_pushinteger(L, rpc->session);
lua_pushinteger(L, (lua_Integer)rpc->traceid);
return 6;
return 3;
}

static int
lrpcpack(lua_State *L)
lresponse(lua_State *L)
{
cmd_t cmd;
uint8_t *p;
const char *str;
void *cookie;
size_t size, body;
struct rpc_cookie *rpc;
int stk = 1;
session_t session;
silly_trace_id_t traceid;
str = getbuffer(L, &stk, &size);
if (size > (USHRT_MAX - sizeof(struct rpc_cookie))) {
session = luaL_checkinteger(L, 1) | ACK_BIT;
str = getbuffer(L, 2, &size);
if (size > (USHRT_MAX - ack_cookie_size)) {
luaL_error(L, "netpacket.pack data large then:%d\n",
USHRT_MAX - sizeof(struct rpc_cookie));
USHRT_MAX - ack_cookie_size);
}
cmd = luaL_checkinteger(L, stk);
session = luaL_checkinteger(L, stk+1);
traceid = luaL_checkinteger(L, stk+2);
body = size + sizeof(struct rpc_cookie);
body = size + ack_cookie_size;
p = silly_malloc(2 + body);
p[0] = (body >> 8) & 0xff;
p[1] = body & 0xff;
memcpy(p + 2, str, size);
//WARN: pointer cast may not align, can't cross platform
rpc = (struct rpc_cookie *)&p[2 + size];
rpc->cmd = cmd;
rpc->session = session;
rpc->traceid = traceid;
cookie = (void *)&p[2 + size];
ack_session_ref(cookie) = session;
lua_pushlightuserdata(L, p);
lua_pushinteger(L, 2 + body);
return 2;
Expand Down Expand Up @@ -480,14 +449,11 @@ int luaopen_core_netpacket(lua_State *L)
luaL_Reg tbl[] = {
{"create", lcreate},
{"pop", lpop},
{"pack", lpack},
{"msgpop", lmsgpop},
{"msgpack", lmsgpack},
{"rpcpop", lrpcpop},
{"rpcpack", lrpcpack},
{"request", lrequest},
{"response", lresponse},
{"clear", lclear},
{"tostring", ltostring},
{"drop", ldrop},
{"tostring", ltostring},
{"message", lmessage},
{NULL, NULL},
};
Expand Down
Loading

0 comments on commit 204f2d8

Please sign in to comment.