Skip to content

Commit

Permalink
CSHARP-4550: Make MemberInitExpression work with struct also as long …
Browse files Browse the repository at this point in the history
…as suitable constructor exists.
  • Loading branch information
rstam committed May 24, 2023
1 parent 1ef7dd3 commit 3720217
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,40 +29,41 @@ internal static class MemberInitExpressionToAggregationExpressionTranslator
public static AggregationExpression Translate(TranslationContext context, MemberInitExpression expression)
{
var newExpression = expression.NewExpression;
var constructorInfo = newExpression.Constructor;
var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct
var constructorArguments = newExpression.Arguments;
var computedFields = new List<AstComputedField>();

var classMap = CreateClassMap(expression.Type, constructorInfo, out var creatorMap);
var creatorMapParameters = creatorMap.Arguments?.ToArray();
if (constructorInfo.GetParameters().Length > 0 && creatorMapParameters == null )
if (constructorInfo != null && creatorMap != null)
{
throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching properties for constructor parameters.");
}
var creatorMapParameters = creatorMap.Arguments?.ToArray();
if (constructorInfo.GetParameters().Length > 0 && creatorMapParameters == null)
{
throw new ExpressionNotSupportedException(expression, because: $"couldn't find matching properties for constructor parameters.");
}

var computedFields = new List<AstComputedField>();
for (var i = 0; i < creatorMapParameters.Length; i++)
{
var creatorMapParameter = creatorMapParameters[i];
var constructorArgumentExpression = constructorArguments[i];
var constructorArgumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, constructorArgumentExpression);
var constructorArgumentType = constructorArgumentExpression.Type;
var constructorArgumentSerializer = constructorArgumentTranslation.Serializer ?? BsonSerializer.LookupSerializer(constructorArgumentType);
var memberMap = EnsureMemberMap(expression, classMap, creatorMapParameter);
memberMap.SetSerializer(constructorArgumentSerializer);
computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, constructorArgumentTranslation.Ast));
for (var i = 0; i < creatorMapParameters.Length; i++)
{
var creatorMapParameter = creatorMapParameters[i];
var constructorArgumentExpression = constructorArguments[i];
var constructorArgumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, constructorArgumentExpression);
var constructorArgumentType = constructorArgumentExpression.Type;
var constructorArgumentSerializer = constructorArgumentTranslation.Serializer ?? BsonSerializer.LookupSerializer(constructorArgumentType);
var memberMap = EnsureMemberMap(expression, classMap, creatorMapParameter);
memberMap.SetSerializer(constructorArgumentSerializer);
computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, constructorArgumentTranslation.Ast));
}
}

foreach (var binding in expression.Bindings)
{
var memberAssignment = (MemberAssignment)binding;
var member = memberAssignment.Member;
var memberMap = FindMemberMap(expression, classMap, member.Name);

var valueExpression = memberAssignment.Expression;
var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression);
computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, valueTranslation.Ast));

memberMap.SetSerializer(valueTranslation.Serializer);
computedFields.Add(AstExpression.ComputedField(memberMap.ElementName, valueTranslation.Ast));
}

var ast = AstExpression.ComputedDocument(computedFields);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,51 @@

using System;
using FluentAssertions;
using MongoDB.Bson.Serialization.Attributes;
using MongoDB.Driver.Linq;
using Xunit;

