Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 258 additions & 12 deletions src/GraphQL.EntityFramework/IncludeAppender.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using System.Reflection;

class IncludeAppender(
IReadOnlyDictionary<Type, IReadOnlyDictionary<string, Navigation>> navigations,
IReadOnlyDictionary<Type, List<string>> keyNames,
Expand Down Expand Up @@ -64,29 +66,85 @@ static IQueryable<TItem> AddIncludesFromProjection<TItem>(
FieldProjectionInfo projection)
where TItem : class
{
if (projection.Navigations is not { Count: > 0 })
var visitedTypes = new HashSet<Type> { typeof(TItem) };

if (projection.Navigations is { Count: > 0 })
{
return query;
foreach (var (navName, navProjection) in projection.Navigations)
{
if (IsVisitedOrBaseType(navProjection.EntityType, visitedTypes))
{
continue;
}

visitedTypes.Add(navProjection.EntityType);
query = query.Include(navName);
query = AddNestedIncludes(query, navName, navProjection.Projection, visitedTypes);
visitedTypes.Remove(navProjection.EntityType);
}
}

var visitedTypes = new HashSet<Type> { typeof(TItem) };
// Add derived-type navigation includes for TPH inline fragments
// e.g. query.Include(e => ((GroupAccessRule)e).Group)
if (projection.DerivedNavigations is { Count: > 0 })
{
query = AddDerivedTypeIncludes(query, projection.DerivedNavigations, visitedTypes);
}

foreach (var (navName, navProjection) in projection.Navigations)
return query;
}

static IQueryable<TItem> AddDerivedTypeIncludes<TItem>(
IQueryable<TItem> query,
Dictionary<Type, Dictionary<string, NavigationProjectionInfo>> derivedNavigations,
HashSet<Type> visitedTypes)
where TItem : class
{
var itemType = typeof(TItem);
var parameter = Expression.Parameter(itemType, "e");

foreach (var (derivedType, navDict) in derivedNavigations)
{
if (IsVisitedOrBaseType(navProjection.EntityType, visitedTypes))
// Cast: (DerivedType)e
var cast = Expression.Convert(parameter, derivedType);

foreach (var (navName, navProjection) in navDict)
{
continue;
}
if (IsVisitedOrBaseType(navProjection.EntityType, visitedTypes))
{
continue;
}

visitedTypes.Add(navProjection.EntityType);
query = query.Include(navName);
query = AddNestedIncludes(query, navName, navProjection.Projection, visitedTypes);
visitedTypes.Remove(navProjection.EntityType);
// Property access: ((DerivedType)e).Navigation
var property = derivedType.GetProperty(navName);
if (property == null)
{
continue;
}

var propertyAccess = Expression.Property(cast, property);

// Build lambda: e => ((DerivedType)e).Navigation
var lambda = Expression.Lambda(propertyAccess, parameter);

// Call EntityFrameworkQueryableExtensions.Include(query, lambda)
var includeMethod = GetIncludeMethod(itemType, property.PropertyType);
query = (IQueryable<TItem>)includeMethod.Invoke(null, [query, lambda])!;
}
}

return query;
}

static MethodInfo GetIncludeMethod(Type entityType, Type propertyType) =>
typeof(EntityFrameworkQueryableExtensions)
.GetMethods(BindingFlags.Static | BindingFlags.Public)
.First(_ => _.Name == "Include" &&
_.GetGenericArguments().Length == 2 &&
_.GetParameters().Length == 2 &&
_.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Expression<>))
.MakeGenericMethod(entityType, propertyType);

static IQueryable<TItem> AddNestedIncludes<TItem>(
IQueryable<TItem> query,
string includePath,
Expand Down Expand Up @@ -180,7 +238,195 @@ FieldProjectionInfo GetProjectionInfo(
}
}

return new(scalarFields, keys, foreignKeyNames, navProjections);
// Scan for derived-type navigations from inline fragments (TPH support)
var derivedNavigations = GetDerivedNavigationsFromFragments(context);

return new(scalarFields, keys, foreignKeyNames, navProjections, derivedNavigations);
}

Dictionary<Type, Dictionary<string, NavigationProjectionInfo>>? GetDerivedNavigationsFromFragments(
IResolveFieldContext context)
{
var selectionSet = GetLeafSelectionSet(context);
if (selectionSet?.Selections is null)
{
return null;
}

Dictionary<Type, Dictionary<string, NavigationProjectionInfo>>? result = null;

foreach (var selection in selectionSet.Selections)
{
GraphQLTypeCondition? typeCondition;
GraphQLSelectionSet? fragmentSelectionSet;

switch (selection)
{
case GraphQLInlineFragment inlineFragment:
typeCondition = inlineFragment.TypeCondition;
fragmentSelectionSet = inlineFragment.SelectionSet;
break;
case GraphQLFragmentSpread fragmentSpread:
{
var name = fragmentSpread.FragmentName.Name;
var fragmentDefinition = context.Document.Definitions
.OfType<GraphQLFragmentDefinition>()
.SingleOrDefault(_ => _.FragmentName.Name == name);
if (fragmentDefinition is null)
{
continue;
}

typeCondition = fragmentDefinition.TypeCondition;
fragmentSelectionSet = fragmentDefinition.SelectionSet;
break;
}
default:
continue;
}

if (typeCondition is null || fragmentSelectionSet?.Selections is null)
{
continue;
}

var typeName = typeCondition.Type.Name.StringValue;

// Find the CLR type for this GraphQL type name using the schema
if (!TryFindDerivedClrType(typeName, context.Schema, out var derivedType))
{
continue;
}

// Get navigation properties for this derived type
if (!navigations.TryGetValue(derivedType, out var derivedNavProps))
{
continue;
}

// Process fields in this fragment against the derived type's navigation properties
foreach (var field in fragmentSelectionSet.Selections.OfType<GraphQLField>())
{
var fieldName = field.Name.StringValue;
if (!derivedNavProps.TryGetValue(fieldName, out var navigation))
{
continue;
}

result ??= [];
if (!result.TryGetValue(derivedType, out var derivedNavs))
{
derivedNavs = [];
result[derivedType] = derivedNavs;
}

if (derivedNavs.ContainsKey(navigation.Name))
{
continue;
}

var navType = navigation.Type;
navigations.TryGetValue(navType, out var nestedNavProps);
keyNames.TryGetValue(navType, out var nestedKeys);
foreignKeys.TryGetValue(navType, out var nestedFks);

derivedNavs[navigation.Name] = new(
navType,
navigation.IsCollection,
GetNestedProjection(field.SelectionSet, nestedNavProps, nestedKeys, nestedFks, context));
}
}

return result;
}

