diff --git a/src/Accounts/Accounts.Test/AutosaveTests.cs b/src/Accounts/Accounts.Test/AutosaveTests.cs index fd3a2df838bb..50ea60b43fe5 100644 --- a/src/Accounts/Accounts.Test/AutosaveTests.cs +++ b/src/Accounts/Accounts.Test/AutosaveTests.cs @@ -51,9 +51,7 @@ private AzKeyStore SetMockedAzKeyStore() storageMocker.Setup(f => f.Create()).Returns(storageMocker.Object); storageMocker.Setup(f => f.ReadData()).Returns(new byte[0]); storageMocker.Setup(f => f.WriteData(It.IsAny())).Callback((byte[] s) => {}); - var keyStore = new AzKeyStore(AzureSession.Instance.ARMProfileDirectory, "azkeystore", false, false, storageMocker.Object); - AzKeyStore.RegisterJsonConverter(typeof(ServicePrincipalKey), typeof(ServicePrincipalKey).Name); - AzKeyStore.RegisterJsonConverter(typeof(SecureString), typeof(SecureString).Name, new SecureStringConverter()); + var keyStore = new AzKeyStore(AzureSession.Instance.ARMProfileDirectory, "azkeystore", storageMocker.Object); return keyStore; } diff --git a/src/Accounts/Accounts.Test/ContextCmdletTests.cs b/src/Accounts/Accounts.Test/ContextCmdletTests.cs index dcbebaf3f36e..7b977a5a7482 100644 --- a/src/Accounts/Accounts.Test/ContextCmdletTests.cs +++ b/src/Accounts/Accounts.Test/ContextCmdletTests.cs @@ -23,17 +23,21 @@ using Microsoft.WindowsAzure.Commands.ScenarioTest; using Microsoft.WindowsAzure.Commands.Test.Utilities.Common; using Microsoft.WindowsAzure.Commands.Utilities.Common; -using Xunit; -using Xunit.Abstractions; using Microsoft.Azure.Commands.Common.Authentication.Abstractions; -using System; +using Microsoft.Azure.Commands.Common.Authentication.Properties; using Microsoft.Azure.Commands.Profile.Context; -using System.Linq; using Microsoft.Azure.Commands.Common.Authentication.ResourceManager; using Microsoft.Azure.Commands.Profile.Common; using Microsoft.Azure.Commands.ScenarioTest.Mocks; using Microsoft.Azure.Commands.TestFx.Mocks; using Microsoft.Azure.Commands.TestFx; +using Microsoft.Azure.Commands.ResourceManager.Common; +using Moq; +using System; +using System.IO; +using System.Linq; +using Xunit; +using Xunit.Abstractions; namespace Microsoft.Azure.Commands.Profile.Test { @@ -56,6 +60,12 @@ public ContextCmdletTests(ITestOutputHelper output) tokenCacheProviderMock = new MockPowerShellTokenCacheProvider(); AzureSession.Instance.RegisterComponent(PowerShellTokenCacheProvider.PowerShellTokenCacheProviderKey, () => tokenCacheProviderMock); Environment.SetEnvironmentVariable("Azure_PS_Data_Collection", "True"); + + Mock storageMocker = new Mock(); + AzKeyStore azKeyStore = null; + string profilePath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.UserProfile), Resources.AzureDirectoryName); + azKeyStore = new AzKeyStore(profilePath, AzureSession.Instance.KeyStoreFile, storageMocker.Object); + AzureSession.Instance.RegisterComponent(AzKeyStore.Name, () => azKeyStore, true); } [Fact] diff --git a/src/Accounts/Accounts.Test/ProfileCmdletTests.cs b/src/Accounts/Accounts.Test/ProfileCmdletTests.cs index 1dab2d754f7b..386998a78ffd 100644 --- a/src/Accounts/Accounts.Test/ProfileCmdletTests.cs +++ b/src/Accounts/Accounts.Test/ProfileCmdletTests.cs @@ -55,7 +55,7 @@ private AzKeyStore SetMockedAzKeyStore() storageMocker.Setup(f => f.Create()).Returns(storageMocker.Object); storageMocker.Setup(f => f.ReadData()).Returns(new byte[0]); storageMocker.Setup(f => f.WriteData(It.IsAny())).Callback((byte[] s) => { }); - var keyStore = new AzKeyStore(AzureSession.Instance.ARMProfileDirectory, "azkeystore", false, false, storageMocker.Object); + var keyStore = new AzKeyStore(AzureSession.Instance.ARMProfileDirectory, "azkeystore", storageMocker.Object); return keyStore; } diff --git a/src/Accounts/Accounts/Account/ConnectAzureRmAccount.cs b/src/Accounts/Accounts/Account/ConnectAzureRmAccount.cs index 5597cb0e98ec..218fc28abe8f 100644 --- a/src/Accounts/Accounts/Account/ConnectAzureRmAccount.cs +++ b/src/Accounts/Accounts/Account/ConnectAzureRmAccount.cs @@ -425,7 +425,7 @@ public override void ExecuteCmdlet() azureAccount.SetProperty(AzureAccount.Property.CertificatePath, resolvedPath); if (CertificatePassword != null) { - keyStore?.SaveKey(new ServicePrincipalKey(AzureAccount.Property.CertificatePassword, azureAccount.Id, Tenant), CertificatePassword); + keyStore?.SaveCredential(new ServicePrincipalKey(AzureAccount.Property.CertificatePassword, azureAccount.Id, Tenant), CertificatePassword); if (GetContextModificationScope() == ContextModificationScope.CurrentUser && !keyStore.IsProtected) { WriteWarning(string.Format(Resources.ServicePrincipalWarning, AzureSession.Instance.KeyStoreFile, AzureSession.Instance.ARMProfileDirectory)); @@ -451,7 +451,7 @@ public override void ExecuteCmdlet() if (azureAccount.Type == AzureAccount.AccountType.ServicePrincipal && password != null) { - keyStore?.SaveKey(new ServicePrincipalKey(AzureAccount.Property.ServicePrincipalSecret + keyStore?.SaveCredential(new ServicePrincipalKey(AzureAccount.Property.ServicePrincipalSecret ,azureAccount.Id, Tenant), password); if (GetContextModificationScope() == ContextModificationScope.CurrentUser && !keyStore.IsProtected) { @@ -713,9 +713,7 @@ public void OnImport() } AzKeyStore keyStore = null; - keyStore = new AzKeyStore(AzureSession.Instance.ARMProfileDirectory, AzureSession.Instance.KeyStoreFile, false, autoSaveEnabled); - AzKeyStore.RegisterJsonConverter(typeof(ServicePrincipalKey), typeof(ServicePrincipalKey).Name); - AzKeyStore.RegisterJsonConverter(typeof(SecureString), typeof(SecureString).Name, new SecureStringConverter()); + keyStore = new AzKeyStore(AzureSession.Instance.ARMProfileDirectory, AzureSession.Instance.KeyStoreFile); AzureSession.Instance.RegisterComponent(AzKeyStore.Name, () => keyStore); if (!InitializeProfileProvider(autoSaveEnabled)) @@ -724,11 +722,6 @@ public void OnImport() autoSaveEnabled = false; } - if (!keyStore.LoadStorage()) - { - WriteInitializationWarnings(Resources.KeyStoreLoadingError); - } - IAuthenticatorBuilder builder = null; if (!AzureSession.Instance.TryGetComponent(AuthenticatorBuilder.AuthenticatorBuilderKey, out builder)) { diff --git a/src/Accounts/Accounts/AutoSave/DisableAzureRmContextAutosave.cs b/src/Accounts/Accounts/AutoSave/DisableAzureRmContextAutosave.cs index c9f768733cb7..1d69aba78ce5 100644 --- a/src/Accounts/Accounts/AutoSave/DisableAzureRmContextAutosave.cs +++ b/src/Accounts/Accounts/AutoSave/DisableAzureRmContextAutosave.cs @@ -92,11 +92,6 @@ void DisableAutosave(IAzureSession session, bool writeAutoSaveFile, out ContextA builder.Reset(); } - if (AzureSession.Instance.TryGetComponent(AzKeyStore.Name, out AzKeyStore keystore)) - { - keystore.DisableAutoSaving(); - } - if (writeAutoSaveFile) { FileUtilities.EnsureDirectoryExists(session.ProfileDirectory); diff --git a/src/Accounts/Accounts/AutoSave/EnableAzureRmContextAutosave.cs b/src/Accounts/Accounts/AutoSave/EnableAzureRmContextAutosave.cs index fae6a7aed5b5..66325469ea71 100644 --- a/src/Accounts/Accounts/AutoSave/EnableAzureRmContextAutosave.cs +++ b/src/Accounts/Accounts/AutoSave/EnableAzureRmContextAutosave.cs @@ -102,13 +102,6 @@ void EnableAutosave(IAzureSession session, bool writeAutoSaveFile, out ContextAu AzureSession.Instance.RegisterComponent(PowerShellTokenCacheProvider.PowerShellTokenCacheProviderKey, () => newCacheProvider, true); } - if (AzureSession.Instance.TryGetComponent(AzKeyStore.Name, out AzKeyStore keystore)) - { - keystore.Flush(); - keystore.DisableAutoSaving(); - } - - if (writeAutoSaveFile) { try diff --git a/src/Accounts/Accounts/Context/ImportAzureRMContext.cs b/src/Accounts/Accounts/Context/ImportAzureRMContext.cs index 18cf276a6f42..9304d88d13ba 100644 --- a/src/Accounts/Accounts/Context/ImportAzureRMContext.cs +++ b/src/Accounts/Accounts/Context/ImportAzureRMContext.cs @@ -78,13 +78,13 @@ void CopyProfile(AzureRmProfile source, IProfileOperations target) var secret = account.GetProperty(AzureAccount.Property.ServicePrincipalSecret); if (!string.IsNullOrEmpty(secret)) { - keyStore.SaveKey(new ServicePrincipalKey(AzureAccount.Property.ServicePrincipalSecret, account.Id, context.Value.Tenant?.Id) + keyStore.SaveCredential(new ServicePrincipalKey(AzureAccount.Property.ServicePrincipalSecret, account.Id, context.Value.Tenant?.Id) , secret.ConvertToSecureString()); } var password = account.GetProperty(AzureAccount.Property.CertificatePassword); if (!string.IsNullOrEmpty(password)) { - keyStore.SaveKey(new ServicePrincipalKey(AzureAccount.Property.CertificatePassword, account.Id, context.Value.Tenant?.Id) + keyStore.SaveCredential(new ServicePrincipalKey(AzureAccount.Property.CertificatePassword, account.Id, context.Value.Tenant?.Id) ,password.ConvertToSecureString()); } } diff --git a/src/Accounts/Accounts/Context/SetAzureRMContext.cs b/src/Accounts/Accounts/Context/SetAzureRMContext.cs index 716d2e940b75..fab411685fea 100644 --- a/src/Accounts/Accounts/Context/SetAzureRMContext.cs +++ b/src/Accounts/Accounts/Context/SetAzureRMContext.cs @@ -97,13 +97,13 @@ public override void ExecuteCmdlet() var secret = account.GetProperty(AzureAccount.Property.ServicePrincipalSecret); if (!string.IsNullOrEmpty(secret)) { - keyStore.SaveKey(new ServicePrincipalKey(AzureAccount.Property.ServicePrincipalSecret, account.Id, Context.Tenant?.Id) + keyStore.SaveCredential(new ServicePrincipalKey(AzureAccount.Property.ServicePrincipalSecret, account.Id, Context.Tenant?.Id) , secret.ConvertToSecureString()); } var password = account.GetProperty(AzureAccount.Property.CertificatePassword); if (!string.IsNullOrEmpty(password)) { - keyStore.SaveKey(new ServicePrincipalKey(AzureAccount.Property.CertificatePassword, account.Id, Context.Tenant?.Id) + keyStore.SaveCredential(new ServicePrincipalKey(AzureAccount.Property.CertificatePassword, account.Id, Context.Tenant?.Id) , password.ConvertToSecureString()); } } diff --git a/src/Accounts/Authentication.ResourceManager/AzureRmProfile.cs b/src/Accounts/Authentication.ResourceManager/AzureRmProfile.cs index d6738055f6d6..9c81702be477 100644 --- a/src/Accounts/Authentication.ResourceManager/AzureRmProfile.cs +++ b/src/Accounts/Authentication.ResourceManager/AzureRmProfile.cs @@ -225,13 +225,13 @@ private IAzureContext MigrateSecretToKeyStore(IAzureContext context, AzKeyStore var account = context.Account; if (account.IsPropertySet(AzureAccount.Property.ServicePrincipalSecret)) { - keystore?.SaveKey(new ServicePrincipalKey(AzureAccount.Property.ServicePrincipalSecret, account.Id, account.GetTenants().First()) + keystore?.SaveCredential(new ServicePrincipalKey(AzureAccount.Property.ServicePrincipalSecret, account.Id, account.GetTenants().First()) , account.ExtendedProperties.GetProperty(AzureAccount.Property.ServicePrincipalSecret).ConvertToSecureString()); account.ExtendedProperties.Remove(AzureAccount.Property.ServicePrincipalSecret); } if (account.IsPropertySet(AzureAccount.Property.CertificatePassword)) { - keystore?.SaveKey(new ServicePrincipalKey(AzureAccount.Property.CertificatePassword, account.Id, account.GetTenants().First()) + keystore?.SaveCredential(new ServicePrincipalKey(AzureAccount.Property.CertificatePassword, account.Id, account.GetTenants().First()) , account.ExtendedProperties.GetProperty(AzureAccount.Property.CertificatePassword).ConvertToSecureString()); account.ExtendedProperties.Remove(AzureAccount.Property.CertificatePassword); } @@ -336,10 +336,6 @@ public void Save(IFileProvider provider, bool serializeCache = true) // so that previous data is overwritten provider.Stream.SetLength(provider.Stream.Position); } - - AzKeyStore keystore = null; - AzureSession.Instance.TryGetComponent(AzKeyStore.Name, out keystore); - keystore?.Flush(); } finally { diff --git a/src/Accounts/Authentication.Test/AzKeyStorageTest.cs b/src/Accounts/Authentication.Test/AzKeyStorageTest.cs index 64f8a6eadb83..3e6165a547b3 100644 --- a/src/Accounts/Authentication.Test/AzKeyStorageTest.cs +++ b/src/Accounts/Authentication.Test/AzKeyStorageTest.cs @@ -11,6 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // ---------------------------------------------------------------------------------- +using Microsoft.Azure.Commands.Common.Authentication.Properties; using Microsoft.Azure.Commands.ResourceManager.Common; using Microsoft.WindowsAzure.Commands.Common; using Microsoft.WindowsAzure.Commands.ScenarioTest; @@ -18,8 +19,8 @@ using Newtonsoft.Json; using System; using System.Collections.Generic; +using System.IO; using System.Linq; -using System.Security; using System.Text; using Xunit; @@ -29,7 +30,7 @@ public class AzKeyStorageTest { private Mock storageMocker = null; private List storageChecker = null; - private string dummpyPath = "/home/dummy/.Azure"; + private string dummpyPath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.UserProfile), Resources.AzureDirectoryName); private string keyStoreFileName = "azkeystore"; public AzKeyStorageTest() @@ -55,25 +56,23 @@ private static bool CompareJsonObjects(string expected, string acutal) [Trait(Category.AcceptanceType, Category.CheckIn)] public void SaveKey() { - using (var store = new AzKeyStore(dummpyPath, keyStoreFileName, false, true, storageMocker.Object)) + using (var store = new AzKeyStore(dummpyPath, keyStoreFileName, storageMocker.Object)) { - AzKeyStore.RegisterJsonConverter(typeof(ServicePrincipalKey), typeof(ServicePrincipalKey).Name); - AzKeyStore.RegisterJsonConverter(typeof(SecureString), typeof(SecureString).Name, new SecureStringConverter()); - IKeyStoreKey servicePrincipalKey = new ServicePrincipalKey("ServicePrincipalSecret", "6c984d31-5b4f-4734-b548-e230a248e347", "54826b22-38d6-4fb2-bad9-b7b93a3e9c5a"); var secret = "secret".ConvertToSecureString(); - store.SaveKey(servicePrincipalKey, secret); + store.SaveCredential(servicePrincipalKey, secret); IKeyStoreKey certificatePassword = new ServicePrincipalKey("CertificatePassword", "6c984d31-5b4f-4734-b548-e230a248e347", "54826b22-38d6-4fb2-bad9-b7b93a3e9c5a"); var passowrd = "password".ConvertToSecureString(); - store.SaveKey(certificatePassword, passowrd); + store.SaveCredential(certificatePassword, passowrd); + + var result = Encoding.UTF8.GetString(storageChecker.ToArray()); + const string EXPECTEDSTRING = @"[{""keyType"":""ServicePrincipalKey"",""keyStoreKey"":""{\""appId\"":\""6c984d31-5b4f-4734-b548-e230a248e347\"",\""tenantId\"":\""54826b22-38d6-4fb2-bad9-b7b93a3e9c5a\"",\""name\"":\""CertificatePassword\""}"",""valueType"":""SecureString"",""keyStoreValue"":""\""password\""""},{""keyType"":""ServicePrincipalKey"",""keyStoreKey"":""{\""appId\"":\""6c984d31-5b4f-4734-b548-e230a248e347\"",\""tenantId\"":\""54826b22-38d6-4fb2-bad9-b7b93a3e9c5a\"",\""name\"":\""ServicePrincipalSecret\""}"",""valueType"":""SecureString"",""keyStoreValue"":""\""secret\""""}]"; + Assert.True(CompareJsonObjects(EXPECTEDSTRING, result)); - store.Flush(); + store.Clear(); } storageMocker.Verify(); - var result = Encoding.UTF8.GetString(storageChecker.ToArray()); - const string EXPECTEDSTRING = @"[{""keyType"":""ServicePrincipalKey"",""keyStoreKey"":""{\""appId\"":\""6c984d31-5b4f-4734-b548-e230a248e347\"",\""tenantId\"":\""54826b22-38d6-4fb2-bad9-b7b93a3e9c5a\"",\""name\"":\""CertificatePassword\""}"",""valueType"":""SecureString"",""keyStoreValue"":""\""password\""""},{""keyType"":""ServicePrincipalKey"",""keyStoreKey"":""{\""appId\"":\""6c984d31-5b4f-4734-b548-e230a248e347\"",\""tenantId\"":\""54826b22-38d6-4fb2-bad9-b7b93a3e9c5a\"",\""name\"":\""ServicePrincipalSecret\""}"",""valueType"":""SecureString"",""keyStoreValue"":""\""secret\""""}]"; - Assert.True(CompareJsonObjects(EXPECTEDSTRING, result)); } [Fact] @@ -82,16 +81,34 @@ public void FindKey() { const string EXPECTED = @"[{""keyType"":""ServicePrincipalKey"",""keyStoreKey"":""{\""appId\"":\""6c984d31-5b4f-4734-b548-e230a248e347\"",\""tenantId\"":\""54826b22-38d6-4fb2-bad9-b7b93a3e9c5a\"",\""name\"":\""ServicePrincipalSecret\""}"",""valueType"":""SecureString"",""keyStoreValue"":""\""secret\""""}]"; storageChecker.AddRange(Encoding.UTF8.GetBytes(EXPECTED)); - using (var store = new AzKeyStore(dummpyPath, keyStoreFileName, false, true, storageMocker.Object)) + using (var store = new AzKeyStore(dummpyPath, keyStoreFileName, storageMocker.Object)) { - AzKeyStore.RegisterJsonConverter(typeof(ServicePrincipalKey), typeof(ServicePrincipalKey).Name); - AzKeyStore.RegisterJsonConverter(typeof(SecureString), typeof(SecureString).Name, new SecureStringConverter()); storageMocker.Setup(f => f.ReadData()).Returns(storageChecker.ToArray()); - store.LoadStorage(); IKeyStoreKey servicePrincipalKey = new ServicePrincipalKey("ServicePrincipalSecret", "6c984d31-5b4f-4734-b548-e230a248e347", "54826b22-38d6-4fb2-bad9-b7b93a3e9c5a"); - var secret = store.GetKey(servicePrincipalKey); + var secret = store.GetCredential(servicePrincipalKey); Assert.Equal("secret", secret.ConvertToString()); + + store.Clear(); + } + storageMocker.Verify(); + } + + [Fact] + [Trait(Category.AcceptanceType, Category.CheckIn)] + public void FindFallbackKey() + { + const string EXPECTED = @"[{""keyType"":""ServicePrincipalKey"",""keyStoreKey"":""{\""appId\"":\""6c984d31-5b4f-4734-b548-e230a248e347\"",\""tenantId\"":\""54826b22-38d6-4fb2-bad9-b7b93a3e9c5a\"",\""name\"":\""ServicePrincipalSecret\""}"",""valueType"":""SecureString"",""keyStoreValue"":""\""secretFallback\""""}]"; + storageChecker.AddRange(Encoding.UTF8.GetBytes(EXPECTED)); + using (var store = new AzKeyStore(dummpyPath, keyStoreFileName, storageMocker.Object)) + { + storageMocker.Setup(f => f.ReadData()).Returns(storageChecker.ToArray()); + + IKeyStoreKey servicePrincipalKey = new ServicePrincipalKey("ServicePrincipalSecret", "6c984d31-5b4f-4734-b548-e230a248e347", "54826b22-38d6-0000-bad9-b7b93a3e9c5a"); + var secret = store.GetCredential(servicePrincipalKey); + Assert.Equal("secretFallback", secret.ConvertToString()); + + store.Clear(); } storageMocker.Verify(); } @@ -103,15 +120,14 @@ public void FindNoKey() { const string EXPECTED = @"[{""keyType"":""ServicePrincipalKey"",""keyStoreKey"":""{\""appId\"":\""6c984d31-5b4f-4734-b548-e230a248e347\"",\""tenantId\"":\""54826b22-38d6-4fb2-bad9-b7b93a3e9c5a\"",\""name\"":\""ServicePrincipalSecret\""}"",""valueType"":""SecureString"",""keyStoreValue"":""\""secret\""""}]"; storageChecker.AddRange(Encoding.UTF8.GetBytes(EXPECTED)); - using (var store = new AzKeyStore(dummpyPath, keyStoreFileName, false, true, storageMocker.Object)) + using (var store = new AzKeyStore(dummpyPath, keyStoreFileName, storageMocker.Object)) { - AzKeyStore.RegisterJsonConverter(typeof(ServicePrincipalKey), typeof(ServicePrincipalKey).Name); - AzKeyStore.RegisterJsonConverter(typeof(SecureString), typeof(SecureString).Name, new SecureStringConverter()); storageMocker.Setup(f => f.ReadData()).Returns(storageChecker.ToArray()); - store.LoadStorage(); IKeyStoreKey servicePrincipalKey = new ServicePrincipalKey("CertificatePassword", "6c984d31-5b4f-4734-b548-e230a248e347", "54826b22-38d6-4fb2-bad9-b7b93a3e9c5a"); - Assert.Throws(() => store.GetKey(servicePrincipalKey)); + Assert.Throws(() => store.GetCredential(servicePrincipalKey)); + + store.Clear(); } storageMocker.Verify(); } @@ -122,21 +138,20 @@ public void RemoveKey() { const string EXPECTED = @"[{""keyType"":""ServicePrincipalKey"",""keyStoreKey"":""{\""appId\"":\""6c984d31-5b4f-4734-b548-e230a248e347\"",\""tenantId\"":\""54826b22-38d6-4fb2-bad9-b7b93a3e9c5a\"",\""name\"":\""ServicePrincipalSecret\""}"",""valueType"":""SecureString"",""keyStoreValue"":""\""secret\""""}]"; storageChecker.AddRange(Encoding.UTF8.GetBytes(EXPECTED)); - using (var store = new AzKeyStore(dummpyPath, keyStoreFileName, false, true, storageMocker.Object)) + using (var store = new AzKeyStore(dummpyPath, keyStoreFileName, storageMocker.Object)) { - AzKeyStore.RegisterJsonConverter(typeof(ServicePrincipalKey), typeof(ServicePrincipalKey).Name); - AzKeyStore.RegisterJsonConverter(typeof(SecureString), typeof(SecureString).Name, new SecureStringConverter()); storageMocker.Setup(f => f.ReadData()).Returns(storageChecker.ToArray()); - store.LoadStorage(); IKeyStoreKey servicePrincipalKey = new ServicePrincipalKey("ServicePrincipalSecret", "6c984d31-5b4f-4734-b548-e230a248e347", "54826b22-38d6-4fb2-bad9-b7b93a3e9c5a"); - store.DeleteKey(servicePrincipalKey); - store.Flush(); + store.RemoveCredential(servicePrincipalKey); + + var result = Encoding.UTF8.GetString(storageChecker.ToArray()); + var objects = JsonConvert.DeserializeObject>(result); + Assert.Empty(objects); + + store.Clear(); } storageMocker.Verify(); - var result = Encoding.UTF8.GetString(storageChecker.ToArray()); - var objects = JsonConvert.DeserializeObject>(result); - Assert.Empty(objects); } [Fact] @@ -145,21 +160,20 @@ public void RemoveNoKey() { const string EXPECTED = @"[{""keyType"":""ServicePrincipalKey"",""keyStoreKey"":""{\""appId\"":\""6c984d31-5b4f-4734-b548-e230a248e347\"",\""tenantId\"":\""54826b22-38d6-4fb2-bad9-b7b93a3e9c5a\"",\""name\"":\""ServicePrincipalSecret\""}"",""valueType"":""SecureString"",""keyStoreValue"":""\""secret\""""}]"; storageChecker.AddRange(Encoding.UTF8.GetBytes(EXPECTED)); - using (var store = new AzKeyStore(dummpyPath, keyStoreFileName, false, true, storageMocker.Object)) + using (var store = new AzKeyStore(dummpyPath, keyStoreFileName, storageMocker.Object)) { - AzKeyStore.RegisterJsonConverter(typeof(ServicePrincipalKey), typeof(ServicePrincipalKey).Name); - AzKeyStore.RegisterJsonConverter(typeof(SecureString), typeof(SecureString).Name, new SecureStringConverter()); storageMocker.Setup(f => f.ReadData()).Returns(storageChecker.ToArray()); - store.LoadStorage(); IKeyStoreKey servicePrincipalKey = new ServicePrincipalKey("CertificatePassword", "6c984d31-5b4f-4734-b548-e230a248e347", "54826b22-38d6-4fb2-bad9-b7b93a3e9c5a"); - store.DeleteKey(servicePrincipalKey); - store.Flush(); + store.RemoveCredential(servicePrincipalKey); + + var result = Encoding.UTF8.GetString(storageChecker.ToArray()); + var objects = JsonConvert.DeserializeObject>(result); + Assert.Single(objects); + + store.Clear(); } storageMocker.Verify(); - var result = Encoding.UTF8.GetString(storageChecker.ToArray()); - var objects = JsonConvert.DeserializeObject>(result); - Assert.Single(objects); } } } diff --git a/src/Accounts/Authentication/Factories/AuthenticationFactory.cs b/src/Accounts/Authentication/Factories/AuthenticationFactory.cs index 721841e2d8dd..830d0520d6b4 100644 --- a/src/Accounts/Authentication/Factories/AuthenticationFactory.cs +++ b/src/Accounts/Authentication/Factories/AuthenticationFactory.cs @@ -433,9 +433,9 @@ public void RemoveUser(IAzureAccount account, IAzureTokenCache tokenCache) case AzureAccount.AccountType.ServicePrincipal: try { - KeyStore.DeleteKey(new ServicePrincipalKey(AzureAccount.Property.ServicePrincipalSecret, + KeyStore.RemoveCredential(new ServicePrincipalKey(AzureAccount.Property.ServicePrincipalSecret, account.Id, account.GetTenants().FirstOrDefault())); - KeyStore.DeleteKey(new ServicePrincipalKey(AzureAccount.Property.CertificatePassword, + KeyStore.RemoveCredential(new ServicePrincipalKey(AzureAccount.Property.CertificatePassword, account.Id, account.GetTenants().FirstOrDefault())); } catch @@ -577,7 +577,7 @@ private AuthenticationParameters GetAuthenticationParameters( { try { - password = KeyStore.GetKey(new ServicePrincipalKey(AzureAccount.Property.ServicePrincipalSecret + password = KeyStore.GetCredential(new ServicePrincipalKey(AzureAccount.Property.ServicePrincipalSecret , account.Id, tenant)); } catch @@ -591,7 +591,7 @@ private AuthenticationParameters GetAuthenticationParameters( { try { - certificatePassword = KeyStore.GetKey(new ServicePrincipalKey(AzureAccount.Property.CertificatePassword + certificatePassword = KeyStore.GetCredential(new ServicePrincipalKey(AzureAccount.Property.CertificatePassword , account.Id, tenant)); } catch diff --git a/src/Accounts/Authentication/Identity/AsyncLockWithValue.cs b/src/Accounts/Authentication/Identity/AsyncLockWithValue.cs index e89ca5eb2394..6ac99a950115 100644 --- a/src/Accounts/Authentication/Identity/AsyncLockWithValue.cs +++ b/src/Accounts/Authentication/Identity/AsyncLockWithValue.cs @@ -126,6 +126,24 @@ private void SetValue(T value) } } + /// + /// Try to reset value and fail if value is locked. + /// + /// + public bool TryClearValue() + { + lock (_syncObj) + { + if (!_isLocked) + { + _value = default(T); + _hasValue = false; + return true; + } + } + return false; + } + /// /// Release the lock and allow next waiter acquire it /// diff --git a/src/Accounts/Authentication/KeyStore/AzKeyStore.cs b/src/Accounts/Authentication/KeyStore/AzKeyStore.cs index 506f535ba225..7c610f20fd6d 100644 --- a/src/Accounts/Authentication/KeyStore/AzKeyStore.cs +++ b/src/Accounts/Authentication/KeyStore/AzKeyStore.cs @@ -11,12 +11,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ---------------------------------------------------------------------------------- -using Newtonsoft.Json; using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Linq; -using System.Text; +using System.Security; namespace Microsoft.Azure.Commands.ResourceManager.Common { @@ -24,215 +20,69 @@ public class AzKeyStore : IDisposable { public const string Name = "AzKeyStore"; - internal class KeyStoreElement - { - public string keyType; - public string keyStoreKey; - public string valueType; - public string keyStoreValue; - } - - private static IDictionary _typeNameMap = new ConcurrentDictionary(); - - private static IDictionary _elementConverterMap = new ConcurrentDictionary(); + public string FileName { get; set; } + public string Directory { get; set; } - public static void RegisterJsonConverter(Type type, string typeName, JsonConverter converter = null) + private IKeyStore _inMemoryStore = null; + public IKeyStore InMemoryStore { - if (string.IsNullOrEmpty(typeName)) - { - throw new ArgumentNullException($"typeName cannot be empty."); - } - if (_typeNameMap.ContainsKey(type)) - { - if (string.Compare(_typeNameMap[type], typeName) != 0) - { - throw new ArgumentException($"{typeName} has conflict with {_typeNameMap[type]} with reference to {type}."); - } - } - else - { - _typeNameMap[type] = typeName; - } - if (converter != null) - { - _elementConverterMap[_typeNameMap[type]] = converter; - } + get => _inMemoryStore; + set => _inMemoryStore = value; } - private IDictionary _credentials = new ConcurrentDictionary(); - private IStorage _storage = null; - - private bool autoSave = true; - private Exception lastError = null; - - public IStorage Storage - { - get => _storage; - set => _storage = value; - } + private IStorage inputStorage = null; public bool IsProtected { - get => Storage.IsProtected; + get; + private set; } - public AzKeyStore() + public AzKeyStore(string directory, string fileName, IStorage storage = null) { + InMemoryStore = new InMemoryKeyStore(); + InMemoryStore.SetBeforeAccess(LoadStorage); - } + FileName = fileName; + Directory = directory; - public AzKeyStore(string directory, string fileName, bool loadStorage = true, bool autoSaveEnabled = true, IStorage inputStorage = null) - { - autoSave = autoSaveEnabled; - Storage = inputStorage ?? new StorageWrapper() - { - FileName = fileName, - Directory = directory - }; - Storage.Create(); - - if (loadStorage&&!LoadStorage()) - { - throw new InvalidOperationException("Failed to load keystore from storage."); - } - } + inputStorage = storage; - private object Deserialize(string typeName, string value) - { - Type t = null; - t = _typeNameMap.FirstOrDefault(item => item.Value == typeName).Key; - - if (t != null) - { - if (_elementConverterMap.ContainsKey(typeName)) - { - return JsonConvert.DeserializeObject(value, t, _elementConverterMap[typeName]); - } - else - { - return JsonConvert.DeserializeObject(value, t); - } - } - return null; + InMemoryKeyStore.RegisterJsonConverter(typeof(ServicePrincipalKey), typeof(ServicePrincipalKey).Name); + InMemoryKeyStore.RegisterJsonConverter(typeof(SecureString), typeof(SecureString).Name, new SecureStringConverter()); } - public bool LoadStorage() - { - try - { - var data = Storage.ReadData(); - if (data != null && data.Length > 0) - { - var rawJsonString = Encoding.UTF8.GetString(data); - var serializableKeyStore = JsonConvert.DeserializeObject(rawJsonString, typeof(List)) as List; - if (serializableKeyStore != null) - { - foreach (var item in serializableKeyStore) - { - IKeyStoreKey keyStoreKey = Deserialize(item.keyType, item.keyStoreKey) as IKeyStoreKey; - if (keyStoreKey == null) - { - throw new ArgumentException($"Cannot parse the keystore {item.keyStoreKey} with the type {item.keyType}."); - } - var keyStoreValue = Deserialize(item.valueType, item.keyStoreValue); - if (keyStoreValue == null) - { - throw new ArgumentException($"Cannot parse the keystore {item.keyStoreValue} with the type {item.valueType}."); - } - _credentials[keyStoreKey] = keyStoreValue; - } - } - } - } - catch (Exception e) - { - lastError = e; - return false; - } - return true; - } - public void ClearCache() + private void LoadStorage(KeyStoreNotificationArgs args) { - _credentials.Clear(); + var asyncHelper = StorageHelper.GetStorageHelperAsync(true, FileName, Directory, args.KeyStore, inputStorage); + var helper = asyncHelper.GetAwaiter().GetResult(); + IsProtected = helper.IsProtected; } public void Clear() { - ClearCache(); - Storage.Clear(); - } - - public void Flush() - { - IList serializableKeyStore = new List(); - foreach (var item in _credentials) - { - var keyType = _typeNameMap[item.Key.GetType()]; - var key = _elementConverterMap.ContainsKey(keyType) ? - JsonConvert.SerializeObject(item.Key, _elementConverterMap[keyType]) : JsonConvert.SerializeObject(item.Key); - if (!string.IsNullOrEmpty(key)) - { - var valueType = _typeNameMap[item.Value.GetType()]; - serializableKeyStore.Add(new KeyStoreElement() - { - keyType = keyType, - keyStoreKey = key, - valueType = valueType, - keyStoreValue = _elementConverterMap.ContainsKey(valueType) ? - JsonConvert.SerializeObject(item.Value, _elementConverterMap[valueType]) : JsonConvert.SerializeObject(item.Value), - }) ; - } - } - var JsonString = JsonConvert.SerializeObject(serializableKeyStore); - Storage.WriteData(Encoding.UTF8.GetBytes(JsonString)); + InMemoryStore.Clear(); } public void Dispose() { - if (autoSave) - { - Flush(); - } - ClearCache(); - } - - public void SaveKey(IKeyStoreKey key, T value) - { - if (!_typeNameMap.ContainsKey(key.GetType()) || !_typeNameMap.ContainsKey(value.GetType())) - { - throw new InvalidOperationException("Please register key & values type before save it."); - } - _credentials[key] = value; - } - - public T GetKey(IKeyStoreKey key) - { - if (!_credentials.ContainsKey(key)) - { - throw new ArgumentException($"{key.ToString()} is not stored in AzKeyStore yet."); - } - return (T)_credentials[key]; - } - - public bool DeleteKey(IKeyStoreKey key) - { - return _credentials.Remove(key); + StorageHelper.TryClearLockedStorageHelper(); } - public void EnableAutoSaving() + public void SaveCredential(IKeyStoreKey key, SecureString value) { - autoSave = true; + InMemoryStore.SaveKey(key, value); } - public void DisableAutoSaving() + public SecureString GetCredential(IKeyStoreKey key) { - autoSave = false; + return InMemoryStore.GetKey(key); } - public Exception GetLastError() + public bool RemoveCredential(IKeyStoreKey key) { - return lastError; + return InMemoryStore.DeleteKey(key); } } } diff --git a/src/Accounts/Authentication/KeyStore/IKeyStore.cs b/src/Accounts/Authentication/KeyStore/IKeyStore.cs new file mode 100644 index 000000000000..f461ac77533f --- /dev/null +++ b/src/Accounts/Authentication/KeyStore/IKeyStore.cs @@ -0,0 +1,34 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +namespace Microsoft.Azure.Commands.ResourceManager.Common +{ + public interface IKeyStore + { + void SaveKey(IKeyStoreKey key, T value); + + T GetKey(IKeyStoreKey key); + + bool DeleteKey(IKeyStoreKey key); + + void Clear(); + + void SetBeforeAccess(KeyStoreCallbak beforeAccess); + + void SetOnUpdate(KeyStoreCallbak onUpdate); + void Deserialize(byte[] Data); + + byte[] Serialize(); + } +} diff --git a/src/Accounts/Authentication/KeyStore/IKeyStoreKey.cs b/src/Accounts/Authentication/KeyStore/IKeyStoreKey.cs index 7916b87bc3ef..4dbac64496d3 100644 --- a/src/Accounts/Authentication/KeyStore/IKeyStoreKey.cs +++ b/src/Accounts/Authentication/KeyStore/IKeyStoreKey.cs @@ -23,5 +23,7 @@ public abstract class IKeyStoreKey public override abstract bool Equals(object obj); public override abstract int GetHashCode(); + + public abstract bool BeEquivalent(object obj); } } diff --git a/src/Accounts/Authentication/KeyStore/IStorage.cs b/src/Accounts/Authentication/KeyStore/IStorage.cs index be9c8187e35c..ab8b6653973c 100644 --- a/src/Accounts/Authentication/KeyStore/IStorage.cs +++ b/src/Accounts/Authentication/KeyStore/IStorage.cs @@ -18,7 +18,7 @@ namespace Microsoft.Azure.Commands.ResourceManager.Common public interface IStorage { IStorage Create(); - + void Clear(); byte[] ReadData(); @@ -26,12 +26,5 @@ public interface IStorage void VerifyPersistence(); void WriteData(byte[] data); - - Exception GetLastError(); - - bool IsProtected - { - get; - } } -} +} \ No newline at end of file diff --git a/src/Accounts/Authentication/KeyStore/IStorageHelper.cs b/src/Accounts/Authentication/KeyStore/IStorageHelper.cs new file mode 100644 index 000000000000..eaa4c1619726 --- /dev/null +++ b/src/Accounts/Authentication/KeyStore/IStorageHelper.cs @@ -0,0 +1,35 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- +using System; + +namespace Microsoft.Azure.Commands.ResourceManager.Common +{ + public interface IStorageHelper + { + void Clear(); + + byte[] LoadUnencryptedTokenCache(); + + void SaveUnencryptedTokenCache(byte[] tokenCache); + + void LoadFromCachedStorage(IKeyStore keystore); + + void WriteToCachedStorage(KeyStoreNotificationArgs args); + + bool IsProtected + { + get; + } + } +} \ No newline at end of file diff --git a/src/Accounts/Authentication/KeyStore/InMemoryKeyStore.cs b/src/Accounts/Authentication/KeyStore/InMemoryKeyStore.cs new file mode 100644 index 000000000000..34363cc0cc05 --- /dev/null +++ b/src/Accounts/Authentication/KeyStore/InMemoryKeyStore.cs @@ -0,0 +1,218 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- +using Microsoft.Identity.Client; +using Newtonsoft.Json; +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.Azure.Commands.ResourceManager.Common +{ + internal class InMemoryKeyStore : IKeyStore + { + internal class KeyStoreElement + { + public string keyType; + public string keyStoreKey; + public string valueType; + public string keyStoreValue; + } + + private static IDictionary _typeNameMap = new ConcurrentDictionary(); + + private static IDictionary _elementConverterMap = new ConcurrentDictionary(); + + private IDictionary _credentials = new ConcurrentDictionary(); + + private readonly object lockObj = new object(); + + internal KeyStoreCallbak BeforeAccess = null; + + internal KeyStoreCallbak OnUpdate = null; + + public void SaveKey(IKeyStoreKey key, T value) + { + var args = new KeyStoreNotificationArgs() + { + KeyStore = this + }; + BeforeAccess?.Invoke(args) ; + if (!_typeNameMap.ContainsKey(key.GetType()) || !_typeNameMap.ContainsKey(value.GetType())) + { + throw new InvalidOperationException("Please register key & values type before save it."); + } + _credentials[key] = value; + OnUpdate?.Invoke(args); + } + + public T GetKey(IKeyStoreKey key) + { + var args = new KeyStoreNotificationArgs() + { + KeyStore = this + }; + BeforeAccess?.Invoke(args); + + object value = null; + if ( _credentials.TryGetValue(key, out value)) + { + return (T) value; + } + + try + { + var fallBackKey = _credentials.Keys.First(x => x.BeEquivalent(key)); + return (T)_credentials[fallBackKey]; + } + catch (InvalidOperationException) + { + throw new ArgumentException($"{key.ToString()} is not stored in AzKeyStore yet."); + } + } + + public bool DeleteKey(IKeyStoreKey key) + { + var args = new KeyStoreNotificationArgs() + { + KeyStore = this + }; + BeforeAccess?.Invoke(args); + bool ret = false; + ret = _credentials.Remove(key); + OnUpdate?.Invoke(args); + return ret; + } + + public void Deserialize(byte[] data) + { + lock(lockObj) + { + if (data != null && data.Length > 0) + { + var rawJsonString = Encoding.UTF8.GetString(data); + var serializableKeyStore = JsonConvert.DeserializeObject(rawJsonString, typeof(List)) as List; + if (serializableKeyStore != null) + { + foreach (var item in serializableKeyStore) + { + IKeyStoreKey keyStoreKey = DeserializeItem(item.keyType, item.keyStoreKey) as IKeyStoreKey; + if (keyStoreKey == null) + { + throw new ArgumentException($"Cannot parse the keystore {item.keyStoreKey} with the type {item.keyType}."); + } + var keyStoreValue = DeserializeItem(item.valueType, item.keyStoreValue); + if (keyStoreValue == null) + { + throw new ArgumentException($"Cannot parse the keystore {item.keyStoreValue} with the type {item.valueType}."); + } + _credentials[keyStoreKey] = keyStoreValue; + } + } + } + } + } + + public byte[] Serialize() + { + IList serializableKeyStore = new List(); + foreach (var item in _credentials) + { + var keyType = _typeNameMap[item.Key.GetType()]; + var key = _elementConverterMap.ContainsKey(keyType) ? + JsonConvert.SerializeObject(item.Key, _elementConverterMap[keyType]) : JsonConvert.SerializeObject(item.Key); + if (!string.IsNullOrEmpty(key)) + { + var valueType = _typeNameMap[item.Value.GetType()]; + serializableKeyStore.Add(new KeyStoreElement() + { + keyType = keyType, + keyStoreKey = key, + valueType = valueType, + keyStoreValue = _elementConverterMap.ContainsKey(valueType) ? + JsonConvert.SerializeObject(item.Value, _elementConverterMap[valueType]) : JsonConvert.SerializeObject(item.Value), + }); + } + } + var JsonString = JsonConvert.SerializeObject(serializableKeyStore); + return Encoding.UTF8.GetBytes(JsonString); + } + + public void Clear() + { + var args = new KeyStoreNotificationArgs() + { + KeyStore = this + }; + BeforeAccess?.Invoke(args); + _credentials.Clear(); + OnUpdate?.Invoke(args); + + } + + private static object DeserializeItem(string typeName, string value) + { + Type t = null; + t = _typeNameMap.FirstOrDefault(item => item.Value == typeName).Key; + + if (t != null) + { + if (_elementConverterMap.ContainsKey(typeName)) + { + return JsonConvert.DeserializeObject(value, t, _elementConverterMap[typeName]); + } + else + { + return JsonConvert.DeserializeObject(value, t); + } + } + return null; + } + + public static void RegisterJsonConverter(Type type, string typeName, JsonConverter converter = null) + { + if (string.IsNullOrEmpty(typeName)) + { + throw new ArgumentNullException($"typeName cannot be empty."); + } + if (_typeNameMap.ContainsKey(type)) + { + if (string.Compare(_typeNameMap[type], typeName) != 0) + { + throw new ArgumentException($"{typeName} has conflict with {_typeNameMap[type]} with reference to {type}."); + } + } + else + { + _typeNameMap[type] = typeName; + } + if (converter != null) + { + _elementConverterMap[_typeNameMap[type]] = converter; + } + } + + public void SetBeforeAccess(KeyStoreCallbak beforeAccess) + { + BeforeAccess = beforeAccess; + } + + public void SetOnUpdate(KeyStoreCallbak onUpdate) + { + OnUpdate = onUpdate; + } + } +} diff --git a/src/Accounts/Authentication/KeyStore/KeyStoreNotificationArgs.cs b/src/Accounts/Authentication/KeyStore/KeyStoreNotificationArgs.cs new file mode 100644 index 000000000000..25cdcf22a50a --- /dev/null +++ b/src/Accounts/Authentication/KeyStore/KeyStoreNotificationArgs.cs @@ -0,0 +1,23 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +namespace Microsoft.Azure.Commands.ResourceManager.Common +{ + public class KeyStoreNotificationArgs + { + public IKeyStore KeyStore; + } + + public delegate void KeyStoreCallbak(KeyStoreNotificationArgs args); +} diff --git a/src/Accounts/Authentication/KeyStore/SecureStringConverter.cs b/src/Accounts/Authentication/KeyStore/SecureStringConverter.cs index c399dfacd424..fbbf3f60da9a 100644 --- a/src/Accounts/Authentication/KeyStore/SecureStringConverter.cs +++ b/src/Accounts/Authentication/KeyStore/SecureStringConverter.cs @@ -10,6 +10,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// ---------------------------------------------------------------------------------- using Microsoft.WindowsAzure.Commands.Common; using Newtonsoft.Json; using System; diff --git a/src/Accounts/Authentication/KeyStore/ServicePrincipalKey.cs b/src/Accounts/Authentication/KeyStore/ServicePrincipalKey.cs index 88e8b5dfc1f7..327ce6acfb8c 100644 --- a/src/Accounts/Authentication/KeyStore/ServicePrincipalKey.cs +++ b/src/Accounts/Authentication/KeyStore/ServicePrincipalKey.cs @@ -54,5 +54,14 @@ public override bool Equals(object obj) } return false; } + + public override bool BeEquivalent(object obj) + { + if (obj is ServicePrincipalKey other) + { + return this.name == other.name && this.appId == other.appId; + } + return false; + } } } diff --git a/src/Accounts/Authentication/KeyStore/StorageHelper.cs b/src/Accounts/Authentication/KeyStore/StorageHelper.cs new file mode 100644 index 000000000000..6adbf6e0fcd8 --- /dev/null +++ b/src/Accounts/Authentication/KeyStore/StorageHelper.cs @@ -0,0 +1,292 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- +using Microsoft.Azure.PowerShell.Authenticators.Identity; +using Microsoft.Identity.Client.Extensions.Msal; +using Microsoft.IdentityModel.Abstractions; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading.Tasks; + +namespace Microsoft.Azure.Commands.ResourceManager.Common +{ + public class StorageHelper : IStorageHelper + { + private const string KeyChainServiceName = "Microsoft.Azure.PowerShell"; + + private static readonly Lazy s_staticLogger = new Lazy(() => + { + return new TraceSourceLogger(new TraceSource(nameof(StorageHelper))); + }); + + private readonly StorageCreationProperties _storageCreationProperties; + + internal IStorage CacheStore { get; } + + private readonly TraceSourceLogger _logger; + + private bool _protected; + public bool IsProtected + { + get => _protected; + private set => _protected = value; + } + + private static AsyncLockWithValue cacheHelperLock = new AsyncLockWithValue(); + + internal StorageHelper(StorageCreationProperties storageProperties, bool isProtected, IStorage store = null) + { + _logger = s_staticLogger.Value; + _storageCreationProperties = storageProperties; + CacheStore = store ?? new StorageWrapper() + { + StorageCreationProperties = _storageCreationProperties, + LoggerSource = _logger.Source + }; + CacheStore.Create(); + _protected = isProtected; + } + + private static StorageHelper GetProtectedStorageHelper(string fileName, string directory, IStorage storage = null) + { + var storageProperties = new StorageCreationPropertiesBuilder(fileName, directory) + .WithMacKeyChain(KeyChainServiceName + ".other_secrets", fileName) + .WithLinuxKeyring(fileName, "default", "AzKeyStoreCache", + new KeyValuePair("AzureClientID", "Microsoft.Developer.Azure.PowerShell"), + new KeyValuePair("Microsoft.Developer.Azure.PowerShell", "1.0.0.0")).Build(); + return StorageHelper.Create(storageProperties, true, storage); + } + + private static StorageHelper GetFallbackStorageHelper(string fileName, string directory, IStorage storage = null) + { + var storageProperties = new StorageCreationPropertiesBuilder(fileName, directory) + .WithUnprotectedFile().Build(); + return StorageHelper.Create(storageProperties, false, storage); + } + + private static StorageHelper Create(StorageCreationProperties storageCreationProperties, bool isProtected, IStorage storage = null) + { + if (storageCreationProperties is null) + { + throw new ArgumentNullException(nameof(storageCreationProperties)); + } + + using (CreateCrossPlatLock(storageCreationProperties)) + { + return new StorageHelper(storageCreationProperties, isProtected, storage); + } + } + + #region Public API + public static async Task GetStorageHelperAsync(bool async, string fileName, string directory, IKeyStore keystore, IStorage storage = null) + { + StorageHelper storageHelper = null; + + using (var asyncLock = await cacheHelperLock.GetLockOrValueAsync(async).ConfigureAwait(false)) + { + if (asyncLock.HasValue) + { + return asyncLock.Value; + } + + try + { + storageHelper = GetProtectedStorageHelper(fileName, directory, storage); + storageHelper.VerifyPersistence(); + } + catch (Exception) + { + storageHelper = GetFallbackStorageHelper(fileName, directory, storage); + storageHelper.VerifyPersistence(); + } + storageHelper.RegisterCache(keystore); + storageHelper.LoadFromCachedStorage(keystore); + asyncLock.SetValue(storageHelper); + } + return storageHelper; + } + + public static bool TryClearLockedStorageHelper() + { + return cacheHelperLock.TryClearValue(); + } + + public void Clear() + { + using (CreateCrossPlatLock(_storageCreationProperties)) + { + CacheStore.Clear(); + } + } + + public byte[] LoadUnencryptedTokenCache() + { + using (CreateCrossPlatLock(_storageCreationProperties)) + { + return CacheStore.ReadData(); + } + } + + public void SaveUnencryptedTokenCache(byte[] tokenCache) + { + using (CreateCrossPlatLock(_storageCreationProperties)) + { + CacheStore.WriteData(tokenCache); + } + } + + public void RegisterCache(IKeyStore keystore) + { + if (keystore == null) + { + throw new ArgumentNullException(nameof(keystore)); + } + + _logger.LogInformation($"Registering token cache with on disk storage"); + + keystore.SetOnUpdate(WriteToCachedStorage); + + _logger.LogInformation($"Done initializing"); + } + + public void UnregisterCache(IKeyStore keystore) + { + if (keystore == null) + { + throw new ArgumentNullException(nameof(keystore)); + } + keystore.SetOnUpdate(null); + } + + public void LoadFromCachedStorage(IKeyStore keystore) + { + LogMessage(EventLogLevel.Verbose, $"Before access\nAcquiring lock for token cache"); + + // OK, we have two nested locks here. We need to maintain a clear ordering to avoid deadlocks. + // 1. Use the CrossPlatLock which is respected by all processes and is used around all cache accesses. + // 2. Use _lockObject which is used in UnregisterCache, and is needed for all accesses of _registeredCaches. + + using (CreateCrossPlatLock(_storageCreationProperties)) + { + LogMessage(EventLogLevel.Verbose, $"Before access, the store has changed"); + + byte[] cachedStoreData = null; + try + { + cachedStoreData = CacheStore.ReadData(); + } + catch (Exception ex) + { + LogMessage(EventLogLevel.Error, $"Could not read the token cache. Ignoring. Exception: {ex}"); + return; + + } + LogMessage(EventLogLevel.Verbose, $"Read '{cachedStoreData?.Length}' bytes from storage"); + + try + { + LogMessage(EventLogLevel.Verbose, $"Deserializing the store"); + keystore.Deserialize(cachedStoreData); //shouldClearExistingCache: true? + } + catch (Exception e) + { + LogMessage(EventLogLevel.Error, $"An exception was encountered while deserializing the {nameof(StorageHelper)} : {e}"); + LogMessage(EventLogLevel.Error, $"No data found in the store, clearing the cache in memory."); + + // Clear the memory cache without taking the lock over again + CacheStore.Clear(); + throw; + } + } + } + + public void WriteToCachedStorage(KeyStoreNotificationArgs args) + { + using (CreateCrossPlatLock(_storageCreationProperties)) + { + LogMessage(EventLogLevel.Verbose, $"After access"); + byte[] data = null; + // if the access operation resulted in a cache update + LogMessage(EventLogLevel.Verbose, $"After access, cache in memory HasChanged"); + try + { + data = args.KeyStore.Serialize(); + } + catch (Exception e) + { + LogMessage(EventLogLevel.Error, $"An exception was encountered while serializing the {nameof(StorageHelper)} : {e}"); + LogMessage(EventLogLevel.Error, $"No data found in the store, clearing the cache in memory."); + + // The cache is corrupt clear it out + CacheStore.Clear(); + throw; + } + + if (data != null) + { + LogMessage(EventLogLevel.Verbose, $"Serializing '{data.Length}' bytes"); + + try + { + CacheStore.WriteData(data); + } + catch (Exception) + { + LogMessage(EventLogLevel.Error, $"Could not write the keystore. Ignoring. See previous error message."); + } + } + } + } + #endregion + + private static CrossPlatLock CreateCrossPlatLock(StorageCreationProperties storageCreationProperties) + { + return new CrossPlatLock( + storageCreationProperties.CacheFilePath + ".lockfile", + storageCreationProperties.LockRetryDelay, + storageCreationProperties.LockRetryCount); + } + + public void VerifyPersistence() + { + CacheStore.VerifyPersistence(); + } + + //Logs to TraceSourceLogger and Identity Logger acquired from MSAL's TokenCacheNotificationArgs + private void LogMessage(EventLogLevel level, string message) + { + LogMessage(level, message, _logger); + } + + //Logs to TraceSourceLogger and Identity Logger acquired from MSAL's TokenCacheNotificationArgs + private static void LogMessage(EventLogLevel level, string message, TraceSourceLogger traceSourceLogger) + { + message = $"[{KeyChainServiceName}] {message}"; + + //Log to TraceSourceLogger + switch (level) + { + case EventLogLevel.Warning: + traceSourceLogger.LogWarning(message); + break; + case EventLogLevel.Error: + traceSourceLogger.LogError(message); + break; + case EventLogLevel.Verbose: + traceSourceLogger.LogInformation(message); + break; + } + } + } +} \ No newline at end of file diff --git a/src/Accounts/Authentication/KeyStore/StorageWrapper.cs b/src/Accounts/Authentication/KeyStore/StorageWrapper.cs index f5c01f42e3ce..aaa15a7600d2 100644 --- a/src/Accounts/Authentication/KeyStore/StorageWrapper.cs +++ b/src/Accounts/Authentication/KeyStore/StorageWrapper.cs @@ -11,34 +11,19 @@ // See the License for the specific language governing permissions and // limitations under the License. // ---------------------------------------------------------------------------------- -using Microsoft.Azure.Commands.Common.Authentication.Properties; using Microsoft.Identity.Client.Extensions.Msal; -using System; -using System.Collections.Generic; -using System.Threading; +using System.Diagnostics; namespace Microsoft.Azure.Commands.ResourceManager.Common { class StorageWrapper : IStorage - { - private const string KeyChainServiceName = "Microsoft.Azure.PowerShell"; + { + public StorageCreationProperties StorageCreationProperties { get; set; } - public string FileName { get; set; } - public string Directory { get; set; } - - private Exception _lastError; + public TraceSource LoggerSource { get; set; } private Storage _storage = null; - private bool _protected; - public bool IsProtected - { - get => _protected; - private set => _protected = value; - } - - static ReaderWriterLockSlim storageLock = new ReaderWriterLockSlim(LockRecursionPolicy.SupportsRecursion); - public StorageWrapper() { @@ -46,104 +31,28 @@ public StorageWrapper() public IStorage Create() { - StorageCreationPropertiesBuilder storageProperties = null; - if (!storageLock.TryEnterWriteLock(TimeSpan.Zero)) - { - throw new InvalidOperationException(Resources.StorageLockConflicts); - } - try - { - storageProperties = new StorageCreationPropertiesBuilder(FileName, Directory) - .WithMacKeyChain(KeyChainServiceName + ".other_secrets", FileName) - .WithLinuxKeyring(FileName, "default", "AzKeyStoreCache", - new KeyValuePair("AzureClientID", "Microsoft.Developer.Azure.PowerShell"), - new KeyValuePair("Microsoft.Developer.Azure.PowerShell", "1.0.0.0")); - _storage = Storage.Create(storageProperties.Build()); - VerifyPersistence(); - _protected = true; - } - catch (Exception e) - { - _lastError = e; - storageProperties = new StorageCreationPropertiesBuilder(FileName, Directory).WithUnprotectedFile(); - _storage = Storage.Create(storageProperties.Build()); - _protected = false; - } - finally - { - storageLock.ExitWriteLock(); - } + _storage = Storage.Create(StorageCreationProperties, LoggerSource); return this; } public void Clear() { - if (!storageLock.TryEnterWriteLock(TimeSpan.Zero)) - { - throw new InvalidOperationException(Resources.StorageLockConflicts); - } - try - { - _storage.Clear(); - } - finally - { - storageLock.ExitWriteLock(); - } + _storage.Clear(ignoreExceptions: true); } public byte[] ReadData() { - if (!storageLock.TryEnterReadLock(TimeSpan.Zero)) - { - throw new InvalidOperationException(Resources.StorageLockConflicts); - } - try - { - return _storage.ReadData(); - } - finally - { - storageLock.ExitReadLock(); - } + return _storage.ReadData(); } public void VerifyPersistence() { - if (!storageLock.TryEnterWriteLock(TimeSpan.Zero)) - { - throw new InvalidOperationException(Resources.StorageLockConflicts); - } - try - { - _storage.VerifyPersistence(); - } - finally - { - storageLock.ExitWriteLock(); - } + _storage.VerifyPersistence(); } public void WriteData(byte[] data) { - if (!storageLock.TryEnterWriteLock(TimeSpan.Zero)) - { - throw new InvalidOperationException(Resources.StorageLockConflicts); - } - - try - { - _storage.WriteData(data); - } - finally - { - storageLock.ExitWriteLock(); - } - } - - public Exception GetLastError() - { - return _lastError; + _storage.WriteData(data); } } -} +} \ No newline at end of file