Skip to content

Commit

Permalink
CSHARP-4656 Simplify A : "$A" to A : 1 only on find (#1087)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanych-sun authored and rstam committed May 24, 2023
1 parent 5a973c2 commit 1ef7dd3
Show file tree
Hide file tree
Showing 22 changed files with 169 additions and 55 deletions.
2 changes: 1 addition & 1 deletion src/MongoDB.Driver/FindFluent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ public override string ToString()

if (_options.Projection != null)
{
var renderedProjection = Render(_options.Projection.Render);
var renderedProjection = Render(_options.Projection.RenderForFind);
if (renderedProjection.Document != null)
{
sb.Append(", " + renderedProjection.Document.ToString());
Expand Down
8 changes: 3 additions & 5 deletions src/MongoDB.Driver/IFindFluentExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@
*/

using System;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;
using MongoDB.Bson;
using MongoDB.Bson.Serialization;
using MongoDB.Driver.Core.Misc;

namespace MongoDB.Driver
Expand Down Expand Up @@ -59,7 +57,7 @@ public static IFindFluent<TDocument, TNewProjection> Project<TDocument, TProject
Ensure.IsNotNull(find, nameof(find));
Ensure.IsNotNull(projection, nameof(projection));

return find.Project<TNewProjection>(new FindExpressionProjectionDefinition<TDocument, TNewProjection>(projection));
return find.Project<TNewProjection>(new ExpressionProjectionDefinition<TDocument, TNewProjection>(projection, null));
}

/// <summary>
Expand All @@ -75,7 +73,7 @@ public static IOrderedFindFluent<TDocument, TProjection> SortBy<TDocument, TProj
Ensure.IsNotNull(find, nameof(find));
Ensure.IsNotNull(field, nameof(field));

// We require an implementation of IFindFluent<TDocument, TProjection>
// We require an implementation of IFindFluent<TDocument, TProjection>
// to also implement IOrderedFindFluent<TDocument, TProjection>
return (IOrderedFindFluent<TDocument, TProjection>)find.Sort(
new DirectionalSortDefinition<TDocument>(new ExpressionFieldDefinition<TDocument>(field), SortDirection.Ascending));
Expand All @@ -94,7 +92,7 @@ public static IOrderedFindFluent<TDocument, TProjection> SortByDescending<TDocum
Ensure.IsNotNull(find, nameof(find));
Ensure.IsNotNull(field, nameof(field));

// We require an implementation of IFindFluent<TDocument, TProjection>
// We require an implementation of IFindFluent<TDocument, TProjection>
// to also implement IOrderedFindFluent<TDocument, TProjection>
return (IOrderedFindFluent<TDocument, TProjection>)find.Sort(
new DirectionalSortDefinition<TDocument>(new ExpressionFieldDefinition<TDocument>(field), SortDirection.Descending));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Stages;

namespace MongoDB.Driver.Linq.Linq3Implementation.Ast.Optimizers
{
internal class AstFindProjectionSimplifier : AstSimplifier
{
public override AstNode VisitProjectStageSetFieldSpecification(AstProjectStageSetFieldSpecification node)
{
node = (AstProjectStageSetFieldSpecification)base.VisitProjectStageSetFieldSpecification(node);

// { path : '$path' } => { path : 1 }
if (node.Value is AstFieldPathExpression fieldPathExpression &&
fieldPathExpression.Path == $"${node.Path}")
{
return AstProject.Include(node.Path);
}

return node;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -307,20 +307,6 @@ static AstExpression UltimateGetFieldInput(AstGetFieldExpression getField)
}
}

public override AstNode VisitProjectStageSetFieldSpecification(AstProjectStageSetFieldSpecification node)
{
node = (AstProjectStageSetFieldSpecification)base.VisitProjectStageSetFieldSpecification(node);

// { path : '$path' } => { path : 1 }
if (node.Value is AstFieldPathExpression fieldPathExpression &&
fieldPathExpression.Path == $"${node.Path}")
{
return AstProject.Include(node.Path);
}

return node;
}

public override AstNode VisitUnaryExpression(AstUnaryExpression node)
{
// { $first : <arg> } => { $arrayElemAt : [<arg>, 0] } (or -1 for $last)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,7 @@ internal override RenderedProjectionDefinition<TProjection> TranslateExpressionT
Expression<Func<TSource, TProjection>> expression,
IBsonSerializer<TSource> sourceSerializer,
IBsonSerializerRegistry serializerRegistry)
{
return TranslateExpressionToProjection(expression, sourceSerializer, serializerRegistry, translationOptions: null);
}
=> TranslateExpressionToProjectionInternal(expression, sourceSerializer, new AstFindProjectionSimplifier());

internal override RenderedProjectionDefinition<TOutput> TranslateExpressionToGroupProjection<TInput, TKey, TOutput>(
Expression<Func<TInput, TKey>> idExpression,
Expand All @@ -158,12 +156,18 @@ internal override RenderedProjectionDefinition<TOutput> TranslateExpressionToPro
IBsonSerializer<TInput> inputSerializer,
IBsonSerializerRegistry serializerRegistry,
ExpressionTranslationOptions translationOptions)
=> TranslateExpressionToProjectionInternal(expression, inputSerializer, new AstSimplifier());

private RenderedProjectionDefinition<TOutput> TranslateExpressionToProjectionInternal<TInput, TOutput>(
Expression<Func<TInput, TOutput>> expression,
IBsonSerializer<TInput> inputSerializer,
AstSimplifier simplifier)
{
expression = (Expression<Func<TInput, TOutput>>)PartialEvaluator.EvaluatePartially(expression);
var context = TranslationContext.Create(expression, inputSerializer);
var translation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, expression, inputSerializer, asRoot: true);
var (projectStage, projectionSerializer) = ProjectionHelper.CreateProjectStage(translation);
var simplifiedProjectStage = AstSimplifier.Simplify(projectStage);
var simplifiedProjectStage = simplifier.Visit(projectStage);
var renderedProjection = simplifiedProjectStage.Render().AsBsonDocument["$project"].AsBsonDocument;

return new RenderedProjectionDefinition<TOutput>(renderedProjection, (IBsonSerializer<TOutput>)projectionSerializer);
Expand Down
2 changes: 1 addition & 1 deletion src/MongoDB.Driver/MongoCollectionImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,7 @@ private FindOneAndUpdateOperation<TProjection> CreateFindOneAndUpdateOperation<T
private FindOperation<TProjection> CreateFindOperation<TProjection>(FilterDefinition<TDocument> filter, FindOptions<TDocument, TProjection> options)
{
var projection = options.Projection ?? new ClientSideDeserializationProjectionDefinition<TDocument, TProjection>();
var renderedProjection = projection.Render(_documentSerializer, _settings.SerializerRegistry, _linqProvider);
var renderedProjection = projection.RenderForFind(_documentSerializer, _settings.SerializerRegistry, _linqProvider);

return new FindOperation<TProjection>(
_collectionNamespace,
Expand Down
21 changes: 18 additions & 3 deletions src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ public static PipelineStageDefinition<TInput, TOutput> Project<TInput, TOutput>(
ExpressionTranslationOptions translationOptions = null)
{
Ensure.IsNotNull(projection, nameof(projection));
return Project(new ProjectExpressionProjection<TInput, TOutput>(projection, translationOptions));
return Project(new ExpressionProjectionDefinition<TInput, TOutput>(projection, translationOptions));
}

/// <summary>
Expand Down Expand Up @@ -1905,6 +1905,11 @@ public override RenderedProjectionDefinition<TOutput> Render(IBsonSerializer<TIn

return linqProvider.GetAdapter().TranslateExpressionToBucketOutputProjection(_valueExpression, _outputExpression, documentSerializer, serializerRegistry, _translationOptions);
}

internal override RenderedProjectionDefinition<TOutput> RenderForFind(IBsonSerializer<TInput> sourceSerializer, IBsonSerializerRegistry serializerRegistry, LinqProvider linqProvider)
{
throw new InvalidOperationException();
}
}

internal sealed class GroupExpressionProjection<TInput, TKey, TOutput> : ProjectionDefinition<TInput, TOutput>
Expand Down Expand Up @@ -1938,14 +1943,19 @@ public override RenderedProjectionDefinition<TOutput> Render(IBsonSerializer<TIn
}
return linqProvider.GetAdapter().TranslateExpressionToGroupProjection(_idExpression, _groupExpression, documentSerializer, serializerRegistry, _translationOptions);
}

internal override RenderedProjectionDefinition<TOutput> RenderForFind(IBsonSerializer<TInput> sourceSerializer, IBsonSerializerRegistry serializerRegistry, LinqProvider linqProvider)
{
throw new InvalidOperationException();
}
}

internal sealed class ProjectExpressionProjection<TInput, TOutput> : ProjectionDefinition<TInput, TOutput>
internal sealed class ExpressionProjectionDefinition<TInput, TOutput> : ProjectionDefinition<TInput, TOutput>
{
private readonly Expression<Func<TInput, TOutput>> _expression;
private readonly ExpressionTranslationOptions _translationOptions;

public ProjectExpressionProjection(Expression<Func<TInput, TOutput>> expression, ExpressionTranslationOptions translationOptions)
public ExpressionProjectionDefinition(Expression<Func<TInput, TOutput>> expression, ExpressionTranslationOptions translationOptions)
{
_expression = Ensure.IsNotNull(expression, nameof(expression));
_translationOptions = translationOptions; // can be null
Expand All @@ -1960,6 +1970,11 @@ public override RenderedProjectionDefinition<TOutput> Render(IBsonSerializer<TIn
{
return linqProvider.GetAdapter().TranslateExpressionToProjection(_expression, inputSerializer, serializerRegistry, _translationOptions);
}

internal override RenderedProjectionDefinition<TOutput> RenderForFind(IBsonSerializer<TInput> sourceSerializer, IBsonSerializerRegistry serializerRegistry, LinqProvider linqProvider)
{
return linqProvider.GetAdapter().TranslateExpressionToFindProjection(_expression, sourceSerializer, serializerRegistry);
}
}

internal class SortPipelineStageDefinition<TInput> : PipelineStageDefinition<TInput, TInput>
Expand Down
3 changes: 3 additions & 0 deletions src/MongoDB.Driver/ProjectionDefinition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ public virtual RenderedProjectionDefinition<TProjection> Render(IBsonSerializer<
/// <returns>A <see cref="RenderedProjectionDefinition{TProjection}"/>.</returns>
public abstract RenderedProjectionDefinition<TProjection> Render(IBsonSerializer<TSource> sourceSerializer, IBsonSerializerRegistry serializerRegistry, LinqProvider linqProvider);

internal virtual RenderedProjectionDefinition<TProjection> RenderForFind(IBsonSerializer<TSource> sourceSerializer, IBsonSerializerRegistry serializerRegistry, LinqProvider linqProvider)
=> Render(sourceSerializer, serializerRegistry, linqProvider);

/// <summary>
/// Performs an implicit conversion from <see cref="BsonDocument"/> to <see cref="ProjectionDefinition{TSource, TProjection}"/>.
/// </summary>
Expand Down
2 changes: 1 addition & 1 deletion src/MongoDB.Driver/ProjectionDefinitionBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ public ProjectionDefinition<TSource> Exclude(Expression<Func<TSource, object>> f
/// </returns>
public ProjectionDefinition<TSource, TProjection> Expression<TProjection>(Expression<Func<TSource, TProjection>> expression)
{
return new FindExpressionProjectionDefinition<TSource, TProjection>(expression);
return new ExpressionProjectionDefinition<TSource, TProjection>(expression, null);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System;
using System.Linq.Expressions;
using FluentAssertions;
using MongoDB.Bson.Serialization;
using Xunit;

namespace MongoDB.Driver.Tests
{
public class FindExpressionProjectionDefinitionTests
{
[Fact]
public void Projection_to_class_should_work()
=> AssertProjection(
x => new Projection { A = x.A, X = x.B },
"{ A : 1, X : '$B', _id : 0 }");

[Fact]
public void Projection_to_anonymous_type_should_work()
=> AssertProjection(
x => new { x.A, X = x.B },
"{ A : 1, X : '$B', _id : 0 }");

private void AssertProjection<TProjection>(
Expression<Func<Document, TProjection>> expression,
string expectedProjection)
{
var projection = new FindExpressionProjectionDefinition<Document, TProjection>(expression);

var renderedProjection = projection.Render(
BsonSerializer.LookupSerializer<Document>(),
BsonSerializer.SerializerRegistry);

renderedProjection.Document.Should().BeEquivalentTo(expectedProjection);
}

private class Document
{
public string A { get; set; }

public int B { get; set; }
}

private class Projection
{
public string A { get; set; }

public int X { get; set; }
}
}
}
2 changes: 1 addition & 1 deletion tests/MongoDB.Driver.Tests/IFindFluentExtensionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ public void SortByDescending_ThenByDescending_should_generate_the_correct_sort()

private static void AssertProjection<TResult>(IFindFluent<Person, TResult> subject, BsonDocument expectedProjection, LinqProvider linqProvider = LinqProvider.V3)
{
Assert.Equal(expectedProjection, subject.Options.Projection.Render(BsonSerializer.SerializerRegistry.GetSerializer<Person>(), BsonSerializer.SerializerRegistry, linqProvider).Document);
Assert.Equal(expectedProjection, subject.Options.Projection.RenderForFind(BsonSerializer.SerializerRegistry.GetSerializer<Person>(), BsonSerializer.SerializerRegistry, linqProvider).Document);
}

private static void AssertSort(IFindFluent<Person, Person> subject, BsonDocument expectedSort)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ public void Distinct_document_preceded_by_select_where()

Assert(query,
1,
"{ $project: { 'A': 1, 'B': 1, '_id': 0 } }",
"{ $match: { 'A': 'Awesome' } }",
"{ $group: { '_id': '$$ROOT' } }",
"{ $project: { 'A' : '$A', 'B' : '$B', '_id': 0 } }",
"{ $match: { 'A' : 'Awesome' } }",
"{ $group: { '_id' : '$$ROOT' } }",
"{ $replaceRoot : { newRoot : '$_id' } }");
}

Expand All @@ -233,7 +233,7 @@ public void Distinct_document_preceded_by_where_select()
Assert(query,
1,
"{ $match : { 'A' : 'Awesome' } }",
"{ $project : { A : 1, B : 1, _id : 0 } }",
"{ $project : { A : '$A', B : '$B', _id : 0 } }",
"{ $group : { '_id' : '$$ROOT' } }",
"{ $replaceRoot : { newRoot : '$_id' } }");
}
Expand Down Expand Up @@ -999,7 +999,7 @@ public void Select_new_of_same()

Assert(query,
2,
"{ $project : { _id : 1, A : 1 } }");
"{ $project : { _id : '$_id', A : '$A' } }");
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public void Should_translate_just_id()
AssertStages(
result.Stages,
"{ $group : { _id : '$A' } }",
"{ $project : { _id : 1 } }");
"{ $project : { _id : '$_id' } }");

result.Value._id.Should().Be("Amazing");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public void Select_new_Person_should_work()
.Select(p => new Person { Id = p.Id, Name = p.Name });

var stages = Translate(collection, queryable);
AssertStages(stages, "{ $project : { _id : 1, Name : 1 } }");
AssertStages(stages, "{ $project : { _id : '$_id', Name : '$Name' } }");

var result = queryable.ToList().Single();
result.ShouldBeEquivalentTo(new Person { Id = 1, Name = "A" });
Expand All @@ -57,7 +57,7 @@ public void Select_new_Person_without_Name_should_work()
.Select(p => new Person { Id = p.Id });

var stages = Translate(collection, queryable);
AssertStages(stages, "{ $project : { _id : 1 } }");
AssertStages(stages, "{ $project : { _id : '$_id' } }");

var result = queryable.ToList().Single();
result.ShouldBeEquivalentTo(new Person { Id = 1, Name = null });
Expand All @@ -71,7 +71,7 @@ public void Select_new_Person_without_Id_should_work()
.Select(p => new Person { Name = p.Name });

var stages = Translate(collection, queryable);
AssertStages(stages, "{ $project : { Name : 1, _id : 0 } }");
AssertStages(stages, "{ $project : { Name : '$Name', _id : 0 } }");

var result = queryable.ToList().Single();
result.ShouldBeEquivalentTo(new Person { Id = 0, Name = "A" });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public void Nested_Select_should_work()
}",
@"{
'$project':{
'_id':1,
'_id':'$_id',
'ParentName':'$Name',
'Children':{
'$map':{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public void Test()
});

var stages = Translate(collection, queryable);
AssertStages(stages, "{ $project : { _id : 1, PageCount : 1, Author : { $cond : { if : { $eq : ['$Author', null] }, then : null, else : { _id : '$Author._id', Name : '$Author.Name' } } } } }");
AssertStages(stages, "{ $project : { _id : '$_id', PageCount : '$PageCount', Author : { $cond : { if : { $eq : ['$Author', null] }, then : null, else : { _id : '$Author._id', Name : '$Author.Name' } } } } }");

var results = queryable.ToList().OrderBy(r => r.Id).ToList();
results.Should().HaveCount(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public void Select_with_constructor_call_should_work()
var stages = Translate(collection, queryable);
AssertStages(
stages,
"{ $project : { X : 1, _id : 0 } }",
"{ $project : { X : '$X', _id : 0 } }",
"{ $project : { R : '$X', S : '$Y', _id : 0 } }",
"{ $project : { T : '$R', U : '$S', _id : 0 } }");
}
Expand All @@ -67,7 +67,7 @@ public void Select_with_constructor_call_and_property_set_should_work()
var stages = Translate(collection, queryable);
AssertStages(
stages,
"{ $project : { X : 1, Y : { $literal : 123 }, _id : 0 } }",
"{ $project : { X : '$X', Y : { $literal : 123 }, _id : 0 } }",
"{ $project : { R : '$X', S : '$Y', _id : 0 } }",
"{ $project : { T : '$R', U : '$S', _id : 0 } }");
}
Expand Down
Loading

0 comments on commit 1ef7dd3

Please sign in to comment.