Skip to content

Commit

Permalink
Construct correct Namespace for Classifier2xxAnd4xxDefinition (#4625)
Browse files Browse the repository at this point in the history
This PR fixes an issue when a client's namespace is customized via the
`[CodeGenClient("ClientName")]` attribute and the private type
`Classifier2xxAnd4xx` within a client is still generated using the
default root namespace leading to compilation issues.

fixes: #4619
  • Loading branch information
jorgerangel-msft authored Oct 7, 2024
1 parent 258b327 commit 2ba6c7a
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ protected override MethodProvider[] BuildMethods()
return [BuildTryClassifyErrorMethod(), BuildTryClassifyRetryMethod()];
}

protected override string GetNamespace() => DeclaringTypeProvider!.Type.Namespace;

protected override CSharpType[] BuildImplements()
{
return [typeof(PipelineMessageClassifier)];
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// <auto-generated/>

#nullable disable

using System;
using System.ClientModel.Primitives;

namespace Sample.Custom
{
/// <summary></summary>
public partial class TestClient
{
private static global::System.ClientModel.Primitives.PipelineMessageClassifier _pipelineMessageClassifier200;
private static global::System.ClientModel.Primitives.PipelineMessageClassifier _pipelineMessageClassifier204;
private static global::Sample.Custom.TestClient.Classifier2xxAnd4xx _pipelineMessageClassifier2xxAnd4xx;

private static global::System.ClientModel.Primitives.PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 = global::System.ClientModel.Primitives.PipelineMessageClassifier.Create(stackalloc ushort[] { 200 });

private static global::System.ClientModel.Primitives.PipelineMessageClassifier PipelineMessageClassifier204 => _pipelineMessageClassifier204 = global::System.ClientModel.Primitives.PipelineMessageClassifier.Create(stackalloc ushort[] { 204 });

private static global::Sample.Custom.TestClient.Classifier2xxAnd4xx PipelineMessageClassifier2xxAnd4xx => _pipelineMessageClassifier2xxAnd4xx ??= new global::Sample.Custom.TestClient.Classifier2xxAnd4xx();

private class Classifier2xxAnd4xx : global::System.ClientModel.Primitives.PipelineMessageClassifier
{
public override bool TryClassify(global::System.ClientModel.Primitives.PipelineMessage message, out bool isError)
{
isError = false;
if ((message.Response == null))
{
return false;
}
isError = message.Response.Status switch
{
((>= 200) and (< 300)) => false,
((>= 400) and (< 500)) => false,
_ => true
};
return true;
}

public override bool TryClassify(global::System.ClientModel.Primitives.PipelineMessage message, global::System.Exception exception, out bool isRetryable)
{
isRetryable = false;
return false;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#nullable disable

using System.ClientModel;
using System.ClientModel.Primitives;
using System.Threading.Tasks;
using Microsoft.Generator.CSharp.Customization;

namespace Sample.Custom;

[CodeGenClient("TestClient")]
public partial class TestClient { }
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Generator.CSharp.ClientModel.Providers;
using Microsoft.Generator.CSharp.Tests.Common;
using NUnit.Framework;

namespace Microsoft.Generator.CSharp.ClientModel.Tests.Providers.Definitions
{
public class Classifier2xxAnd4xxDefinitionTests
{
[TestCaseSource(nameof(GetTypeNamespaceTestCases))]
public void TestGetTypeNamespace(string mockJson)
{
MockHelpers.LoadMockPlugin(configuration: mockJson);
var inputClient = InputFactory.Client("TestClient");
var restClientProvider = new ClientProvider(inputClient).RestClient;
Assert.IsNotNull(restClientProvider);

var classifier2xxAnd4xxDefinition = new Classifier2xxAnd4xxDefinition(restClientProvider);
var result = classifier2xxAnd4xxDefinition.Type.Namespace;

Assert.AreEqual(restClientProvider.Type.Namespace, result);
}

[Test]
public async Task TestGetTypeCustomNamespace()
{
var inputClient = InputFactory.Client("TestClient");
var plugin = await MockHelpers.LoadMockPluginAsync(
clients: () => [inputClient],
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());

// Find the rest client provider
var clientProvider = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ClientProvider);
Assert.IsNotNull(clientProvider);
var restClientProvider = (clientProvider as ClientProvider)!.RestClient;
Assert.IsNotNull(restClientProvider);

var classifier2xxAnd4xxDefinition = new Classifier2xxAnd4xxDefinition(restClientProvider!);
var result = classifier2xxAnd4xxDefinition.Type.Namespace;

Assert.AreEqual(restClientProvider!.Type.Namespace, result);
}

public static IEnumerable<TestCaseData> GetTypeNamespaceTestCases
{
get
{
yield return new TestCaseData(@"{
""output-folder"": ""outputFolder"",
""library-name"": ""libraryName"",
""namespace"": ""testNamespace""
}");
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#nullable disable

using System.ClientModel;
using System.ClientModel.Primitives;
using System.Threading.Tasks;
using Microsoft.Generator.CSharp.Customization;

namespace Sample.Custom;

[CodeGenClient("TestClient")]
public partial class TestClient { }
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.Linq;
using System.Threading.Tasks;
using Microsoft.Generator.CSharp.ClientModel.Providers;
using Microsoft.Generator.CSharp.Primitives;
using Microsoft.Generator.CSharp.Tests.Common;
using NUnit.Framework;

namespace Microsoft.Generator.CSharp.ClientModel.Tests.Providers.ClientProviders
{
public class RestClientProviderCustomizationTests
{
// Validates the client is generated using the correct namespace
[Test]
public async Task CanChangeClientNamespace()
{
var inputClient = InputFactory.Client("TestClient");
var plugin = await MockHelpers.LoadMockPluginAsync(
clients: () => [inputClient],
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());

var clientProvider = plugin.Object.OutputLibrary.TypeProviders.SingleOrDefault(t => t is ClientProvider);
Assert.IsNotNull(clientProvider);
var restClientProvider = (clientProvider as ClientProvider)!.RestClient;
Assert.IsNotNull(restClientProvider);

var writer = new TypeProviderWriter(restClientProvider);
var file = writer.Write();
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ public static async Task<Mock<ClientModelPlugin>> LoadMockPluginAsync(
Func<IReadOnlyList<InputEnumType>>? inputEnums = null,
Func<IReadOnlyList<InputModelType>>? inputModels = null,
Func<IReadOnlyList<InputClient>>? clients = null,
Func<Task<Compilation>>? compilation = null)
Func<Task<Compilation>>? compilation = null,
string? configuration = null)
{
var mockPlugin = LoadMockPlugin(
inputEnums: inputEnums,
inputModels: inputModels,
clients: clients);
clients: clients,
configuration: configuration);

var compilationResult = compilation == null ? null : await compilation();

Expand All @@ -53,7 +55,8 @@ public static Mock<ClientModelPlugin> LoadMockPlugin(
Func<IReadOnlyList<InputModelType>>? inputModels = null,
Func<IReadOnlyList<InputClient>>? clients = null,
Func<InputLibrary>? createInputLibrary = null,
Func<InputClient, ClientProvider>? createClientCore = null)
Func<InputClient, ClientProvider>? createClientCore = null,
string? configuration = null)
{
IReadOnlyList<string> inputNsApiVersions = apiVersions?.Invoke() ?? [];
IReadOnlyList<InputEnumType> inputNsEnums = inputEnums?.Invoke() ?? [];
Expand Down Expand Up @@ -106,7 +109,7 @@ public static Mock<ClientModelPlugin> LoadMockPlugin(
var clientModelInstance = typeof(ClientModelPlugin).GetField("_instance", BindingFlags.Static | BindingFlags.NonPublic);
// invoke the load method with the config file path
var loadMethod = typeof(Configuration).GetMethod("Load", BindingFlags.Static | BindingFlags.NonPublic);
object?[] parameters = [_configFilePath, null];
object?[] parameters = [_configFilePath, configuration];
var config = loadMethod?.Invoke(null, parameters);
var mockGeneratorContext = new Mock<GeneratorContext>(config!);
var mockPluginInstance = new Mock<ClientModelPlugin>(mockGeneratorContext.Object) { CallBase = true };
Expand Down

0 comments on commit 2ba6c7a

Please sign in to comment.