Skip to content

Commit

Permalink
mysql数据库禁写模块UT
Browse files Browse the repository at this point in the history
Signed-off-by: daizhenyu <[email protected]>
  • Loading branch information
daizhenyu committed Feb 20, 2024
1 parent 05df4e0 commit 128c164
Show file tree
Hide file tree
Showing 8 changed files with 647 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (C) 2024-2024 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.huaweicloud.sermant.database.utils;

import org.junit.Assert;
import org.junit.Test;

/**
* SqlParserUtils单元测试
*
* @author daizhenyu
* @since 2024-02-06
**/
public class SqlParserUtilsTest {
private String sql;

@Test
public void testIsWriteOperation() {
// sql为写操作
sql = "INSERT INTO table (name) VALUES ('test')";
Assert.assertTrue(SqlParserUtils.isWriteOperation(sql));

sql = "CREATE TABLE table (name VARCHAR(255))";
Assert.assertTrue(SqlParserUtils.isWriteOperation(sql));

sql = "DROP INDEX idx_name on table";
Assert.assertTrue(SqlParserUtils.isWriteOperation(sql));

// sql为读操作
sql = "SELECT * FROM table";
Assert.assertFalse(SqlParserUtils.isWriteOperation(sql));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@
<artifactId>database-controller</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright (C) 2024-2024 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.huaweicloud.sermant.mariadbv2.interceptors;

import com.huaweicloud.sermant.core.plugin.agent.entity.ExecuteContext;
import com.huaweicloud.sermant.database.config.DatabaseWriteProhibitionConfig;
import com.huaweicloud.sermant.database.config.DatabaseWriteProhibitionManager;

import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.mariadb.jdbc.HostAddress;
import org.mariadb.jdbc.internal.protocol.MasterProtocol;
import org.mariadb.jdbc.internal.protocol.Protocol;
import org.mockito.Mockito;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
* executeBatchStmt方法拦截器单元测试
*
* @author daizhenyu
* @since 2024-02-06
**/
public class ExecuteBatchStmtInterceptorTest {
private static DatabaseWriteProhibitionConfig globalConfig = new DatabaseWriteProhibitionConfig();

private static ExecuteContext context;

private static Method methodMock;

private static Protocol protocolMock;

private static Object[] argument;

private static List<String> sqlList;

private ExecuteBatchStmtInterceptor interceptor = new ExecuteBatchStmtInterceptor();

@BeforeClass
public static void setUp() {
DatabaseWriteProhibitionManager.updateGlobalConfig(globalConfig);
protocolMock = Mockito.mock(MasterProtocol.class);
methodMock = Mockito.mock(Method.class);
HostAddress serverAddress = new HostAddress("127.0.0.1", 8080);
Mockito.when(protocolMock.getHostAddress()).thenReturn(serverAddress);
Mockito.when(protocolMock.getDatabase()).thenReturn("database-test");
sqlList = new ArrayList<>();
sqlList.add("INSERT INTO table (name) VALUES ('test')");
argument = new Object[]{null, null, sqlList};
}

@AfterClass
public static void tearDown() {
Mockito.clearAllCaches();
DatabaseWriteProhibitionManager.updateGlobalConfig(null);
}

@Test
public void testDoBefore() throws Exception {
// 数据库禁写开关关闭
globalConfig.setEnableMySqlWriteProhibition(false);
context = ExecuteContext.forMemberMethod(protocolMock, methodMock, argument, null, null);
interceptor.before(context);
Assert.assertNull(context.getThrowableOut());

// 数据库禁写开关关闭,禁写数据库set包含被拦截的数据库
Set<String> databases = new HashSet<>();
databases.add("database-test");
globalConfig.setMySqlDatabases(databases);
interceptor.before(context);
Assert.assertNull(context.getThrowableOut());

//数据库禁写开关打开,禁写数据库集合包含被拦截的数据库
globalConfig.setEnableMySqlWriteProhibition(true);
context = ExecuteContext.forMemberMethod(protocolMock, methodMock, argument, null, null);
interceptor.before(context);
Assert.assertEquals("Database prohibit to write, database: database-test",
context.getThrowableOut().getMessage());

//数据库禁写开关打开,sql没有写操作,禁写数据库集合包含被拦截的数据库
sqlList = new ArrayList<>();
argument = new Object[]{null, null, sqlList};
context = ExecuteContext.forMemberMethod(protocolMock, methodMock, argument, null, null);
interceptor.before(context);
Assert.assertNull(context.getThrowableOut());

//数据库禁写开关打开,禁写数据库集合不包含被拦截的数据库
sqlList.add("INSERT INTO table (name) VALUES ('test')");
globalConfig.setMySqlDatabases(new HashSet<>());
context = ExecuteContext.forMemberMethod(protocolMock, methodMock, argument, null, null);
interceptor.before(context);
Assert.assertNull(context.getThrowableOut());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright (C) 2024-2024 Huawei Technologies Co., Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.huaweicloud.sermant.mariadbv2.interceptors;

import com.huaweicloud.sermant.core.plugin.agent.entity.ExecuteContext;
import com.huaweicloud.sermant.database.config.DatabaseWriteProhibitionConfig;
import com.huaweicloud.sermant.database.config.DatabaseWriteProhibitionManager;

import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.mariadb.jdbc.HostAddress;
import org.mariadb.jdbc.internal.protocol.MasterProtocol;
import org.mariadb.jdbc.internal.protocol.Protocol;
import org.mariadb.jdbc.internal.util.dao.ClientPrepareResult;
import org.mockito.Mockito;

import java.lang.reflect.Method;
import java.util.HashSet;
import java.util.Set;

/**
* execute方法拦截器单元测试
*
* @author daizhenyu
* @since 2024-02-06
**/
public class ExecuteInterceptorTest {
private static final int PARAM_INDEX = 2;

private static DatabaseWriteProhibitionConfig globalConfig = new DatabaseWriteProhibitionConfig();

private static ExecuteContext context;

private static Method methodMock;

private static Protocol protocolMock;

private static Object[] argument;

private static String sql;

private static ClientPrepareResult resultMock;

private ExecuteInterceptor interceptor = new ExecuteInterceptor();

@BeforeClass
public static void setUp() {
DatabaseWriteProhibitionManager.updateGlobalConfig(globalConfig);
sql = "INSERT INTO table (name) VALUES ('test')";
protocolMock = Mockito.mock(MasterProtocol.class);
methodMock = Mockito.mock(Method.class);
resultMock = Mockito.mock(ClientPrepareResult.class);
HostAddress serverAddress = new HostAddress("127.0.0.1", 8080);
Mockito.when(protocolMock.getHostAddress()).thenReturn(serverAddress);
Mockito.when(protocolMock.getDatabase()).thenReturn("database-test");
Mockito.when(resultMock.getSql()).thenReturn(sql);
argument = new Object[]{null, null, "INSERT INTO table (name) VALUES ('test')"};
}

@AfterClass
public static void tearDown() {
Mockito.clearAllCaches();
DatabaseWriteProhibitionManager.updateGlobalConfig(null);
}

@Test
public void testDoBefore() throws Exception {
// 数据库禁写开关关闭
globalConfig.setEnableMySqlWriteProhibition(false);
context = ExecuteContext.forMemberMethod(protocolMock, methodMock, argument, null, null);
interceptor.before(context);
Assert.assertNull(context.getThrowableOut());

// 数据库禁写开关关闭,禁写数据库set包含被拦截的数据库
Set<String> databases = new HashSet<>();
databases.add("database-test");
globalConfig.setMySqlDatabases(databases);
interceptor.before(context);
Assert.assertNull(context.getThrowableOut());

//数据库禁写开关打开,禁写数据库集合包含被拦截的数据库, 方法入参为String
globalConfig.setEnableMySqlWriteProhibition(true);
context = ExecuteContext.forMemberMethod(protocolMock, methodMock, argument, null, null);
interceptor.before(context);
Assert.assertEquals("Database prohibit to write, database: database-test",
context.getThrowableOut().getMessage());

//数据库禁写开关打开,禁写数据库集合包含被拦截的数据库,方法入参为ClientPrepareResult
argument[PARAM_INDEX] = resultMock;
context = ExecuteContext.forMemberMethod(protocolMock, methodMock, argument, null, null);
interceptor.before(context);
Assert.assertEquals("Database prohibit to write, database: database-test",
context.getThrowableOut().getMessage());

//数据库禁写开关打开,sql没有写操作,禁写数据库集合包含被拦截的数据库
sql = "SELECT * FROM table";
argument[PARAM_INDEX] = sql;
context = ExecuteContext.forMemberMethod(protocolMock, methodMock, argument, null, null);
interceptor.before(context);
Assert.assertNull(context.getThrowableOut());

//数据库禁写开关打开,禁写数据库集合不包含被拦截的数据库, 方法入参为String
argument[PARAM_INDEX] = "INSERT INTO table (name) VALUES ('test')";
globalConfig.setMySqlDatabases(new HashSet<>());
context = ExecuteContext.forMemberMethod(protocolMock, methodMock, argument, null, null);
interceptor.before(context);
Assert.assertNull(context.getThrowableOut());

//数据库禁写开关打开,禁写数据库集合不包含被拦截的数据库,方法入参为ClientPrepareResult
argument[PARAM_INDEX] = resultMock;
context = ExecuteContext.forMemberMethod(protocolMock, methodMock, argument, null, null);
interceptor.before(context);
Assert.assertNull(context.getThrowableOut());
}
}
Loading

0 comments on commit 128c164

Please sign in to comment.