diff --git a/csharp/test/Drivers/Snowflake/ClientTests.cs b/csharp/test/Drivers/Snowflake/ClientTests.cs index 77d12fb462..e895865024 100644 --- a/csharp/test/Drivers/Snowflake/ClientTests.cs +++ b/csharp/test/Drivers/Snowflake/ClientTests.cs @@ -39,7 +39,7 @@ public class ClientTests { public ClientTests() { - Skip.IfNot(Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE)); + Skip.IfNot(Utils.CanExecuteTestConfig(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE)); } /// @@ -50,7 +50,7 @@ public void CanClientExecuteUpdate() { SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnection(testConfiguration)) + using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration)) { string[] queries = SnowflakeTestingUtils.GetQueries(testConfiguration); @@ -68,7 +68,7 @@ public void CanClientExecuteUpdateUsingExecuteReader() { SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnection(testConfiguration)) + using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration)) { adbcConnection.Open(); @@ -104,7 +104,7 @@ public void CanClientGetSchema() { SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnection(testConfiguration)) + using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration)) { Tests.ClientTests.CanClientGetSchema(adbcConnection, testConfiguration); } @@ -119,7 +119,7 @@ public void CanClientExecuteQuery() { SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnection(testConfiguration)) + using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration)) { Tests.ClientTests.CanClientExecuteQuery(adbcConnection, testConfiguration); } @@ -136,7 +136,7 @@ public void CanClientExecuteQueryWithNoResults() testConfiguration.Query = "SELECT * WHERE 0=1"; testConfiguration.ExpectedResultsCount = 0; - using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnection(testConfiguration)) + using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration)) { Tests.ClientTests.CanClientExecuteQuery(adbcConnection, testConfiguration); } @@ -151,7 +151,9 @@ public void CanClientExecuteQueryUsingPrivateKey() { SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration)) + Skip.If(testConfiguration.Authentication.SnowflakeJwt is null, "JWT authentication is not configured"); + + using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration, SnowflakeAuthentication.AuthJwt)) { Tests.ClientTests.CanClientExecuteQuery(adbcConnection, testConfiguration); } @@ -166,7 +168,7 @@ public void VerifyTypesAndValues() { SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnection(testConfiguration)) + using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration)) { SampleDataBuilder sampleDataBuilder = SnowflakeData.GetSampleData(); @@ -179,7 +181,7 @@ public void VerifySchemaTables() { SnowflakeTestConfiguration testConfiguration = Utils.LoadTestConfiguration(SnowflakeTestingUtils.SNOWFLAKE_TEST_CONFIG_VARIABLE); - using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnection(testConfiguration)) + using (Adbc.Client.AdbcConnection adbcConnection = GetSnowflakeAdbcConnectionUsingConnectionString(testConfiguration)) { adbcConnection.Open(); @@ -218,7 +220,7 @@ public void VerifySchemaTables() } } - private Adbc.Client.AdbcConnection GetSnowflakeAdbcConnectionUsingConnectionString(SnowflakeTestConfiguration testConfiguration) + private Adbc.Client.AdbcConnection GetSnowflakeAdbcConnectionUsingConnectionString(SnowflakeTestConfiguration testConfiguration, string authType = null) { // see https://arrow.apache.org/adbc/0.5.1/driver/snowflake.html @@ -228,22 +230,32 @@ private Adbc.Client.AdbcConnection GetSnowflakeAdbcConnectionUsingConnectionStri builder[SnowflakeParameters.HOST] = testConfiguration.Host; builder[SnowflakeParameters.DATABASE] = testConfiguration.Database; builder[SnowflakeParameters.USERNAME] = testConfiguration.User; - if (!string.IsNullOrEmpty(testConfiguration.AuthenticationTokenPath)) + if (authType == SnowflakeAuthentication.AuthJwt) { - builder[SnowflakeParameters.AUTH_TYPE] = testConfiguration.AuthenticationType; - string privateKey = File.ReadAllText(testConfiguration.AuthenticationTokenPath); - if (testConfiguration.AuthenticationType.Equals("auth_jwt", StringComparison.OrdinalIgnoreCase)) + string privateKey = testConfiguration.Authentication.SnowflakeJwt.PrivateKey; + builder[SnowflakeParameters.AUTH_TYPE] = SnowflakeAuthentication.AuthJwt; + builder[SnowflakeParameters.PKCS8_VALUE] = privateKey; + builder[SnowflakeParameters.USERNAME] = testConfiguration.Authentication.SnowflakeJwt.User; + if (!string.IsNullOrEmpty(testConfiguration.Authentication.SnowflakeJwt.PrivateKeyPassPhrase)) { - builder[SnowflakeParameters.PKCS8_VALUE] = privateKey; - if(!string.IsNullOrEmpty(testConfiguration.Pkcs8Passcode)) - { - builder[SnowflakeParameters.PKCS8_PASS] = testConfiguration.Pkcs8Passcode; - } + builder[SnowflakeParameters.PKCS8_PASS] = testConfiguration.Authentication.SnowflakeJwt.PrivateKeyPassPhrase; + } + } + else if (authType == SnowflakeAuthentication.AuthOAuth) + { + builder[SnowflakeParameters.AUTH_TYPE] = SnowflakeAuthentication.AuthOAuth; + builder[SnowflakeParameters.AUTH_TOKEN] = testConfiguration.Authentication.OAuth.Token; + if (testConfiguration.Authentication.OAuth.User != null) + { + builder[SnowflakeParameters.USERNAME] = testConfiguration.Authentication.OAuth.User; } } - else + else if (string.IsNullOrEmpty(authType) || authType == SnowflakeAuthentication.AuthSnowflake) { - builder[SnowflakeParameters.PASSWORD] = testConfiguration.Password; + // if no auth type is specified, use the snowflake auth + builder[SnowflakeParameters.AUTH_TYPE] = SnowflakeAuthentication.AuthSnowflake; + builder[SnowflakeParameters.USERNAME] = testConfiguration.Authentication.Default.User; + builder[SnowflakeParameters.PASSWORD] = testConfiguration.Authentication.Default.Password; } AdbcDriver snowflakeDriver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(testConfiguration); return new Adbc.Client.AdbcConnection(builder.ConnectionString) @@ -251,19 +263,5 @@ private Adbc.Client.AdbcConnection GetSnowflakeAdbcConnectionUsingConnectionStri AdbcDriver = snowflakeDriver }; } - private Adbc.Client.AdbcConnection GetSnowflakeAdbcConnection(SnowflakeTestConfiguration testConfiguration) - { - Dictionary parameters = new Dictionary(); - - AdbcDriver snowflakeDriver = SnowflakeTestingUtils.GetSnowflakeAdbcDriver(testConfiguration, out parameters); - - Adbc.Client.AdbcConnection adbcConnection = new Adbc.Client.AdbcConnection( - snowflakeDriver, - parameters: parameters, - options: new Dictionary() - ); - - return adbcConnection; - } } } diff --git a/csharp/test/Drivers/Snowflake/Resources/snowflakeconfig.json b/csharp/test/Drivers/Snowflake/Resources/snowflakeconfig.json index d3b4358143..dec0c492e3 100644 --- a/csharp/test/Drivers/Snowflake/Resources/snowflakeconfig.json +++ b/csharp/test/Drivers/Snowflake/Resources/snowflakeconfig.json @@ -4,19 +4,28 @@ "account": "", "host": "", "database": "", - "user": "", - "password": "", "warehouse": "", - "authenticationType": "", - "authenticationTokenPath": "", - "pkcs8Passcode": "", - "useHighPrecision": true, + "useHighPrecision": true, "metadata": { "catalog": "", "schema": "", "table": "", "expectedColumnCount": 0 }, + "authentication": { + "auth_oauth": { + "token": "" + }, + "auth_jwt": { + "private_key_file": "", + "private_key_pwd": "", + "user": "" + }, + "auth_snowflake": { + "user": "", + "password": "" + } + }, "query": "", "expectedResults": 0 } diff --git a/csharp/test/Drivers/Snowflake/SnowflakeTestConfiguration.cs b/csharp/test/Drivers/Snowflake/SnowflakeTestConfiguration.cs index 96bdcc22b5..b2f7b2ba89 100644 --- a/csharp/test/Drivers/Snowflake/SnowflakeTestConfiguration.cs +++ b/csharp/test/Drivers/Snowflake/SnowflakeTestConfiguration.cs @@ -73,27 +73,65 @@ internal class SnowflakeTestConfiguration : TestConfiguration public string Warehouse { get; set; } /// - /// The Snowflake authentication type. + /// The Snowflake use high precision /// - [JsonPropertyName("authenticationType")] - public string AuthenticationType { get; set; } + [JsonPropertyName("useHighPrecision")] + public bool UseHighPrecision { get; set; } = true; /// - /// The file location of the authentication token (if using). + /// The snowflake Authentication /// - [JsonPropertyName("authenticationTokenPath"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] - public string AuthenticationTokenPath { get; set; } + [JsonPropertyName("authentication")] + public SnowflakeAuthentication Authentication { get; set; } - /// - /// The passcode to use if the JWT token is encrypted. - /// - [JsonPropertyName("pkcs8Passcode"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] - public string Pkcs8Passcode { get; set; } + } - /// - /// The Snowflake authentication type. - /// - [JsonPropertyName("useHighPrecision")] - public bool UseHighPrecision { get; set; } = true; + public class SnowflakeAuthentication + { + public const string AuthOAuth = "auth_oauth"; + public const string AuthJwt = "auth_jwt"; + public const string AuthSnowflake = "auth_snowflake"; + + [JsonPropertyName(AuthOAuth)] + public OAuthAuthentication OAuth { get; set; } + + [JsonPropertyName(AuthJwt)] + public JwtAuthentication SnowflakeJwt { get; set; } + + [JsonPropertyName(AuthSnowflake)] + public DefaultAuthentication Default { get; set; } + } + + public class OAuthAuthentication + { + [JsonPropertyName("token")] + public string Token { get; set; } + + [JsonPropertyName("user")] + public string User { get; set; } + } + + public class JwtAuthentication + { + [JsonPropertyName("private_key")] + public string PrivateKey { get; set; } + + [JsonPropertyName("private_key_file")] + public string PrivateKeyFile { get; set; } + + [JsonPropertyName("private_key_pwd")] + public string PrivateKeyPassPhrase{ get; set; } + + [JsonPropertyName("user")] + public string User { get; set; } + } + + public class DefaultAuthentication + { + [JsonPropertyName("user")] + public string User { get; set; } + + [JsonPropertyName("password")] + public string Password { get; set; } } } diff --git a/csharp/test/Drivers/Snowflake/SnowflakeTestingUtils.cs b/csharp/test/Drivers/Snowflake/SnowflakeTestingUtils.cs index 18fcafb65f..a6c63868b9 100644 --- a/csharp/test/Drivers/Snowflake/SnowflakeTestingUtils.cs +++ b/csharp/test/Drivers/Snowflake/SnowflakeTestingUtils.cs @@ -34,6 +34,7 @@ internal class SnowflakeParameters public const string PASSWORD = "password"; public const string WAREHOUSE = "adbc.snowflake.sql.warehouse"; public const string AUTH_TYPE = "adbc.snowflake.sql.auth_type"; + public const string AUTH_TOKEN = "adbc.snowflake.sql.client_option.auth_token"; public const string HOST = "adbc.snowflake.sql.uri.host"; public const string PKCS8_VALUE = "adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_value"; public const string PKCS8_PASS = "adbc.snowflake.sql.client_option.jwt_private_key_pkcs8_password"; @@ -78,11 +79,17 @@ out Dictionary parameters { SnowflakeParameters.USERNAME, testConfiguration.User }, { SnowflakeParameters.PASSWORD, testConfiguration.Password }, { SnowflakeParameters.WAREHOUSE, testConfiguration.Warehouse }, - { SnowflakeParameters.AUTH_TYPE, testConfiguration.AuthenticationType }, { SnowflakeParameters.USE_HIGH_PRECISION, testConfiguration.UseHighPrecision.ToString().ToLowerInvariant() } }; - if (!string.IsNullOrWhiteSpace(testConfiguration.Host)) + if(testConfiguration.Authentication.Default is not null) + { + parameters[SnowflakeParameters.AUTH_TYPE] = SnowflakeAuthentication.AuthSnowflake; + parameters[SnowflakeParameters.USERNAME] = testConfiguration.Authentication.Default.User; + parameters[SnowflakeParameters.PASSWORD] = testConfiguration.Authentication.Default.Password; + } + + if(!string.IsNullOrWhiteSpace(testConfiguration.Host)) { parameters[SnowflakeParameters.HOST] = testConfiguration.Host; }