diff --git a/sharding-proxy/src/test/java/io/shardingsphere/shardingproxy/backend/jdbc/connection/BackendConnectionTest.java b/sharding-proxy/src/test/java/io/shardingsphere/shardingproxy/backend/jdbc/connection/BackendConnectionTest.java index 6683e04bb8719c..ad23e21d2af5cb 100644 --- a/sharding-proxy/src/test/java/io/shardingsphere/shardingproxy/backend/jdbc/connection/BackendConnectionTest.java +++ b/sharding-proxy/src/test/java/io/shardingsphere/shardingproxy/backend/jdbc/connection/BackendConnectionTest.java @@ -28,11 +28,14 @@ import org.mockito.junit.MockitoJUnitRunner; import java.sql.Connection; +import java.util.ArrayList; import java.util.List; +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertThat; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -42,21 +45,26 @@ public class BackendConnectionTest { @Mock private LogicSchema logicSchema; + @Mock + private JDBCBackendDataSource backendDataSource; + private BackendConnection backendConnection = new BackendConnection(); @Before @SuppressWarnings("unchecked") @SneakyThrows public void setup() { - List newConnection = mock(List.class); - JDBCBackendDataSource backendDataSource = mock(JDBCBackendDataSource.class); - when(backendDataSource.getConnections((ConnectionMode) any(), anyString(), anyInt())).thenReturn(newConnection); when(logicSchema.getBackendDataSource()).thenReturn(backendDataSource); + backendConnection.setLogicSchema(logicSchema); } @Test + @SneakyThrows public void assertGetConnectionCacheIsEmpty() { - + when(backendDataSource.getConnections((ConnectionMode) any(), anyString(), eq(2))).thenReturn(mockNewConnections(2)); + List actualConnections = backendConnection.getConnections(ConnectionMode.MEMORY_STRICTLY, "ds1", 2); + assertThat(actualConnections.size(), is(2)); + assertThat(backendConnection.getConnectionSize(), is(2)); } @Test @@ -78,4 +86,13 @@ public void assertGetConnectionSizeIsOne() { public void assertMultiThreadGetConnection() { } + + private List mockNewConnections(final int connectionSize) { + List result = new ArrayList<>(); + for (int i = 0; i < connectionSize; i++) { + Connection connection = mock(Connection.class); + result.add(connection); + } + return result; + } }