Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check privilege when register or alter storage unit #32172

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ void assertParseRQL() {

@Test
void assertParseRDL() {
assertParse(new RegisterStorageUnitStatement(false, Collections.emptyList()), "RDL=1");
assertParse(new RegisterStorageUnitStatement(false, Collections.emptyList(), Collections.emptySet()), "RDL=1");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
import java.util.stream.Collectors;

/**
* Storage units connect exception.
* Storage units validate exception.
*/
public final class StorageUnitsConnectException extends ResourceDefinitionException {
public final class StorageUnitsValidateException extends ResourceDefinitionException {

private static final long serialVersionUID = 1824912697040264268L;

public StorageUnitsConnectException(final Map<String, Exception> causes) {
super(XOpenSQLState.CONNECTION_EXCEPTION, 10, "Storage units can not connect, error messages are: %s.", causes.entrySet().stream().map(entry -> String.format(
public StorageUnitsValidateException(final Map<String, Exception> causes) {
super(XOpenSQLState.CONNECTION_EXCEPTION, 10, "Storage units validate error, messages are: %s.", causes.entrySet().stream().map(entry -> String.format(
"Storage unit name: '%s', error message is: %s", entry.getKey(), entry.getValue().getMessage())).collect(Collectors.joining(System.lineSeparator())));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.database.core.checker.DialectDatabaseEnvironmentChecker;
import org.apache.shardingsphere.infra.database.core.checker.PrivilegeCheckType;
import org.apache.shardingsphere.infra.database.core.spi.DatabaseTypedSPILoader;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeFactory;
import org.apache.shardingsphere.infra.datasource.pool.creator.DataSourcePoolCreator;
import org.apache.shardingsphere.infra.datasource.pool.destroyer.DataSourcePoolDestroyer;
import org.apache.shardingsphere.infra.datasource.pool.props.domain.DataSourcePoolProperties;
Expand All @@ -27,9 +32,11 @@
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;

/**
* Data source pool properties validator.
Expand All @@ -41,14 +48,15 @@ public final class DataSourcePoolPropertiesValidator {
* Validate data source pool properties map.
*
* @param propsMap data source pool properties map
* @param expectedPrivileges excepted privileges
* @return data source name and exception map
*/
public static Map<String, Exception> validate(final Map<String, DataSourcePoolProperties> propsMap) {
public static Map<String, Exception> validate(final Map<String, DataSourcePoolProperties> propsMap, final Collection<PrivilegeCheckType> expectedPrivileges) {
Map<String, Exception> result = new LinkedHashMap<>(propsMap.size(), 1F);
for (Entry<String, DataSourcePoolProperties> entry : propsMap.entrySet()) {
try {
validateProperties(entry.getKey(), entry.getValue());
validateConnection(entry.getKey(), entry.getValue());
validateConnection(entry.getKey(), entry.getValue(), expectedPrivileges);
} catch (final InvalidDataSourcePoolPropertiesException ex) {
result.put(entry.getKey(), ex);
}
Expand All @@ -64,11 +72,16 @@ private static void validateProperties(final String dataSourceName, final DataSo
}
}

private static void validateConnection(final String dataSourceName, final DataSourcePoolProperties props) throws InvalidDataSourcePoolPropertiesException {
private static void validateConnection(final String dataSourceName, final DataSourcePoolProperties props,
final Collection<PrivilegeCheckType> expectedPrivileges) throws InvalidDataSourcePoolPropertiesException {
DataSource dataSource = null;
try {
dataSource = DataSourcePoolCreator.create(props);
checkFailFast(dataSource);
if (expectedPrivileges.isEmpty() || expectedPrivileges.contains(PrivilegeCheckType.NONE)) {
checkFailFast(dataSource);
return;
}
checkPrivileges(dataSource, props, expectedPrivileges);
// CHECKSTYLE:OFF
} catch (final SQLException | RuntimeException ex) {
// CHECKSTYLE:ON
Expand All @@ -87,4 +100,14 @@ private static void checkFailFast(final DataSource dataSource) throws SQLExcepti
// CHECKSTYLE:ON
}
}

private static void checkPrivileges(final DataSource dataSource, final DataSourcePoolProperties props, final Collection<PrivilegeCheckType> expectedPrivileges) {
DatabaseType databaseType = DatabaseTypeFactory.get((String) props.getConnectionPropertySynonyms().getStandardProperties().get("url"));
Optional<DialectDatabaseEnvironmentChecker> checker = DatabaseTypedSPILoader.findService(DialectDatabaseEnvironmentChecker.class, databaseType);
if (checker.isPresent()) {
for (PrivilegeCheckType each : expectedPrivileges) {
checker.get().checkPrivilege(dataSource, each);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ static void setUp() throws ClassNotFoundException {
@Test
void assertValidate() {
assertTrue(DataSourcePoolPropertiesValidator.validate(
Collections.singletonMap("name", new DataSourcePoolProperties(HikariDataSource.class.getName(), Collections.singletonMap("jdbcUrl", "jdbc:mock")))).isEmpty());
Collections.singletonMap("name", new DataSourcePoolProperties(HikariDataSource.class.getName(), Collections.singletonMap("jdbcUrl", "jdbc:mock"))), Collections.emptySet()).isEmpty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
*/
public enum PrivilegeCheckType {

PIPELINE, SELECT, XA
NONE, PIPELINE, SELECT, XA
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.shardingsphere.distsql.segment.URLBasedDataSourceSegment;
import org.apache.shardingsphere.distsql.segment.converter.DataSourceSegmentsConverter;
import org.apache.shardingsphere.distsql.statement.rdl.resource.unit.type.AlterStorageUnitStatement;
import org.apache.shardingsphere.infra.database.core.checker.PrivilegeCheckType;
import org.apache.shardingsphere.infra.database.core.connector.ConnectionProperties;
import org.apache.shardingsphere.infra.database.core.connector.url.JdbcUrl;
import org.apache.shardingsphere.infra.database.core.connector.url.StandardJdbcUrlParser;
Expand All @@ -35,8 +36,8 @@
import org.apache.shardingsphere.infra.exception.core.external.ShardingSphereExternalException;
import org.apache.shardingsphere.infra.exception.kernel.metadata.resource.storageunit.AlterStorageUnitConnectionInfoException;
import org.apache.shardingsphere.infra.exception.kernel.metadata.resource.storageunit.DuplicateStorageUnitException;
import org.apache.shardingsphere.infra.exception.kernel.metadata.resource.storageunit.StorageUnitsOperateException;
import org.apache.shardingsphere.infra.exception.kernel.metadata.resource.storageunit.MissingRequiredStorageUnitsException;
import org.apache.shardingsphere.infra.exception.kernel.metadata.resource.storageunit.StorageUnitsOperateException;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.mode.manager.ContextManager;
Expand Down Expand Up @@ -64,7 +65,7 @@ public final class AlterStorageUnitExecutor implements DistSQLUpdateExecutor<Alt
public void executeUpdate(final AlterStorageUnitStatement sqlStatement, final ContextManager contextManager) {
checkBefore(sqlStatement);
Map<String, DataSourcePoolProperties> propsMap = DataSourceSegmentsConverter.convert(database.getProtocolType(), sqlStatement.getStorageUnits());
validateHandler.validate(propsMap);
validateHandler.validate(propsMap, sqlStatement.getExpectedPrivileges().stream().map(each -> PrivilegeCheckType.valueOf(each.toUpperCase())).collect(Collectors.toSet()));
try {
MetaDataContexts originalMetaDataContexts = contextManager.getMetaDataContexts();
contextManager.getPersistServiceFacade().getMetaDataManagerPersistService().alterStorageUnits(database.getName(), propsMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.shardingsphere.distsql.segment.DataSourceSegment;
import org.apache.shardingsphere.distsql.segment.converter.DataSourceSegmentsConverter;
import org.apache.shardingsphere.distsql.statement.rdl.resource.unit.type.RegisterStorageUnitStatement;
import org.apache.shardingsphere.infra.database.core.checker.PrivilegeCheckType;
import org.apache.shardingsphere.infra.datasource.pool.props.domain.DataSourcePoolProperties;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.exception.core.external.ShardingSphereExternalException;
Expand Down Expand Up @@ -66,7 +67,7 @@ public void executeUpdate(final RegisterStorageUnitStatement sqlStatement, final
if (propsMap.isEmpty()) {
return;
}
validateHandler.validate(propsMap);
validateHandler.validate(propsMap, sqlStatement.getExpectedPrivileges().stream().map(each -> PrivilegeCheckType.valueOf(each.toUpperCase())).collect(Collectors.toSet()));
try {
MetaDataContexts originalMetaDataContexts = contextManager.getMetaDataContexts();
contextManager.getPersistServiceFacade().getMetaDataManagerPersistService().registerStorageUnits(database.getName(), propsMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

package org.apache.shardingsphere.distsql.handler.validate;

import org.apache.shardingsphere.infra.database.core.checker.PrivilegeCheckType;
import org.apache.shardingsphere.infra.datasource.pool.props.domain.DataSourcePoolProperties;
import org.apache.shardingsphere.infra.datasource.pool.props.validator.DataSourcePoolPropertiesValidator;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.exception.kernel.metadata.resource.storageunit.StorageUnitsConnectException;
import org.apache.shardingsphere.infra.exception.kernel.metadata.resource.storageunit.StorageUnitsValidateException;

import java.util.Collection;
import java.util.Collections;
import java.util.Map;

/**
Expand All @@ -35,7 +38,17 @@ public final class DistSQLDataSourcePoolPropertiesValidator {
* @param propsMap data source pool properties map
*/
public void validate(final Map<String, DataSourcePoolProperties> propsMap) {
Map<String, Exception> exceptions = DataSourcePoolPropertiesValidator.validate(propsMap);
ShardingSpherePreconditions.checkMustEmpty(exceptions, () -> new StorageUnitsConnectException(exceptions));
validate(propsMap, Collections.emptySet());
}

/**
* Validate data source properties map.
*
* @param propsMap data source pool properties map
* @param expectedPrivileges expected privileges
*/
public void validate(final Map<String, DataSourcePoolProperties> propsMap, final Collection<PrivilegeCheckType> expectedPrivileges) {
Map<String, Exception> exceptions = DataSourcePoolPropertiesValidator.validate(propsMap, expectedPrivileges);
ShardingSpherePreconditions.checkMustEmpty(exceptions, () -> new StorageUnitsValidateException(exceptions));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,14 @@ private ContextManager mockContextManager(final MetaDataContexts metaDataContext
}

private AlterStorageUnitStatement createAlterStorageUnitStatement(final String resourceName) {
return new AlterStorageUnitStatement(Collections.singleton(new URLBasedDataSourceSegment(resourceName, "jdbc:mysql://127.0.0.1:3306/ds_0", "root", "", new Properties())));
return new AlterStorageUnitStatement(Collections.singleton(new URLBasedDataSourceSegment(resourceName, "jdbc:mysql://127.0.0.1:3306/ds_0", "root", "", new Properties())),
Collections.emptySet());
}

private AlterStorageUnitStatement createAlterStorageUnitStatementWithDuplicateStorageUnitNames() {
return new AlterStorageUnitStatement(Arrays.asList(
new HostnameAndPortBasedDataSourceSegment("ds_0", "127.0.0.1", "3306", "ds_0", "root", "", new Properties()),
new URLBasedDataSourceSegment("ds_0", "jdbc:mysql://127.0.0.1:3306/ds_1", "root", "", new Properties())));
new URLBasedDataSourceSegment("ds_0", "jdbc:mysql://127.0.0.1:3306/ds_1", "root", "", new Properties())), Collections.emptySet());
}

private ConnectionProperties mockConnectionProperties(final String catalog) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,13 @@ void assertExecuteUpdateWithDuplicateStorageUnitNamesWithDataSourceContainedRule
}

private RegisterStorageUnitStatement createRegisterStorageUnitStatement() {
return new RegisterStorageUnitStatement(false, Collections.singleton(new URLBasedDataSourceSegment("ds_0", "jdbc:mysql://127.0.0.1:3306/test0", "root", "", new Properties())));
return new RegisterStorageUnitStatement(false, Collections.singleton(new URLBasedDataSourceSegment("ds_0", "jdbc:mysql://127.0.0.1:3306/test0", "root", "", new Properties())),
Collections.emptySet());
}

private RegisterStorageUnitStatement createRegisterStorageUnitStatementWithDuplicateStorageUnitNames() {
return new RegisterStorageUnitStatement(false, Arrays.asList(
new HostnameAndPortBasedDataSourceSegment("ds_0", "127.0.0.1", "3306", "ds_0", "root", "", new Properties()),
new URLBasedDataSourceSegment("ds_0", "jdbc:mysql://127.0.0.1:3306/ds_1", "root", "", new Properties())));
new URLBasedDataSourceSegment("ds_0", "jdbc:mysql://127.0.0.1:3306/ds_1", "root", "", new Properties())), Collections.emptySet());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,23 @@

import org.apache.shardingsphere.distsql.statement.rdl.resource.unit.type.RegisterStorageUnitStatement;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableNameSegment;
import org.apache.shardingsphere.sql.parser.statement.core.statement.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.apache.shardingsphere.sql.parser.statement.mysql.ddl.MySQLCreateTableStatement;
import org.apache.shardingsphere.sql.parser.statement.mysql.dml.MySQLInsertStatement;
import org.apache.shardingsphere.sql.parser.statement.mysql.dml.MySQLSelectStatement;
import org.junit.jupiter.api.Test;

import java.util.LinkedList;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;

class AutoCommitUtilsTest {

@Test
void assertNeedOpenTransactionForSelectStatement() {
SelectStatement selectStatement = new MySQLSelectStatement();
assertFalse(AutoCommitUtils.needOpenTransaction(selectStatement));
selectStatement.setFrom(new SimpleTableSegment(new TableNameSegment(0, 5, new IdentifierValue("foo"))));
selectStatement.setFrom(mock(SimpleTableSegment.class));
assertTrue(AutoCommitUtils.needOpenTransaction(selectStatement));
}

Expand All @@ -50,6 +47,6 @@ void assertNeedOpenTransactionForDDLOrDMLStatement() {

@Test
void assertNeedOpenTransactionForOtherStatement() {
assertFalse(AutoCommitUtils.needOpenTransaction(new RegisterStorageUnitStatement(false, new LinkedList<>())));
assertFalse(AutoCommitUtils.needOpenTransaction(mock(RegisterStorageUnitStatement.class)));
}
}
4 changes: 4 additions & 0 deletions parser/distsql/engine/src/main/antlr4/imports/Keyword.g4
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,7 @@ ALGORITHM
FORCE
: F O R C E
;

CHECK_PRIVILEGES
: C H E C K UL_ P R I V I L E G E S
;
16 changes: 14 additions & 2 deletions parser/distsql/engine/src/main/antlr4/imports/RDLStatement.g4
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,21 @@ grammar RDLStatement;
import BaseRule;

registerStorageUnit
: REGISTER STORAGE UNIT ifNotExists? storageUnitDefinition (COMMA_ storageUnitDefinition)*
: REGISTER STORAGE UNIT ifNotExists? storageUnitsDefinition (COMMA_ checkPrivileges)?
;

alterStorageUnit
: ALTER STORAGE UNIT storageUnitDefinition (COMMA_ storageUnitDefinition)*
: ALTER STORAGE UNIT storageUnitsDefinition (COMMA_ checkPrivileges)?
;

unregisterStorageUnit
: UNREGISTER STORAGE UNIT ifExists? storageUnitName (COMMA_ storageUnitName)* ignoreTables?
;

storageUnitsDefinition
: storageUnitDefinition (COMMA_ storageUnitDefinition)*
;

storageUnitDefinition
: storageUnitName LP_ (simpleSource | urlSource) COMMA_ USER EQ_ user (COMMA_ PASSWORD EQ_ password)? (COMMA_ propertiesDefinition)? RP_
;
Expand Down Expand Up @@ -80,3 +84,11 @@ ifExists
ifNotExists
: IF NOT EXISTS
;

checkPrivileges
: CHECK_PRIVILEGES EQ_ privilegeType (COMMA_ privilegeType)*
;

privilegeType
: IDENTIFIER_
;
Loading