From 8e1af6771f4d771d02c0e38a93cf0bc31935817f Mon Sep 17 00:00:00 2001 From: maca88 Date: Wed, 14 Sep 2022 10:41:52 +0200 Subject: [PATCH] Add option for specifying extension methods from referenced assemblies --- ...ojectAsyncExtensionMethodsConfiguration.cs | 15 ++--- .../Extensions/SymbolExtensions.cs | 15 +++-- .../FileConfiguration/FileConfiguration.cs | 19 ++++++ .../FileConfiguration/FileConfigurator.cs | 5 ++ .../Plugins/AsyncExtensionMethodsFinder.cs | 61 +++++++++++++++---- .../AsyncGenerator.TestCases/IFileReader.cs | 22 +++++++ .../AsyncMethodFinder/Fixture.cs | 30 ++++++--- .../Input/ExternalExtensionMethods.cs | 12 ++++ .../Output/ExternalExtensionMethods.txt | 23 +++++++ Source/AsyncGenerator/AsyncCodeGenerator.cs | 10 ++- ...ojectAsyncExtensionMethodsConfiguration.cs | 24 +++++++- 11 files changed, 201 insertions(+), 35 deletions(-) create mode 100644 Source/AsyncGenerator.TestCases/IFileReader.cs create mode 100644 Source/AsyncGenerator.Tests/AsyncMethodFinder/Input/ExternalExtensionMethods.cs create mode 100644 Source/AsyncGenerator.Tests/AsyncMethodFinder/Output/ExternalExtensionMethods.txt diff --git a/Source/AsyncGenerator.Core/Configuration/IFluentProjectAsyncExtensionMethodsConfiguration.cs b/Source/AsyncGenerator.Core/Configuration/IFluentProjectAsyncExtensionMethodsConfiguration.cs index 788f5917..d49a51fb 100644 --- a/Source/AsyncGenerator.Core/Configuration/IFluentProjectAsyncExtensionMethodsConfiguration.cs +++ b/Source/AsyncGenerator.Core/Configuration/IFluentProjectAsyncExtensionMethodsConfiguration.cs @@ -13,16 +13,13 @@ public interface IFluentProjectAsyncExtensionMethodsConfiguration /// /// Name of the project where async extension methods are located /// Name of the file which contains the async extension methods - /// IFluentProjectAsyncExtensionMethodsConfiguration ProjectFile(string projectName, string fileName); - // TODO - ///// - ///// Add an external type that contains async extension methods - ///// - ///// Name of the assembly where async extension methods are located - ///// Full name of the type which contains the async extension methods - ///// - //IFluentProjectExtensionMethodsConfiguration ExternalType(string assemblyName, string type); + /// + /// Add an external type that contains async extension methods + /// + /// Name of the assembly where async extension methods are located + /// Full name of the type which contains the async extension methods + IFluentProjectAsyncExtensionMethodsConfiguration ExternalType(string assemblyName, string fullTypeName); } } diff --git a/Source/AsyncGenerator.Core/Extensions/SymbolExtensions.cs b/Source/AsyncGenerator.Core/Extensions/SymbolExtensions.cs index 9e432278..e660bd65 100644 --- a/Source/AsyncGenerator.Core/Extensions/SymbolExtensions.cs +++ b/Source/AsyncGenerator.Core/Extensions/SymbolExtensions.cs @@ -43,16 +43,23 @@ public static bool IsAsyncCounterpart(this IMethodSymbol syncMethod, ITypeSymbol syncMethod = syncMethod.ReducedFrom ?? syncMethod; } + var candidateIndexOffset = 0; + if (candidateAsyncMethod.IsExtensionMethod && !syncMethod.IsExtensionMethod) + { + candidateIndexOffset = 1; + } + if (syncMethod.OverriddenMethod != null && candidateAsyncMethod.EqualTo(syncMethod.OverriddenMethod)) { return false; } + var candidateParameterLength = candidateAsyncMethod.Parameters.Length - candidateIndexOffset; // Check if the length of the parameters matches - if (syncMethod.Parameters.Length != candidateAsyncMethod.Parameters.Length) + if (syncMethod.Parameters.Length != candidateParameterLength) { // For symplicity, we suppose that the sync method does not have a cancellation token as a parameter - if (!hasCancellationToken || syncMethod.Parameters.Length + 1 != candidateAsyncMethod.Parameters.Length) + if (!hasCancellationToken || syncMethod.Parameters.Length + 1 != candidateParameterLength) { return false; } @@ -93,7 +100,7 @@ public static bool IsAsyncCounterpart(this IMethodSymbol syncMethod, ITypeSymbol for (var i = 0; i < syncMethod.Parameters.Length; i++) { var param = syncMethod.Parameters[i]; - var candidateParam = candidateAsyncMethod.Parameters[i]; + var candidateParam = candidateAsyncMethod.Parameters[i + candidateIndexOffset]; if (param.IsOptional != candidateParam.IsOptional || param.IsParams != candidateParam.IsParams || param.RefKind != candidateParam.RefKind) @@ -152,7 +159,7 @@ public static bool IsAsyncCounterpart(this IMethodSymbol syncMethod, ITypeSymbol } result = true; } - if (syncMethod.Parameters.Length >= candidateAsyncMethod.Parameters.Length) + if (syncMethod.Parameters.Length >= candidateParameterLength) { return result; } diff --git a/Source/AsyncGenerator.Core/FileConfiguration/FileConfiguration.cs b/Source/AsyncGenerator.Core/FileConfiguration/FileConfiguration.cs index a6a63103..6d486577 100644 --- a/Source/AsyncGenerator.Core/FileConfiguration/FileConfiguration.cs +++ b/Source/AsyncGenerator.Core/FileConfiguration/FileConfiguration.cs @@ -125,9 +125,13 @@ public class AsyncExtensionMethods [XmlArrayItem("ProjectFile", IsNullable = false)] public List ProjectFiles { get; set; } + [XmlArrayItem("AssemblyType", IsNullable = false)] + public List AssemblyTypes { get; set; } + public AsyncExtensionMethods() { ProjectFiles = new List(); + AssemblyTypes = new List(); } } @@ -379,6 +383,21 @@ public class ProjectFile [XmlAttribute(AttributeName = "projectName")] public string ProjectName { get; set; } } + + [Serializable] + [DebuggerStepThrough] + [DesignerCategory("code")] + [XmlType(Namespace = "https://github.com/maca88/AsyncGenerator")] + [XmlRoot("AssemblyType")] + [EditorBrowsable(EditorBrowsableState.Never)] + public class AssemblyType + { + [XmlAttribute(AttributeName = "assemblyName")] + public string AssemblyName { get; set; } + + [XmlAttribute(AttributeName = "fullTypeName")] + public string FullTypeName { get; set; } + } [Serializable] [DebuggerStepThrough] diff --git a/Source/AsyncGenerator.Core/FileConfiguration/FileConfigurator.cs b/Source/AsyncGenerator.Core/FileConfiguration/FileConfigurator.cs index dfc5c125..e405627c 100644 --- a/Source/AsyncGenerator.Core/FileConfiguration/FileConfigurator.cs +++ b/Source/AsyncGenerator.Core/FileConfiguration/FileConfigurator.cs @@ -224,6 +224,11 @@ private static void Configure(AsyncExtensionMethods config, IFluentProjectAsyncE { fluentConfig.ProjectFile(projectFile.ProjectName, projectFile.FileName); } + + foreach (var assemblyType in config.AssemblyTypes) + { + fluentConfig.ExternalType(assemblyType.AssemblyName, assemblyType.FullTypeName); + } } private static void Configure(AsyncGenerator configuration, Diagnostics config, IFluentProjectDiagnosticsConfiguration fluentConfig) diff --git a/Source/AsyncGenerator.Core/Plugins/AsyncExtensionMethodsFinder.cs b/Source/AsyncGenerator.Core/Plugins/AsyncExtensionMethodsFinder.cs index eb0a058b..6668f987 100644 --- a/Source/AsyncGenerator.Core/Plugins/AsyncExtensionMethodsFinder.cs +++ b/Source/AsyncGenerator.Core/Plugins/AsyncExtensionMethodsFinder.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text; using System.Threading.Tasks; using AsyncGenerator.Core.Configuration; using AsyncGenerator.Core.Extensions; @@ -10,6 +9,7 @@ using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.CSharp; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using static AsyncGenerator.Core.AsyncCounterpartsSearchOptions; namespace AsyncGenerator.Core.Plugins { @@ -19,23 +19,56 @@ public class AsyncExtensionMethodsFinder : IAsyncCounterpartsFinder, IDocumentTr private ILookup _extensionMethodsLookup; private readonly string _fileName; private readonly string _projectName; + private readonly bool _findByReference; - public AsyncExtensionMethodsFinder(string projectName, string fileName) + public AsyncExtensionMethodsFinder(string projectName, string fileName, bool findByReference) { _projectName = projectName; _fileName = fileName; + _findByReference = findByReference; } public async Task Initialize(Project project, IProjectConfiguration configuration, Compilation compilation) { - var extProject = project.Solution.Projects.First(o => o.Name == _projectName); - var doc = extProject.Documents.First(o => o.Name == _fileName); - var rootNode = await doc.GetSyntaxRootAsync().ConfigureAwait(false); - var semanticModel = await doc.GetSemanticModelAsync().ConfigureAwait(false); - _extensionMethods = new HashSet(rootNode.DescendantNodes() - .OfType() - .Where(o => o.Identifier.ValueText.EndsWith("Async")) - .Select(o => semanticModel.GetDeclaredSymbol(o)), SymbolEqualityComparer.Default); + _extensionMethods = new HashSet(SymbolEqualityComparer.Default); + if (_findByReference) + { + var test = compilation.References + .Select(compilation.GetAssemblyOrModuleSymbol) + .OfType() + .FirstOrDefault(o => o.Name == _projectName); + var type = test?.GetTypeByMetadataName(_fileName); + if (type == null) + { + throw new InvalidOperationException($"Type {_fileName} was not found in assembly {_projectName}"); + } + + foreach (var asyncMethod in type.GetMembers().OfType() + .Where(o => o.Name.EndsWith("Async") && o.IsExtensionMethod)) + { + _extensionMethods.Add(asyncMethod); + } + } + else + { + var extProject = project.Solution.Projects.First(o => o.Name == _projectName); + var docs = extProject.Documents.Where(o => o.Name == _fileName); + foreach (var doc in docs) + { + var rootNode = await doc.GetSyntaxRootAsync().ConfigureAwait(false); + var semanticModel = await doc.GetSemanticModelAsync().ConfigureAwait(false); + var asyncMethods = rootNode.DescendantNodes() + .OfType() + .Where(o => o.Identifier.ValueText.EndsWith("Async")) + .Select(o => semanticModel.GetDeclaredSymbol(o)) + .Where(o => o?.IsExtensionMethod == true); + foreach (var asyncMethod in asyncMethods) + { + _extensionMethods.Add(asyncMethod); + } + } + } + _extensionMethodsLookup = _extensionMethods.ToLookup(o => o.Name); } @@ -77,10 +110,16 @@ public IEnumerable FindAsyncCounterparts(IMethodSymbol symbol, IT var asyncName = symbol.GetAsyncName(); foreach (var asyncCandidate in _extensionMethodsLookup[asyncName]) { - if (!symbol.IsAsyncCounterpart(invokedFromType, asyncCandidate, true, true, false)) + if (!symbol.IsAsyncCounterpart( + invokedFromType, + asyncCandidate, + true, + options.HasFlag(HasCancellationToken), + options.HasFlag(IgnoreReturnType))) { continue; } + yield return asyncCandidate; yield break; } diff --git a/Source/AsyncGenerator.TestCases/IFileReader.cs b/Source/AsyncGenerator.TestCases/IFileReader.cs new file mode 100644 index 00000000..29d1e8e2 --- /dev/null +++ b/Source/AsyncGenerator.TestCases/IFileReader.cs @@ -0,0 +1,22 @@ +using System.Threading.Tasks; + +namespace AsyncGenerator.TestCases +{ + public class FileResult + { + + } + + public interface IFileReader + { + FileResult Read(string path); + } + + public static class FileReaderExtensions + { + public static Task ReadAsync(this IFileReader reader, string path) + { + return Task.FromResult(new FileResult()); + } + } +} \ No newline at end of file diff --git a/Source/AsyncGenerator.Tests/AsyncMethodFinder/Fixture.cs b/Source/AsyncGenerator.Tests/AsyncMethodFinder/Fixture.cs index a515d62e..ba2cf8ba 100644 --- a/Source/AsyncGenerator.Tests/AsyncMethodFinder/Fixture.cs +++ b/Source/AsyncGenerator.Tests/AsyncMethodFinder/Fixture.cs @@ -1,11 +1,5 @@ -using System; -using System.Linq; -using System.Threading.Tasks; -using AsyncGenerator.Analyzation; +using System.Threading.Tasks; using AsyncGenerator.Core; -using AsyncGenerator.Core.Plugins; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp.Syntax; using NUnit.Framework; using AsyncGenerator.Tests.AsyncMethodFinder.Input; @@ -161,5 +155,27 @@ public Task TestExtensionMethodsAfterTransformation() ) ); } + + [Test] + public Task TestExternalExtensionMethodsAfterTransformation() + { + return ReadonlyTest(nameof(ExternalExtensionMethods), p => p + .ConfigureAnalyzation(a => a + .MethodConversion(symbol => MethodConversion.Smart) + .AsyncExtensionMethods(o => o + .ExternalType("AsyncGenerator.TestCases", "AsyncGenerator.TestCases.FileReaderExtensions")) + ) + .ConfigureTransformation(t => t + .AfterTransformation(result => + { + AssertValidAnnotations(result); + Assert.AreEqual(1, result.Documents.Count); + var document = result.Documents[0]; + Assert.NotNull(document.OriginalModified); + Assert.AreEqual(GetOutputFile(nameof(ExternalExtensionMethods)), document.Transformed.ToFullString()); + }) + ) + ); + } } } diff --git a/Source/AsyncGenerator.Tests/AsyncMethodFinder/Input/ExternalExtensionMethods.cs b/Source/AsyncGenerator.Tests/AsyncMethodFinder/Input/ExternalExtensionMethods.cs new file mode 100644 index 00000000..dc5b8f48 --- /dev/null +++ b/Source/AsyncGenerator.Tests/AsyncMethodFinder/Input/ExternalExtensionMethods.cs @@ -0,0 +1,12 @@ +using AsyncGenerator.TestCases; + +namespace AsyncGenerator.Tests.AsyncMethodFinder.Input +{ + public class ExternalExtensionMethods + { + public void External(IFileReader reader) + { + reader.Read("test"); + } + } +} \ No newline at end of file diff --git a/Source/AsyncGenerator.Tests/AsyncMethodFinder/Output/ExternalExtensionMethods.txt b/Source/AsyncGenerator.Tests/AsyncMethodFinder/Output/ExternalExtensionMethods.txt new file mode 100644 index 00000000..dc54a95a --- /dev/null +++ b/Source/AsyncGenerator.Tests/AsyncMethodFinder/Output/ExternalExtensionMethods.txt @@ -0,0 +1,23 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using AsyncGenerator.TestCases; + +namespace AsyncGenerator.Tests.AsyncMethodFinder.Input +{ + using System.Threading.Tasks; + public partial class ExternalExtensionMethods + { + public Task ExternalAsync(IFileReader reader) + { + return reader.ReadAsync("test"); + } + } +} \ No newline at end of file diff --git a/Source/AsyncGenerator/AsyncCodeGenerator.cs b/Source/AsyncGenerator/AsyncCodeGenerator.cs index af673ef7..85ca0af8 100644 --- a/Source/AsyncGenerator/AsyncCodeGenerator.cs +++ b/Source/AsyncGenerator/AsyncCodeGenerator.cs @@ -123,7 +123,15 @@ internal static async Task GenerateProject(ProjectData projectData, ILoggerFacto { foreach (var fileName in pair.Value) { - RegisterPlugin(projectData.Configuration, new AsyncExtensionMethodsFinder(pair.Key, fileName)); + RegisterPlugin(projectData.Configuration, new AsyncExtensionMethodsFinder(pair.Key, fileName, false)); + } + } + + foreach (var pair in analyzeConfig.AsyncExtensionMethods.AssemblyTypes) + { + foreach (var fullTypeName in pair.Value) + { + RegisterPlugin(projectData.Configuration, new AsyncExtensionMethodsFinder(pair.Key, fullTypeName, true)); } } diff --git a/Source/AsyncGenerator/Configuration/Internal/ProjectAsyncExtensionMethodsConfiguration.cs b/Source/AsyncGenerator/Configuration/Internal/ProjectAsyncExtensionMethodsConfiguration.cs index ea5a4e58..e17624d8 100644 --- a/Source/AsyncGenerator/Configuration/Internal/ProjectAsyncExtensionMethodsConfiguration.cs +++ b/Source/AsyncGenerator/Configuration/Internal/ProjectAsyncExtensionMethodsConfiguration.cs @@ -1,8 +1,5 @@ using System; using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; using AsyncGenerator.Core.Configuration; namespace AsyncGenerator.Configuration.Internal @@ -10,6 +7,8 @@ namespace AsyncGenerator.Configuration.Internal internal class ProjectAsyncExtensionMethodsConfiguration : IFluentProjectAsyncExtensionMethodsConfiguration { public Dictionary> ProjectFiles { get; } = new Dictionary>(); + + public Dictionary> AssemblyTypes { get; } = new Dictionary>(); IFluentProjectAsyncExtensionMethodsConfiguration IFluentProjectAsyncExtensionMethodsConfiguration.ProjectFile(string projectName, string fileName) { @@ -29,5 +28,24 @@ IFluentProjectAsyncExtensionMethodsConfiguration IFluentProjectAsyncExtensionMet ProjectFiles[projectName].Add(fileName); return this; } + + IFluentProjectAsyncExtensionMethodsConfiguration IFluentProjectAsyncExtensionMethodsConfiguration.ExternalType(string assemblyName, string fullTypeName) + { + if (assemblyName == null) + { + throw new ArgumentNullException(nameof(assemblyName)); + } + if (fullTypeName == null) + { + throw new ArgumentNullException(nameof(fullTypeName)); + } + + if (!AssemblyTypes.ContainsKey(assemblyName)) + { + AssemblyTypes.Add(assemblyName, new HashSet()); + } + AssemblyTypes[assemblyName].Add(fullTypeName); + return this; + } } }