Skip to content
This repository has been archived by the owner on Jun 11, 2024. It is now read-only.

Added LSK check on beforeCrossChainMessageForwarding #8806

Merged
merged 13 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions framework/src/modules/fee/cc_method.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ export class FeeInteroperableMethod extends BaseCCMethod {
}

public async beforeCrossChainCommandExecute(ctx: CrossChainMessageContext): Promise<void> {
const messageTokenID = await this._interopMethod.getMessageFeeTokenID(
ctx,
ctx.ccm.sendingChainID,
);
const messageTokenID = await this._interopMethod.getMessageFeeTokenIDFromCCM(ctx, ctx.ccm);
await this._tokenMethod.lock(
ctx.getMethodContext(),
ctx.transaction.senderAddress,
Expand All @@ -59,10 +56,7 @@ export class FeeInteroperableMethod extends BaseCCMethod {
}

public async afterCrossChainCommandExecute(ctx: CrossChainMessageContext): Promise<void> {
const messageTokenID = await this._interopMethod.getMessageFeeTokenID(
ctx,
ctx.ccm.sendingChainID,
);
const messageTokenID = await this._interopMethod.getMessageFeeTokenIDFromCCM(ctx, ctx.ccm);
await this._tokenMethod.unlock(
ctx.getMethodContext(),
ctx.transaction.senderAddress,
Expand Down
2 changes: 2 additions & 0 deletions framework/src/modules/fee/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import { MethodContext, ImmutableMethodContext } from '../../state_machine/types';
import { JSONObject } from '../../types';
import { CCMsg } from '../interoperability';

export type FeeTokenID = Buffer;

Expand Down Expand Up @@ -76,4 +77,5 @@ export interface GetMinFeePerByteResponse {

export interface InteroperabilityMethod {
getMessageFeeTokenID(methodContext: ImmutableMethodContext, chainID: Buffer): Promise<Buffer>;
getMessageFeeTokenIDFromCCM(methodContext: ImmutableMethodContext, ccm: CCMsg): Promise<Buffer>;
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ export abstract class BaseInteroperabilityMethod<
return channel.messageFeeTokenID;
}

public async getMessageFeeTokenIDFromCCM(
context: ImmutableMethodContext,
ccm: CCMsg,
): Promise<Buffer> {
return this.getMessageFeeTokenID(context, ccm.sendingChainID);
}

// https://github.com/LiskHQ/lips/blob/main/proposals/lip-0045.md#getminreturnfeeperbyte
public async getMinReturnFeePerByte(
context: ImmutableMethodContext,
Expand Down
20 changes: 9 additions & 11 deletions framework/src/modules/token/cc_method.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import { RecoverEvent } from './events/recover';
import { EMPTY_BYTES } from '../interoperability/constants';
import { BeforeCCMForwardingEvent } from './events/before_ccm_forwarding';
import { splitTokenID } from './utils';
import { getEncodedCCMAndID } from '../interoperability/utils';
import { getEncodedCCMAndID, getMainchainTokenID } from '../interoperability/utils';
import { InternalMethod } from './internal_method';

export class TokenInteroperableMethod extends BaseCCMethod {
Expand All @@ -45,10 +45,7 @@ export class TokenInteroperableMethod extends BaseCCMethod {
ccm,
} = ctx;
const methodContext = ctx.getMethodContext();
const tokenID = await this._interopMethod.getMessageFeeTokenID(
methodContext,
ccm.sendingChainID,
);
const tokenID = await this._interopMethod.getMessageFeeTokenIDFromCCM(methodContext, ccm);
const { ccmID } = getEncodedCCMAndID(ccm);
const [chainID] = splitTokenID(tokenID);
const userStore = this.stores.get(UserStore);
Expand Down Expand Up @@ -88,12 +85,16 @@ export class TokenInteroperableMethod extends BaseCCMethod {
public async beforeCrossChainMessageForwarding(ctx: CrossChainMessageContext): Promise<void> {
const { ccm } = ctx;
const methodContext = ctx.getMethodContext();
const messageFeeTokenID = await this._interopMethod.getMessageFeeTokenID(
const messageFeeTokenID = await this._interopMethod.getMessageFeeTokenIDFromCCM(
methodContext,
ccm.receivingChainID,
ccm,
);
const { ccmID } = getEncodedCCMAndID(ccm);

if (!messageFeeTokenID.equals(getMainchainTokenID(ctx.chainID))) {
shuse2 marked this conversation as resolved.
Show resolved Hide resolved
throw new Error('Message fee token should be LSK.');
}

const escrowStore = this.stores.get(EscrowStore);
const escrowKey = escrowStore.getKey(ccm.sendingChainID, messageFeeTokenID);
const escrowAccount = await escrowStore.getOrDefault(methodContext, escrowKey);
Expand Down Expand Up @@ -128,10 +129,7 @@ export class TokenInteroperableMethod extends BaseCCMethod {
public async verifyCrossChainMessage(ctx: CrossChainMessageContext): Promise<void> {
const { ccm } = ctx;
const methodContext = ctx.getMethodContext();
const tokenID = await this._interopMethod.getMessageFeeTokenID(
methodContext,
ccm.sendingChainID,
);
const tokenID = await this._interopMethod.getMessageFeeTokenIDFromCCM(methodContext, ccm);
const [chainID] = splitTokenID(tokenID);
if (chainID.equals(ctx.chainID)) {
const escrowStore = this.stores.get(EscrowStore);
Expand Down
1 change: 1 addition & 0 deletions framework/src/modules/token/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ export interface InteroperabilityMethod {
terminateChain(methodContext: MethodContext, chainID: Buffer): Promise<void>;
getChannel(methodContext: MethodContext, chainID: Buffer): Promise<ChannelData>;
getMessageFeeTokenID(methodContext: ImmutableMethodContext, chainID: Buffer): Promise<Buffer>;
getMessageFeeTokenIDFromCCM(methodContext: ImmutableMethodContext, ccm: CCMsg): Promise<Buffer>;
}

export interface FeeMethod {
Expand Down
1 change: 1 addition & 0 deletions framework/test/unit/modules/fee/cc_method.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ describe('FeeInteroperableMethod', () => {
feeMethod.addDependencies(
{
getMessageFeeTokenID: jest.fn().mockResolvedValue(messageFeeTokenID),
getMessageFeeTokenIDFromCCM: jest.fn().mockResolvedValue(messageFeeTokenID),
},
{
burn: jest.fn(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ describe('CrossChain Transfer Command', () => {
terminateChain: jest.fn(),
getChannel: jest.fn(),
getMessageFeeTokenID: jest.fn(),
getMessageFeeTokenIDFromCCM: jest.fn(),
};
const config = {
ownChainID,
Expand Down
67 changes: 52 additions & 15 deletions framework/test/unit/modules/token/cc_method.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ describe('TokenInteroperableMethod', () => {
'hex',
);
const defaultAddress = address.getAddressFromPublicKey(defaultPublicKey);
const ownChainID = Buffer.from([0, 0, 0, 1]);
const ownChainID = Buffer.from([1, 0, 0, 0]);
const defaultTokenID = Buffer.concat([ownChainID, Buffer.alloc(4)]);
const defaultForeignTokenID = Buffer.from([0, 0, 0, 2, 0, 0, 0, 0]);
const defaultForeignTokenID = Buffer.from([2, 0, 0, 0, 0, 0, 0, 0]);
const defaultAccount = {
availableBalance: BigInt(10000000000),
lockedBalances: [
Expand Down Expand Up @@ -119,6 +119,7 @@ describe('TokenInteroperableMethod', () => {
{
send: jest.fn().mockResolvedValue(true),
getMessageFeeTokenID: jest.fn().mockResolvedValue(defaultTokenID),
getMessageFeeTokenIDFromCCM: jest.fn().mockResolvedValue(defaultTokenID),
} as never,
internalMethod,
);
Expand Down Expand Up @@ -153,7 +154,7 @@ describe('TokenInteroperableMethod', () => {
describe('beforeCrossChainCommandExecute', () => {
it('should credit fee to transaction sender if token id is not native', async () => {
jest
.spyOn(tokenInteropMethod['_interopMethod'], 'getMessageFeeTokenID')
.spyOn(tokenInteropMethod['_interopMethod'], 'getMessageFeeTokenIDFromCCM')
.mockResolvedValue(defaultForeignTokenID);
await expect(
tokenInteropMethod.beforeCrossChainCommandExecute({
Expand All @@ -162,7 +163,7 @@ describe('TokenInteroperableMethod', () => {
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: Buffer.from([0, 0, 0, 1]),
receivingChainID: ownChainID,
fee,
status: CCM_STATUS_OK,
params: utils.getRandomBytes(30),
Expand Down Expand Up @@ -205,7 +206,7 @@ describe('TokenInteroperableMethod', () => {
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: Buffer.from([0, 0, 0, 1]),
receivingChainID: ownChainID,
fee: fee + defaultEscrowAmount,
status: CCM_STATUS_OK,
params: utils.getRandomBytes(30),
Expand Down Expand Up @@ -243,7 +244,7 @@ describe('TokenInteroperableMethod', () => {
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: Buffer.from([0, 0, 0, 1]),
receivingChainID: ownChainID,
fee,
status: CCM_STATUS_OK,
params: utils.getRandomBytes(30),
Expand Down Expand Up @@ -285,6 +286,42 @@ describe('TokenInteroperableMethod', () => {
});

describe('beforeCrossChainMessageForwarding', () => {
it('should throw if messageFeeTokenID is not LSK', async () => {
jest
.spyOn(tokenInteropMethod['_interopMethod'], 'getMessageFeeTokenIDFromCCM')
.mockResolvedValue(defaultForeignTokenID);
await expect(
tokenInteropMethod.beforeCrossChainMessageForwarding({
ccm: {
crossChainCommand: CROSS_CHAIN_COMMAND_NAME_TRANSFER,
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: ownChainID,
fee: fee + defaultEscrowAmount,
status: CCM_STATUS_OK,
params: utils.getRandomBytes(30),
},
getMethodContext: () => methodContext,
eventQueue: new EventQueue(0),
getStore: (moduleID: Buffer, prefix: Buffer) => stateStore.getStore(moduleID, prefix),
logger: fakeLogger,
chainID: ownChainID,
header: {
timestamp: Date.now(),
height: 10,
},
stateStore,
contextStore,
transaction: {
fee,
senderAddress: defaultAddress,
params: defaultEncodedCCUParams,
},
}),
).rejects.toThrow('Message fee token should be LSK.');
});

it('should throw if escrow balance is not sufficient', async () => {
await expect(
tokenInteropMethod.beforeCrossChainMessageForwarding({
Expand All @@ -293,7 +330,7 @@ describe('TokenInteroperableMethod', () => {
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: Buffer.from([0, 0, 0, 1]),
receivingChainID: ownChainID,
fee: fee + defaultEscrowAmount,
status: CCM_STATUS_OK,
params: utils.getRandomBytes(30),
Expand Down Expand Up @@ -323,15 +360,15 @@ describe('TokenInteroperableMethod', () => {
);
});

it('should deduct escrow account for fee and credit to receving chain escrow account if ccm command is not transfer', async () => {
it('should deduct escrow account for fee and credit to receiving chain escrow account if ccm command is not transfer', async () => {
await expect(
tokenInteropMethod.beforeCrossChainMessageForwarding({
ccm: {
crossChainCommand: CROSS_CHAIN_COMMAND_REGISTRATION,
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: Buffer.from([0, 0, 0, 1]),
receivingChainID: ownChainID,
fee,
status: CCM_STATUS_OK,
params: codec.encode(crossChainForwardMessageParams, {
Expand Down Expand Up @@ -370,7 +407,7 @@ describe('TokenInteroperableMethod', () => {
expect(amount).toEqual(defaultEscrowAmount - fee);
const { amount: receiver } = await escrowStore.get(
methodContext,
escrowStore.getKey(Buffer.from([0, 0, 0, 1]), defaultTokenID),
escrowStore.getKey(ownChainID, defaultTokenID),
);
expect(receiver).toEqual(fee);
});
Expand All @@ -385,7 +422,7 @@ describe('TokenInteroperableMethod', () => {
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: Buffer.from([0, 0, 0, 1]),
receivingChainID: ownChainID,
fee,
status: CCM_STATUS_OK,
params: utils.getRandomBytes(30),
Expand Down Expand Up @@ -418,7 +455,7 @@ describe('TokenInteroperableMethod', () => {
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: Buffer.from([0, 0, 0, 1]),
receivingChainID: ownChainID,
fee: fee + defaultEscrowAmount,
status: CCM_STATUS_OK,
params: utils.getRandomBytes(30),
Expand All @@ -445,7 +482,7 @@ describe('TokenInteroperableMethod', () => {

it('should resolve if token id is not native', async () => {
jest
.spyOn(tokenInteropMethod['_interopMethod'], 'getMessageFeeTokenID')
.spyOn(tokenInteropMethod['_interopMethod'], 'getMessageFeeTokenIDFromCCM')
.mockResolvedValue(defaultForeignTokenID);
await expect(
tokenInteropMethod.verifyCrossChainMessage({
Expand All @@ -454,7 +491,7 @@ describe('TokenInteroperableMethod', () => {
module: tokenModule.name,
nonce: BigInt(1),
sendingChainID,
receivingChainID: Buffer.from([0, 0, 0, 1]),
receivingChainID: ownChainID,
fee,
status: CCM_STATUS_OK,
params: utils.getRandomBytes(30),
Expand Down Expand Up @@ -549,7 +586,7 @@ describe('TokenInteroperableMethod', () => {

it('should reject if token is not native', async () => {
jest
.spyOn(tokenInteropMethod['_interopMethod'], 'getMessageFeeTokenID')
.spyOn(tokenInteropMethod['_interopMethod'], 'getMessageFeeTokenIDFromCCM')
.mockResolvedValue(defaultForeignTokenID);
await expect(
tokenInteropMethod.recover({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ describe('Transfer command', () => {
terminateChain: jest.fn(),
getChannel: jest.fn(),
getMessageFeeTokenID: jest.fn(),
getMessageFeeTokenIDFromCCM: jest.fn(),
};
internalMethod.addDependencies({ payFee: jest.fn() });
method.addDependencies(interopMethod, internalMethod);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ describe('CCTransfer command', () => {
terminateChain: jest.Mock;
getChannel: jest.Mock;
getMessageFeeTokenID: jest.Mock;
getMessageFeeTokenIDFromCCM: jest.Mock;
};

beforeEach(() => {
Expand All @@ -129,6 +130,7 @@ describe('CCTransfer command', () => {
terminateChain: jest.fn(),
getChannel: jest.fn(),
getMessageFeeTokenID: jest.fn(),
getMessageFeeTokenIDFromCCM: jest.fn(),
};
internalMethod.addDependencies({
payFee: jest.fn(),
Expand Down
1 change: 1 addition & 0 deletions framework/test/unit/modules/token/endpoint.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ describe('token endpoint', () => {
terminateChain: jest.fn(),
getChannel: jest.fn(),
getMessageFeeTokenID: jest.fn(),
getMessageFeeTokenIDFromCCM: jest.fn(),
},
internalMethod,
);
Expand Down