Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics.CodeAnalysis;
using Microsoft.EntityFrameworkCore.Internal;
using ExpressionExtensions = Microsoft.EntityFrameworkCore.Infrastructure.ExpressionExtensions;

Expand Down Expand Up @@ -597,6 +598,33 @@ protected override Expression VisitMember(MemberExpression memberExpression)
{
return memberExpression;
}

if (memberExpression.Expression is NavigationTreeExpression navigationTreeExpression
&& navigationTreeExpression.Value is NewExpression newExpression
&& newExpression.Members != null)
{
for (var i = 0; i < newExpression.Members.Count; i++)
{
if (newExpression.Members[i] == memberExpression.Member)
{
var argument = newExpression.Arguments[i];
var newRoot = Expression.MakeMemberAccess(navigationTreeExpression, newExpression.Members[i]);

if (argument is EntityReference entityReference)
{
return ExpandInclude(newRoot, entityReference);
}

if (argument is NewExpression innerNewExpression
&& ReconstructAnonymousType(newRoot, innerNewExpression, out var replacement))
{
return replacement;
}

return newRoot;
}
}
}
}

return base.VisitMember(memberExpression);
Expand Down Expand Up @@ -1052,6 +1080,88 @@ MethodCallExpression e when e.Method.IsEFPropertyMethod()
}
}

private sealed class NavigationTreeMemberPruningVisitor : ExpressionVisitor
{
private readonly Stack<Expression> _knownFalseTests = new();

protected override Expression VisitConditional(ConditionalExpression node)
{
if (node.IfTrue is ConstantExpression { Value: null })
{
var test = Visit(node.Test);

_knownFalseTests.Push(test);
var ifFalse = Visit(node.IfFalse);
_knownFalseTests.Pop();

return node.Update(test, node.IfTrue, ifFalse);
}

return base.VisitConditional(node);
}

protected override Expression VisitMember(MemberExpression node)
{
var innerExpression = Visit(node.Expression);

if (innerExpression is NewExpression { Members: not null } newExpression)
{
for (var i = 0; i < newExpression.Members.Count; i++)
{
if (newExpression.Members[i] == node.Member)
{
return newExpression.Arguments[i];
}
}
}

if (innerExpression is ConditionalExpression { IfTrue: ConstantExpression { Value: null } } conditional
&& node.Member.DeclaringType!.IsAssignableFrom(conditional.IfFalse.Type))
{
foreach (var knownFalseTest in _knownFalseTests)
{
if (ExpressionEqualityComparer.Instance.Equals(knownFalseTest, conditional.Test))
{
return VisitMember(Expression.MakeMemberAccess(conditional.IfFalse, node.Member));
}
}
}
return node.Update(innerExpression);
}

protected override Expression VisitBinary(BinaryExpression node)
{
if (node.NodeType is ExpressionType.Equal or ExpressionType.NotEqual)
{
var left = Visit(node.Left);
var right = Visit(node.Right);

Expression? conditionalTest = null;
if (left is ConditionalExpression { IfTrue: ConstantExpression { Value: null }, IfFalse: NewExpression } leftCond
&& right is ConstantExpression { Value: null })
{
conditionalTest = leftCond.Test;
}
else if (right is ConditionalExpression { IfTrue: ConstantExpression { Value: null }, IfFalse: NewExpression } rightCond
&& left is ConstantExpression { Value: null })
{
conditionalTest = rightCond.Test;
}

if (conditionalTest != null)
{
return node.NodeType == ExpressionType.Equal
? conditionalTest
: Expression.Not(conditionalTest);
}

return node.Update(left, node.Conversion, right);
}

return base.VisitBinary(node);
}
}

/// <summary>
/// Marks <see cref="EntityReference" /> as nullable when coming from a left join.
/// Nullability is required to figure out if the navigation from this entity should be a left join or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,8 @@ private static NavigationExpansionExpression ProcessSelect(NavigationExpansionEx
source.PendingSelector,
selector.Body);

selectorBody = new NavigationTreeMemberPruningVisitor().Visit(selectorBody);
Comment thread
AndriySvyryd marked this conversation as resolved.

source.ApplySelector(selectorBody);

return source;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,169 @@ public class OtherEntity
}

#endregion

#region ConditionalProjection

[Theory, InlineData(false), InlineData(true)]
public virtual async Task Consecutive_selects_with_conditional_projection_should_not_include_unnecessary_joins(bool async)
{
var contextFactory = await InitializeNonSharedTest<ContextConditionalProjection>(
seed: c => c.SeedAsync());

using var context = contextFactory.CreateDbContext();

var query = context.Users
.Select(x => new
{
x.Id,
Job = x.Job == null ? null : new
{
x.Job.Id,
Address = new
{
x.Job.Address.Id,
x.Job.Address.Street
}
}
})
.Select(x => new
{
x.Id,
Job = x.Job == null ? null : new
{
x.Job.Id
}
})
.Where(x => x.Id == 1);

var result = async ? await query.FirstOrDefaultAsync() : query.FirstOrDefault();

Assert.NotNull(result);
Assert.Equal(1, result.Id);
Assert.NotNull(result.Job);
Assert.Equal(1, result.Job!.Id);
}