namespace MongoDB.Driver.Tests.Linq.Linq3ImplementationTests.Translators.ExpressionToAggregationExpressionTranslators
{
public class MemberInitExpressionToAggregationExpressionTranslatorTests : Linq3IntegrationTest
{
private readonly IMongoCollection<MyData> _collection;

public MemberInitExpressionToAggregationExpressionTranslatorTests()
[Fact]
public void Should_project_class_via_parameterless_constructor()
{
_collection = CreateCollection(LinqProvider.V3);
var collection = CreateCollection();

var queryable = collection.AsQueryable()
.Select(x => new SpawnDataClassParameterless
{
Identifier = x.Id,
SpawnDate = x.Date,
SpawnText = x.Text
});

var stages = Translate(collection, queryable);
AssertStages(stages, "{ $project : { Identifier : '$_id', SpawnDate : '$Date', SpawnText : '$Text', _id : 0 } }");

var results = queryable.Single();

results.SpawnDate.Should().Be(new DateTime(2023, 1, 2, 3, 4, 5, DateTimeKind.Utc));
results.SpawnText.Should().Be("data text");
results.Identifier.Should().Be(1);
}

[Fact]
public void Should_project_via_parameterless_constructor()
public void Should_project_struct_via_parameterless_constructor()
{
var queryable = _collection.AsQueryable()
.Select(x => new SpawnDataParameterless
var collection = CreateCollection();

var queryable = collection.AsQueryable()
.Select(x => new SpawnDataStructParameterless
{
Identifier = x.Id,
SpawnDate = x.Date,
SpawnText = x.Text
});

var stages = Translate(_collection, queryable);
var stages = Translate(collection, queryable);
AssertStages(stages, "{ $project : { Identifier : '$_id', SpawnDate : '$Date', SpawnText : '$Text', _id : 0 } }");

var results = queryable.Single();
Expand All @@ -51,15 +70,38 @@ public void Should_project_via_parameterless_constructor()
}

[Fact]
public void Should_project_via_constructor()
public void Should_project_class_via_constructor()
{
var queryable = _collection.AsQueryable()
.Select(x => new SpawnData(x.Id, x.Date)
var collection = CreateCollection();

var queryable = collection.AsQueryable()
.Select(x => new SpawnDataClass(x.Id, x.Date)
{
SpawnText = x.Text
});

var stages = Translate(collection, queryable);
AssertStages(stages, "{ $project : { Identifier : '$_id', SpawnDate : '$Date', SpawnText : '$Text', _id : 0 } }");

var results = queryable.Single();

results.SpawnDate.Should().Be(new DateTime(2023, 1, 2, 3, 4, 5, DateTimeKind.Utc));
results.SpawnText.Should().Be("data text");
results.Identifier.Should().Be(1);
}

[Fact]
public void Should_project_struct_via_constructor()
{
var collection = CreateCollection();

var queryable = collection.AsQueryable()
.Select(x => new SpawnDataStruct(x.Id, x.Date)
{
SpawnText = x.Text
});

var stages = Translate(_collection, queryable);
var stages = Translate(collection, queryable);
AssertStages(stages, "{ $project : { Identifier : '$_id', SpawnDate : '$Date', SpawnText : '$Text', _id : 0 } }");

var results = queryable.Single();
Expand All @@ -72,13 +114,15 @@ public void Should_project_via_constructor()
[Fact]
public void Should_project_via_constructor_with_inheritance()
{
var queryable = _collection.AsQueryable()
var collection = CreateCollection();

var queryable = collection.AsQueryable()
.Select(x => new InheritedSpawnData(x.Id, x.Date)
{
SpawnText = x.Text
});

var stages = Translate(_collection, queryable);
var stages = Translate(collection, queryable);
AssertStages(stages, "{ $project : { Identifier : '$_id', SpawnDate : '$Date', SpawnText : '$Text', _id : 0 } }");

var results = queryable.Single();
Expand All @@ -88,9 +132,9 @@ public void Should_project_via_constructor_with_inheritance()
results.Identifier.Should().Be(1);
}

private IMongoCollection<MyData> CreateCollection(LinqProvider linqProvider)
private IMongoCollection<MyData> CreateCollection()
{
var collection = GetCollection<MyData>("data", linqProvider);
var collection = GetCollection<MyData>("data");

CreateCollection(
collection,
Expand All @@ -106,23 +150,70 @@ public class MyData
public string Text;
}

public class SpawnDataParameterless
public class SpawnDataClassParameterless
{
public int Identifier;
public DateTime SpawnDate;
public string SpawnText;
}

public class SpawnData
public struct SpawnDataStructParameterless
{
public int Identifier;
public DateTime SpawnDate;
public string SpawnText;

// this constructor is required to be able to deserialize instances of this struct
[BsonConstructor]
public SpawnDataStructParameterless(int identifier, DateTime spawnDate, string spawnText)
{
Identifier = identifier;
SpawnDate = spawnDate;
SpawnText = spawnText;
}
}

public class SpawnDataClass
{
public readonly int Identifier;
public DateTime SpawnDate;
private string spawnText;

public SpawnDataClass(int identifier, DateTime spawnDate)
{
Identifier = identifier;
SpawnDate = spawnDate;
}

public string SpawnText
{
get => spawnText;
set => spawnText = value;
}
}

public struct SpawnDataStruct
{
[BsonElement]
public readonly int Identifier;
public DateTime SpawnDate;
private string spawnText;

public SpawnData(int identifier, DateTime spawnDate)
// this constructor is required for the test to compile
public SpawnDataStruct(int identifier, DateTime spawnDate)
{
Identifier = identifier;
SpawnDate = spawnDate;
spawnText = default;
}

// this constructor is required to be able to deserialize instances of this struct
[BsonConstructor]
public SpawnDataStruct(int identifier, DateTime spawnDate, string spawnText)
{
Identifier = identifier;
SpawnDate = spawnDate;
this.spawnText = spawnText;
}

public string SpawnText
Expand All @@ -132,7 +223,7 @@ public string SpawnText
}
}

public class InheritedSpawnData : SpawnData
public class InheritedSpawnData : SpawnDataClass
{
public InheritedSpawnData(int identifier, DateTime spawnDate)
: base(identifier, spawnDate)
Expand Down

0 comments on commit 3720217

Please sign in to comment.