Skip to content

Commit

Permalink
Query: Don't add DbParameter if already added
Browse files Browse the repository at this point in the history
Resolves #27427

If a FromSql with DbParameter is reused in multiple parts of query then we need to add the DbParameter only once
  • Loading branch information
smitpatel committed Jun 29, 2022
1 parent 813322f commit 0254eee
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/EFCore.Relational/Storage/Internal/RawRelationalParameter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,19 @@ public override void AddDbParameter(DbCommand command, object? value)
value is DbParameter,
$"{nameof(value)} isn't a DbParameter in {nameof(RawRelationalParameter)}.{nameof(AddDbParameter)}");

if (value is DbParameter dbParameter
&& dbParameter.Direction == ParameterDirection.Input
&& value is ICloneable cloneable)
if (value is DbParameter dbParameter)
{
value = cloneable.Clone();
if (!(AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue27427", out var enabled) && enabled)
&& command.Parameters.Contains(dbParameter.ParameterName))
{
return;
}

if (dbParameter.Direction == ParameterDirection.Input
&& value is ICloneable cloneable)
{
value = cloneable.Clone();
}
}

command.Parameters.Add(value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,25 @@ public virtual async Task FromSqlRaw_with_dbParameter_mixed_in_subquery(bool asy
Assert.Equal(26, actual.Length);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Multiple_occurrences_of_FromSql_with_db_parameter_adds_parameter_only_once(bool async)
{
using var context = CreateContext();
var city = "Seattle";
var fromSqlQuery = context.Customers.FromSqlRaw(
NormalizeDelimitersInRawString(@"SELECT * FROM [Customers] WHERE [City] = {0}"),
CreateDbParameter("city", city));

var query = fromSqlQuery.Intersect(fromSqlQuery);

var actual = async
? await query.ToArrayAsync()
: query.ToArray();

Assert.Single(actual);
}

protected string NormalizeDelimitersInRawString(string sql)
=> Fixture.TestStore.NormalizeDelimitersInRawString(sql);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,24 @@ SELECT 1
WHERE [m].[CustomerID] = [o].[CustomerID])");
}

public override async Task Multiple_occurrences_of_FromSql_with_db_parameter_adds_parameter_only_once(bool async)
{
await base.Multiple_occurrences_of_FromSql_with_db_parameter_adds_parameter_only_once(async);

AssertSql(
@"city='Seattle' (Nullable = false) (Size = 7)
SELECT [m].[CustomerID], [m].[Address], [m].[City], [m].[CompanyName], [m].[ContactName], [m].[ContactTitle], [m].[Country], [m].[Fax], [m].[Phone], [m].[PostalCode], [m].[Region]
FROM (
SELECT * FROM ""Customers"" WHERE ""City"" = @city
) AS [m]
INTERSECT
SELECT [m0].[CustomerID], [m0].[Address], [m0].[City], [m0].[CompanyName], [m0].[ContactName], [m0].[ContactTitle], [m0].[Country], [m0].[Fax], [m0].[Phone], [m0].[PostalCode], [m0].[Region]
FROM (
SELECT * FROM ""Customers"" WHERE ""City"" = @city
) AS [m0]");
}

protected override DbParameter CreateDbParameter(string name, object value)
=> new SqlParameter { ParameterName = name, Value = value };

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Linq;
using System.Threading.Tasks;
using Microsoft.Data.SqlClient;
using Microsoft.EntityFrameworkCore.TestUtilities;
using Xunit;

// ReSharper disable InconsistentNaming
namespace Microsoft.EntityFrameworkCore.Query
Expand Down Expand Up @@ -297,5 +300,66 @@ WHERE [c].[SomeNullableDateTime] IS NULL
) AS [t0] ON [p].[Id] = [t0].[ParentId]
WHERE [t0].[SomeOtherNullableDateTime] IS NOT NULL");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Muliple_occurrences_of_FromSql_in_group_by_aggregate(bool async)
{
var contextFactory = await InitializeAsync<Context27427>();
using var context = contextFactory.CreateContext();
var query = context.DemoEntities
.FromSqlRaw("SELECT * FROM DemoEntities WHERE Id = {0}", new SqlParameter { Value = 1 })
.Select(e => e.Id);

var query2 = context.DemoEntities
.Where(e => query.Contains(e.Id))
.GroupBy(e => e.Id)
.Select(g => new { g.Key, Aggregate = g.Count() });

if (async)
{
await query2.ToListAsync();
}
else
{
query2.ToList();
}

AssertSql(
@"p0='1'
SELECT [d].[Id] AS [Key], (
SELECT COUNT(*)
FROM [DemoEntities] AS [d0]
WHERE EXISTS (
SELECT 1
FROM (
SELECT * FROM DemoEntities WHERE Id = @p0
) AS [m0]
WHERE [m0].[Id] = [d0].[Id]) AND ([d].[Id] = [d0].[Id])) AS [Aggregate]
FROM [DemoEntities] AS [d]
WHERE EXISTS (
SELECT 1
FROM (
SELECT * FROM DemoEntities WHERE Id = @p0
) AS [m]
WHERE [m].[Id] = [d].[Id])
GROUP BY [d].[Id]");
}

protected class Context27427 : DbContext
{
public Context27427(DbContextOptions options)
: base(options)
{
}

public DbSet<DemoEntity> DemoEntities { get; set; }
}

protected class DemoEntity
{
public int Id { get; set; }
}
}
}

0 comments on commit 0254eee

Please sign in to comment.