Skip to content

Commit

Permalink
fix: websocket exception (#430)
Browse files Browse the repository at this point in the history
  • Loading branch information
huaxiabuluo authored Jan 16, 2023
1 parent 5b4f345 commit a360c9f
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 91 deletions.
1 change: 1 addition & 0 deletions app/interfaces/tools.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export type Recordable<T = unknown> = Record<string, T>;
198 changes: 115 additions & 83 deletions app/utils/websocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,23 @@ import { getRootStore } from '@app/stores';
import { message } from 'antd';
import { v4 as uuidv4 } from 'uuid';
import JSONBigint from 'json-bigint';
import { Recordable } from '@app/interfaces/tools';
import { safeParse } from './function';
import { HttpResCode } from './http';

export interface MessageReceive<T extends unknown = Record<string, unknown>> {
export interface MessageSend<T extends unknown = Recordable> {
header: {
msgId: string;
version: string;
};
body: {
product: string;
msgType: string;
content: T;
};
}

export interface MessageReceive<T extends unknown = Recordable> {
header: {
msgId: string;
sendTime: number;
Expand All @@ -25,6 +38,36 @@ export interface NgqlRes<T = any> {
export const WsHeartbeatReq = '1';
export const WsHeartbeatRes = '2';

type MessageReceiverProps<R = unknown> = {
resolve: (res: R) => void;
reject: () => void;
content: MessageSend['body']['content'];
product: string;
msgType: string;
config?: Recordable;
};

export class MessageReceiver<R extends NgqlRes = NgqlRes> {
resolve: (res: R) => void;
reject: () => void;
onError?: (e: Error) => void;
messageSend: MessageSend;
config?: Recordable;

sendTime = Date.now();

constructor(props: MessageReceiverProps<R>) {
const { resolve, reject, content, product, msgType, config = {} } = props;
this.resolve = resolve;
this.reject = reject;
this.config = config;
this.messageSend = {
header: { msgId: uuidv4(), version: '1.0' },
body: { product: product, msgType, content },
};
}
}

export class NgqlRunner {
socket: WebSocket | undefined = undefined;

Expand All @@ -34,6 +77,7 @@ export class NgqlRunner {
product = 'Studio';

socketMessageListeners: Array<(e: MessageEvent) => void> = [];
messageReceiverMap = new Map<string, MessageReceiver>();

socketConnectingPromise: Promise<boolean> | undefined;
socketPingTimeInterval: number | undefined;
Expand All @@ -51,16 +95,19 @@ export class NgqlRunner {
this.socketMessageListeners.push(listener);
};

clearSocketMessageListener = () => {
this.socketMessageListeners.forEach((l) => this.socket?.removeEventListener('message', l));
this.socketMessageListeners = [];
};

rmSocketMessageListener = (listener: (e: MessageEvent) => void) => {
this.socket?.removeEventListener('message', listener);
this.socketMessageListeners = this.socketMessageListeners.filter((l) => l !== listener);
};

clearSocketMessageListener = () => {
this.socketMessageListeners.forEach((l) => {
this.socket?.removeEventListener('message', l);
});
this.socketMessageListeners = [];
clearMessageReceiver = () => {
this.messageReceiverMap.forEach((receiver) => receiver.resolve({ code: -1, message: 'WebSocket closed' }));
this.messageReceiverMap.clear();
};

connect = (url: string | URL, protocols?: string | string[]) => {
Expand Down Expand Up @@ -95,9 +142,9 @@ export class NgqlRunner {
socket.onerror = undefined;
socket.onclose = undefined;

// reconnect
this.socket.addEventListener('close', this.onDisconnect);
this.socket.addEventListener('error', this.onError);
this.socket.addEventListener('message', this.onMessage);

resolve(true);
};
Expand All @@ -118,36 +165,60 @@ export class NgqlRunner {
return this.connect(this.socketUrl, this.socketProtocols);
};

onMessage = (e: MessageEvent<string>) => {
if (e.data === WsHeartbeatRes) {
return;
}

const msgReceive = safeParse<MessageReceive<NgqlRes>>(e.data, { paser: JSONBigint.parse });
if (msgReceive?.body?.content?.code === HttpResCode.ErrSession) {
getRootStore().global.logout();
return;
}

const messageReceiver = this.messageReceiverMap.get(msgReceive?.header?.msgId);
if (messageReceiver && msgReceive.header.msgId === messageReceiver.messageSend.header.msgId) {
const content = msgReceive.body.content;
if (messageReceiver.config.noTip !== true && content.code !== 0) {
message.error(content.message);
}
messageReceiver.resolve(content);
this.messageReceiverMap.delete(msgReceive.header.msgId);
}
};

onError = (e: Event) => {
console.error('=====ngqlSocket error', e);
message.error('WebSocket error, try to reconnect...');
this.onDisconnect();
};

onDisconnect = () => {
console.log('=====onDisconnect');
this.socket?.removeEventListener('close', this.onDisconnect);
this.socket?.removeEventListener('error', this.onError);
this.clearSocketMessageListener();
this.socket?.close();

this.stopSocketPing();
this.socket = undefined;
this.clearMessageReceiver();
this.closeSocket();

// try reconnect
this.socketUrl && setTimeout(this.reConnect, 1000);
};

stopSocketPing = () => {
closeSocket = () => {
this.socket?.close();
this.socket = undefined;

clearTimeout(this.socketPingTimeInterval);
this.socketPingTimeInterval = undefined;
};

desctory = () => {
this.stopSocketPing();
this.clearSocketMessageListener();
this.socket?.close();
this.socket = undefined;
// disable reconnect
this.socketUrl = undefined;
this.socketProtocols = undefined;

this.onDisconnect();
};

ping = () => {
Expand All @@ -156,89 +227,50 @@ export class NgqlRunner {

runNgql = async (
{ gql, paramList, space }: { gql: string; paramList?: string[]; space?: string },
config: Record<string, unknown> = {},
config: Recordable = {},
): Promise<NgqlRes> => {
const reqMsg = {
header: {
msgId: uuidv4(),
version: '1.0',
},
body: {
product: this.product,
msgType: 'ngql',
content: { gql, paramList, space },
},
};

if (!this.socket || this.socket.readyState !== WebSocket.OPEN) {
await this.reConnect();
const flag = await this.reConnect();
if (!flag) {
return Promise.resolve({ code: -1, message: 'WebSocket reconnect failed' });
}
}

return new Promise((resolve) => {
const receiveMsg = (e: MessageEvent<string>) => {
if (e.data === WsHeartbeatRes) {
return;
}
const msgReceive = safeParse<MessageReceive<NgqlRes>>(e.data, { paser: JSONBigint.parse });
if (msgReceive?.body?.content?.code === HttpResCode.ErrSession) {
getRootStore().global.logout();
return;
}
if (msgReceive?.header?.msgId === reqMsg.header.msgId) {
const content = msgReceive.body.content;
if (config.hideErrMsg !== false && content.code !== 0) {
message.error(content.message);
}
resolve(msgReceive.body.content);
this.rmSocketMessageListener(receiveMsg);
}
};
return new Promise((resolve, reject) => {
const messageReceiver = new MessageReceiver({
resolve,
reject,
product: this.product,
content: { gql, paramList, space },
config,
msgType: 'ngql',
});

this.socket?.send(JSON.stringify(reqMsg));
this.addSocketMessageListener(receiveMsg);
this.socket.send(JSON.stringify(messageReceiver.messageSend));
this.messageReceiverMap.set(messageReceiver.messageSend.header.msgId, messageReceiver);
});
};

runBatchNgql = async (
{ gqls, paramList, space }: { gqls: string[]; paramList?: string[]; space?: string },
_config: Record<string, unknown> = {},
config: Recordable = {},
): Promise<NgqlRes> => {
const message = {
header: {
msgId: uuidv4(),
version: '1.0',
},
body: {
product: this.product,
msgType: 'batch_ngql',
content: { gqls, paramList, space },
},
};

if (!this.socket || this.socket.readyState !== WebSocket.OPEN) {
await this.reConnect();
}

return new Promise((resolve) => {
const receiveMsg = (e: MessageEvent<string>) => {
if (e.data === WsHeartbeatRes) {
return;
}
const msgReceive = safeParse<MessageReceive<NgqlRes>>(e.data, { paser: JSONBigint.parse });
if (msgReceive?.body?.content?.code === HttpResCode.ErrSession) {
this.desctory();
getRootStore().global.logout();
return;
}
if (msgReceive?.header?.msgId === message.header.msgId) {
resolve(msgReceive.body.content);
this.rmSocketMessageListener(receiveMsg);
}
};
receiveMsg.sendTime = Date.now();
return new Promise((resolve, reject) => {
const messageReceiver = new MessageReceiver({
resolve,
reject,
product: this.product,
content: { gqls, paramList, space },
config,
msgType: 'batch_ngql',
});

this.socket?.send(JSON.stringify(message));
this.addSocketMessageListener(receiveMsg);
this.socket.send(JSON.stringify(messageReceiver.messageSend));
this.messageReceiverMap.set(messageReceiver.messageSend.header.msgId, messageReceiver);
});
};
}
Expand Down
31 changes: 23 additions & 8 deletions server/api/studio/pkg/ws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (c *Client) switchSpace(msgReceived *MessageReceive) *map[string]any {
reqSpace = strings.Replace(reqSpace, "`", "\\`", -1)
_, _, err := dao.Execute(c.clientInfo.NSID, fmt.Sprintf("USE `%s`", reqSpace), nil)
if err != nil {
logx.ErrorStackf("[WebSocket ngql query]: msgReceived.Body.Content(%v); error(%v)", &msgReceived.Body.Content, err)
logx.Errorf("[WebSocket switchSpace]: msgReceived.Body.Content(%+v); error(%+v)", &msgReceived.Body.Content, err)
content := map[string]any{
"code": base.Error,
"message": err.Error(),
Expand All @@ -82,7 +82,14 @@ func (c *Client) switchSpace(msgReceived *MessageReceive) *map[string]any {
return nil
}

func (c *Client) runNgql(msgReceived *MessageReceive) {
func (c *Client) runNgql(msgReceived *MessageReceive) (closed bool) {
defer func() {
if err := recover(); err != nil {
logx.Errorf("[WebSocket runNgql panic]: %+v", err)
closed = true
}
}()

msgPost := MessagePost{
Header: MessagePostHeader{
MsgId: msgReceived.Header.MsgId,
Expand Down Expand Up @@ -117,7 +124,7 @@ func (c *Client) runNgql(msgReceived *MessageReceive) {

execute, _, err := dao.Execute(c.clientInfo.NSID, gql, paramList)
if err != nil {
logx.ErrorStackf("[WebSocket ngql query]: msgReceived.Body.Content(%v); error(%v)", &msgReceived.Body.Content, err)
logx.Errorf("[WebSocket runNgql]: msgReceived.Body.Content(%v); error(%v)", &msgReceived.Body.Content, err)
content := map[string]any{
"code": base.Error,
"message": err.Error(),
Expand All @@ -133,12 +140,19 @@ func (c *Client) runNgql(msgReceived *MessageReceive) {
"message": "Success",
}
}

msgSend, _ := json.Marshal(msgPost)
c.send <- msgSend
return false
}

func (c *Client) runBatchNgql(msgReceived *MessageReceive) {
func (c *Client) runBatchNgql(msgReceived *MessageReceive) (closed bool) {
defer func() {
if err := recover(); err != nil {
logx.Errorf("[WebSocket runBatchNgql panic]: %+v", err)
closed = true
}
}()

msgPost := MessagePost{
Header: MessagePostHeader{
MsgId: msgReceived.Header.MsgId,
Expand Down Expand Up @@ -212,6 +226,7 @@ func (c *Client) runBatchNgql(msgReceived *MessageReceive) {
}
msgSend, _ := json.Marshal(msgPost)
c.send <- msgSend
return false
}

// readPump pumps messages from the websocket connection to the hub.
Expand All @@ -231,9 +246,9 @@ func (c *Client) readPump() {
_, message, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
logx.ErrorStackf("[WebSocket UnexpectedClose]: %v", err)
logx.Errorf("[WebSocket UnexpectedClose]: %v", err)
} else {
logx.ErrorStackf("[WebSocket ReadMessage]: %v", err)
logx.Errorf("[WebSocket ReadMessage]: %v", err)
}
break
}
Expand Down Expand Up @@ -283,7 +298,7 @@ func (c *Client) writePump() {

w, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil {
logx.ErrorStackf("[WebSocket WriteMessage]: %v", err)
logx.Errorf("[WebSocket WriteMessage]: %v", err)
return
}
w.Write(message)
Expand Down

0 comments on commit a360c9f

Please sign in to comment.