Skip to content

Commit

Permalink
Add more test cases on BroadcastUnicastRoutingEngine (#33500)
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu authored Nov 1, 2024
1 parent 0beb0dc commit 9b1c315
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ private RouteMapper getDataSourceRouteMapper(final Collection<String> dataSource
}

private String getDataSourceName(final Collection<String> dataSourceNames) {
return sqlStatementContext instanceof CursorAvailable || isViewStatementContext(sqlStatementContext) ? dataSourceNames.iterator().next() : getRandomDataSourceName(dataSourceNames);
return isRouteToFirstDataSource() ? dataSourceNames.iterator().next() : getRandomDataSourceName(dataSourceNames);
}

private boolean isRouteToFirstDataSource() {
return sqlStatementContext instanceof CursorAvailable || isViewStatementContext(sqlStatementContext);
}

private boolean isViewStatementContext(final SQLStatementContext sqlStatementContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ class BroadcastDatabaseBroadcastRoutingEngineTest {

@Test
void assertRoute() {
BroadcastRule broadcastRule = mock(BroadcastRule.class);
when(broadcastRule.getDataSourceNames()).thenReturn(Arrays.asList("ds_0", "ds_1"));
BroadcastRule rule = mock(BroadcastRule.class);
when(rule.getDataSourceNames()).thenReturn(Arrays.asList("ds_0", "ds_1"));
BroadcastDatabaseBroadcastRoutingEngine engine = new BroadcastDatabaseBroadcastRoutingEngine();
RouteContext routeContext = engine.route(new RouteContext(), broadcastRule);
RouteContext routeContext = engine.route(new RouteContext(), rule);
assertThat(routeContext.getRouteUnits().size(), is(2));
Iterator<RouteUnit> iterator = routeContext.getRouteUnits().iterator();
assertDataSourceRouteMapper(iterator.next(), "ds_0");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ void assertRoute() {
ResourceMetaData resourceMetaData = mock(ResourceMetaData.class);
when(resourceMetaData.getAllInstanceDataSourceNames()).thenReturn(Collections.singleton("ds_0"));
BroadcastInstanceBroadcastRoutingEngine engine = new BroadcastInstanceBroadcastRoutingEngine(resourceMetaData);
BroadcastRule broadcastRule = mock(BroadcastRule.class);
when(broadcastRule.getDataSourceNames()).thenReturn(Arrays.asList("ds_0", "ds_1"));
RouteContext routeContext = engine.route(new RouteContext(), broadcastRule);
BroadcastRule rule = mock(BroadcastRule.class);
when(rule.getDataSourceNames()).thenReturn(Arrays.asList("ds_0", "ds_1"));
RouteContext routeContext = engine.route(new RouteContext(), rule);
assertThat(routeContext.getRouteUnits().size(), is(1));
assertDataSourceRouteMapper(routeContext.getRouteUnits().iterator().next(), "ds_0");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,74 +19,101 @@

import org.apache.shardingsphere.broadcast.rule.BroadcastRule;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.ddl.AlterViewStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.ddl.CreateViewStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.ddl.DropViewStatementContext;
import org.apache.shardingsphere.infra.binder.context.type.CursorAvailable;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteMapper;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.hamcrest.Matcher;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;

import static org.hamcrest.CoreMatchers.anyOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.withSettings;

@ExtendWith(MockitoExtension.class)
class BroadcastUnicastRoutingEngineTest {

private BroadcastRule broadcastRule;
@Mock
private BroadcastRule rule;

@Mock
private ConnectionContext connectionContext;

@BeforeEach
void setUp() {
broadcastRule = mock(BroadcastRule.class);
when(broadcastRule.getDataSourceNames()).thenReturn(Arrays.asList("ds_0", "ds_1"));
when(rule.getDataSourceNames()).thenReturn(Arrays.asList("ds_0", "ds_1"));
}

@Test
void assertRouteToFirstDataSourceWithCursorStatement() {
assertRoute(mock(SQLStatementContext.class, withSettings().extraInterfaces(CursorAvailable.class)), is("ds_0"));
}

@Test
void assertRoute() {
SQLStatementContext sqlStatementContext = mock(SQLStatementContext.class);
Collection<String> logicTables = Collections.singleton("t_address");
ConnectionContext connectionContext = mock(ConnectionContext.class);
BroadcastUnicastRoutingEngine engine = new BroadcastUnicastRoutingEngine(sqlStatementContext, logicTables, connectionContext);
RouteContext routeContext = engine.route(new RouteContext(), broadcastRule);
assertThat(routeContext.getRouteUnits().size(), is(1));
assertTableRouteMapper(routeContext);
void assertRouteToFirstDataSourceWithCreateViewStatementContext() {
assertRoute(mock(CreateViewStatementContext.class), is("ds_0"));
}

@Test
void assertRouteWithCreateViewStatementContext() {
CreateViewStatementContext sqlStatementContext = mock(CreateViewStatementContext.class);
Collection<String> logicTables = Collections.singleton("t_address");
ConnectionContext connectionContext = mock(ConnectionContext.class);
BroadcastUnicastRoutingEngine engine = new BroadcastUnicastRoutingEngine(sqlStatementContext, logicTables, connectionContext);
RouteContext routeContext = engine.route(new RouteContext(), broadcastRule);
assertThat(routeContext.getRouteUnits().size(), is(1));
RouteMapper dataSourceRouteMapper = routeContext.getRouteUnits().iterator().next().getDataSourceMapper();
assertThat(dataSourceRouteMapper.getLogicName(), is("ds_0"));
assertTableRouteMapper(routeContext);
void assertRouteToFirstDataSourceWithAlterViewStatementContext() {
assertRoute(mock(AlterViewStatementContext.class), is("ds_0"));
}

@Test
void assertRouteWithCursorStatement() {
CreateViewStatementContext sqlStatementContext = mock(CreateViewStatementContext.class);
Collection<String> logicTables = Collections.singleton("t_address");
ConnectionContext connectionContext = mock(ConnectionContext.class);
BroadcastUnicastRoutingEngine engine = new BroadcastUnicastRoutingEngine(sqlStatementContext, logicTables, connectionContext);
RouteContext routeContext = engine.route(new RouteContext(), broadcastRule);
assertThat(routeContext.getRouteUnits().size(), is(1));
RouteMapper dataSourceRouteMapper = routeContext.getRouteUnits().iterator().next().getDataSourceMapper();
assertThat(dataSourceRouteMapper.getLogicName(), is("ds_0"));
assertTableRouteMapper(routeContext);
void assertRouteToFirstDataSourceWithDropViewStatementContext() {
assertRoute(mock(DropViewStatementContext.class), is("ds_0"));
}

private void assertTableRouteMapper(final RouteContext routeContext) {
Collection<RouteMapper> tableRouteMappers = routeContext.getRouteUnits().iterator().next().getTableMappers();
assertThat(tableRouteMappers.size(), is(1));
RouteMapper tableRouteMapper = tableRouteMappers.iterator().next();
assertThat(tableRouteMapper.getLogicName(), is("t_address"));
assertThat(tableRouteMapper.getActualName(), is("t_address"));
@Test
void assertRouteToRandomDataSourceWithUnusedDataSources() {
assertRoute(mock(SQLStatementContext.class), is("ds_0"), is("ds_1"));
}

@Test
void assertRouteToRandomDataSourceWithUsedDataSources() {
when(connectionContext.getUsedDataSourceNames()).thenReturn(Collections.singletonList("ds_2"));
assertRoute(mock(SQLStatementContext.class), is("ds_2"));
}

@SafeVarargs
private final void assertRoute(final SQLStatementContext sqlStatementContext, final Matcher<String>... matchers) {
BroadcastUnicastRoutingEngine engine = new BroadcastUnicastRoutingEngine(sqlStatementContext, Collections.singleton("foo_tbl"), connectionContext);
RouteContext actual = engine.route(new RouteContext(), rule);
assertThat(actual.getRouteUnits().size(), is(1));
RouteMapper actualDataSourceRouteMapper = actual.getRouteUnits().iterator().next().getDataSourceMapper();
assertThat(actualDataSourceRouteMapper.getLogicName(), anyOf(matchers));
Collection<RouteMapper> actualTableRouteMappers = actual.getRouteUnits().iterator().next().getTableMappers();
assertTableRouteMapper(actualTableRouteMappers);
}

private void assertTableRouteMapper(final Collection<RouteMapper> actual) {
assertThat(actual.size(), is(1));
RouteMapper tableRouteMapper = actual.iterator().next();
assertThat(tableRouteMapper.getLogicName(), is("foo_tbl"));
assertThat(tableRouteMapper.getActualName(), is("foo_tbl"));
}

@Test
void assertRouteWithEmptyTables() {
BroadcastUnicastRoutingEngine engine = new BroadcastUnicastRoutingEngine(mock(SQLStatementContext.class), Collections.emptyList(), connectionContext);
RouteContext actual = engine.route(new RouteContext(), rule);
assertThat(actual.getRouteUnits().size(), is(1));
Collection<RouteMapper> actualTableRouteMappers = actual.getRouteUnits().iterator().next().getTableMappers();
assertTrue(actualTableRouteMappers.isEmpty());
}
}

0 comments on commit 9b1c315

Please sign in to comment.