/// <summary>
/// Navigate through connection wrapper fields (edges/items/node) to find the leaf selection set
/// that contains the actual entity fields and inline fragments.
/// </summary>
static GraphQLSelectionSet? GetLeafSelectionSet(IResolveFieldContext context)
{
var selectionSet = context.FieldAst.SelectionSet;
if (selectionSet?.Selections is null)
{
return null;
}

// Drill through connection wrapper fields
while (true)
{
var found = false;
foreach (var selection in selectionSet.Selections)
{
if (selection is GraphQLField field && IsConnectionNodeName(field.Name.StringValue))
{
if (field.SelectionSet is not null)
{
selectionSet = field.SelectionSet;
found = true;
break;
}
}
}

if (!found)
{
break;
}
}

return selectionSet;
}

bool TryFindDerivedClrType(string graphQlTypeName, ISchema schema, [NotNullWhen(true)] out Type? clrType)
{
clrType = null;

// Use the schema's type lookup to resolve GraphQL type name → CLR type
var graphType = schema.AllTypes.FirstOrDefault(_ => _.Name == graphQlTypeName);
if (graphType is not null)
{
// Walk the type hierarchy to find the CLR type from the generic arguments
var graphClrType = GetSourceType(graphType.GetType());
if (graphClrType is not null && navigations.ContainsKey(graphClrType))
{
clrType = graphClrType;
return true;
}
}

// Fallback: match CLR type name directly
foreach (var type in navigations.Keys)
{
if (string.Equals(type.Name, graphQlTypeName, StringComparison.OrdinalIgnoreCase))
{
clrType = type;
return true;
}
}

return false;
}

static Type? GetSourceType(Type graphType)
{
var type = graphType;
while (type is not null)
{
if (type.IsGenericType)
{
var genericDef = type.GetGenericTypeDefinition();
if (genericDef == typeof(ObjectGraphType<>) ||
genericDef == typeof(InterfaceGraphType<>))
{
return type.GenericTypeArguments[0];
}
}

type = type.BaseType;
}

return null;
}

void ProcessConnectionNodeFields(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ record FieldProjectionInfo(
HashSet<string> ScalarFields,
List<string>? KeyNames,
IReadOnlySet<string>? ForeignKeyNames,
Dictionary<string, NavigationProjectionInfo>? Navigations);
Dictionary<string, NavigationProjectionInfo>? Navigations,
Dictionary<Type, Dictionary<string, NavigationProjectionInfo>>? DerivedNavigations = null);
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
public class CategoryEntity
{
public Guid Id { get; set; } = Guid.NewGuid();
public string? Name { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
public class CategoryGraphType :
EfObjectGraphType<IntegrationDbContext, CategoryEntity>
{
public CategoryGraphType(IEfGraphQLService<IntegrationDbContext> graphQlService) :
base(graphQlService) =>
AutoMap();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
public class RegionEntity
{
public Guid Id { get; set; } = Guid.NewGuid();
public string? Name { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
public class RegionGraphType :
EfObjectGraphType<IntegrationDbContext, RegionEntity>
{
public RegionGraphType(IEfGraphQLService<IntegrationDbContext> graphQlService) :
base(graphQlService) =>
AutoMap();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
public abstract class TphDerivedNavBaseEntity
{
public Guid Id { get; set; } = Guid.NewGuid();
public string? Property { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
public class TphDerivedNavBaseGraphType(IEfGraphQLService<IntegrationDbContext> graphQlService) :
EfInterfaceGraphType<IntegrationDbContext, TphDerivedNavBaseEntity>(graphQlService);
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
public class TphDerivedNavCategoryEntity : TphDerivedNavBaseEntity
{
public Guid? CategoryId { get; set; }
public CategoryEntity? Category { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
public class TphDerivedNavCategoryGraphType :
EfObjectGraphType<IntegrationDbContext, TphDerivedNavCategoryEntity>
{
public TphDerivedNavCategoryGraphType(IEfGraphQLService<IntegrationDbContext> graphQlService) :
base(graphQlService)
{
AutoMap();
Interface<TphDerivedNavBaseGraphType>();
IsTypeOf = _ => _ is TphDerivedNavCategoryEntity;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
public class TphDerivedNavRegionEntity : TphDerivedNavBaseEntity
{
public Guid? RegionId { get; set; }
public RegionEntity? Region { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
public class TphDerivedNavRegionGraphType :
EfObjectGraphType<IntegrationDbContext, TphDerivedNavRegionEntity>
{
public TphDerivedNavRegionGraphType(IEfGraphQLService<IntegrationDbContext> graphQlService) :
base(graphQlService)
{
AutoMap();
Interface<TphDerivedNavBaseGraphType>();
IsTypeOf = _ => _ is TphDerivedNavRegionEntity;
}
}
Loading
Loading