[Theory, InlineData(false), InlineData(true)]
public virtual async Task Consecutive_selects_with_conditional_projection_null_navigation_returns_null(bool async)
Comment thread
AndriySvyryd marked this conversation as resolved.
{
var contextFactory = await InitializeNonSharedTest<ContextConditionalProjection>(
seed: c => c.SeedUserWithNoJobAsync());

using var context = contextFactory.CreateDbContext();

var query = context.Users
.Where(x => x.JobId == null)
.Select(x => new
{
x.Id,
Job = x.Job == null ? null : new
{
x.Job.Id,
Address = new
{
x.Job.Address.Id,
x.Job.Address.Street
}
}
})
.Select(x => new
{
x.Id,
Job = x.Job == null ? null : new
{
x.Job.Id
}
});

var result = async ? await query.FirstOrDefaultAsync() : query.FirstOrDefault();

Assert.NotNull(result);
Assert.Null(result.Job);
}

[Theory, InlineData(false), InlineData(true)]
public virtual async Task Consecutive_selects_with_conditional_projection_nested_navigation_accessed_includes_join(bool async)
{
var contextFactory = await InitializeNonSharedTest<ContextConditionalProjection>(
seed: c => c.SeedAsync());

using var context = contextFactory.CreateDbContext();

var query = context.Users
.Select(x => new
{
x.Id,
Job = x.Job == null ? null : new
{
x.Job.Id,
Address = new
{
x.Job.Address.Id,
x.Job.Address.Street
}
}
})
.Select(x => new
{
x.Id,
Job = x.Job == null ? null : new
{
x.Job.Id,
AddressId = x.Job.Address.Id
}
})
.Where(x => x.Id == 1);

var result = async ? await query.FirstOrDefaultAsync() : query.FirstOrDefault();

Assert.NotNull(result);
Assert.Equal(1, result.Id);
Assert.NotNull(result.Job);
Assert.Equal(1, result.Job.AddressId);
}

protected class ContextConditionalProjection(DbContextOptions options) : DbContext(options)
{
public DbSet<User> Users { get; set; }

public async Task SeedAsync()
{
var address = new Address { Street = "123 Main St" };
var job = new Job { Address = address };
var user = new User { Job = job };

Add(user);
await SaveChangesAsync();
}

public async Task SeedUserWithNoJobAsync()
{
var user = new User();
Add(user);
await SaveChangesAsync();
}

public class User
{
public long Id { get; set; }
public long? JobId { get; set; }
public Job Job { get; set; }
}

public class Job
{
public long Id { get; set; }
public long AddressId { get; set; }
public Address Address { get; set; }
}

public class Address
{
public long Id { get; set; }
public string Street { get; set; }
}
}

#endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,55 @@ FROM [IdentityDocument] AS [i0]
""");
}

public override async Task Consecutive_selects_with_conditional_projection_should_not_include_unnecessary_joins(bool async)
{
await base.Consecutive_selects_with_conditional_projection_should_not_include_unnecessary_joins(async);

AssertSql(
"""
SELECT TOP(1) [u].[Id], CASE
WHEN [j].[Id] IS NULL THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END, [j].[Id]
FROM [Users] AS [u]
LEFT JOIN [Job] AS [j] ON [u].[JobId] = [j].[Id]
WHERE [u].[Id] = CAST(1 AS bigint)
""");
}

public override async Task Consecutive_selects_with_conditional_projection_null_navigation_returns_null(bool async)
{
await base.Consecutive_selects_with_conditional_projection_null_navigation_returns_null(async);

AssertSql(
"""
SELECT TOP(1) [u].[Id], CASE
WHEN [j].[Id] IS NULL THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END, [j].[Id]
FROM [Users] AS [u]
LEFT JOIN [Job] AS [j] ON [u].[JobId] = [j].[Id]
WHERE [u].[JobId] IS NULL
""");
}

public override async Task Consecutive_selects_with_conditional_projection_nested_navigation_accessed_includes_join(bool async)
{
await base.Consecutive_selects_with_conditional_projection_nested_navigation_accessed_includes_join(async);

AssertSql(
"""
SELECT TOP(1) [u].[Id], CASE
WHEN [j].[Id] IS NULL THEN CAST(1 AS bit)
ELSE CAST(0 AS bit)
END, [j].[Id], [a].[Id]
FROM [Users] AS [u]
LEFT JOIN [Job] AS [j] ON [u].[JobId] = [j].[Id]
LEFT JOIN [Address] AS [a] ON [j].[AddressId] = [a].[Id]
WHERE [u].[Id] = CAST(1 AS bigint)
""");
}

public override async Task Using_explicit_interface_implementation_as_navigation_works()
{
await base.Using_explicit_interface_implementation_as_navigation_works();
Expand Down
Loading
Loading