Skip to content

Handle comparers for nullable value types in primitive collections #35235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/EFCore/Storage/TypeMappingSourceBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ protected virtual bool TryFindJsonCollectionMapping(
elementReader);

elementComparer = (ValueComparer?)Activator.CreateInstance(
elementType.IsNullableValueType()
elementType.IsNullableValueType() || elementMapping.Comparer.Type.IsNullableValueType()
? typeof(ListOfNullableValueTypesComparer<,>).MakeGenericType(typeToInstantiate, elementType.UnwrapNullableType())
: elementType.IsValueType
? typeof(ListOfValueTypesComparer<,>).MakeGenericType(typeToInstantiate, elementType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2100,6 +2100,82 @@ FROM root c

#endregion Cosmos-specific tests

public override async Task Parameter_collection_of_structs_Contains_struct(bool async)
{
// Always throws for sync before getting to the exception to test.
if (async)
{
// Requires collections of converted elements
await Assert.ThrowsAsync<InvalidOperationException>(() => base.Parameter_collection_of_structs_Contains_struct(async));

AssertSql();
}
}

public override async Task Parameter_collection_of_structs_Contains_nullable_struct(bool async)
{
// Always throws for sync before getting to the exception to test.
if (async)
{
// Requires collections of converted elements
await Assert.ThrowsAsync<InvalidOperationException>(() => base.Parameter_collection_of_structs_Contains_nullable_struct(async));

AssertSql();
}
}

public override async Task Parameter_collection_of_structs_Contains_nullable_struct_with_nullable_comparer(bool async)
{
// Always throws for sync before getting to the exception to test.
if (async)
{
// Requires collections of converted elements
await Assert.ThrowsAsync<InvalidOperationException>(
() => base.Parameter_collection_of_structs_Contains_nullable_struct_with_nullable_comparer(async));

AssertSql();
}
}

public override async Task Parameter_collection_of_nullable_structs_Contains_struct(bool async)
{
// Always throws for sync before getting to the exception to test.
if (async)
{
// Requires collections of converted elements
await Assert.ThrowsAsync<InvalidOperationException>(
() => base.Parameter_collection_of_nullable_structs_Contains_struct(async));

AssertSql();
}
}

public override async Task Parameter_collection_of_nullable_structs_Contains_nullable_struct(bool async)
{
// Always throws for sync before getting to the exception to test.
if (async)
{
// Requires collections of converted elements
await Assert.ThrowsAsync<InvalidOperationException>(
() => base.Parameter_collection_of_nullable_structs_Contains_nullable_struct(async));

AssertSql();
}
}

public override async Task Parameter_collection_of_nullable_structs_Contains_nullable_struct_with_nullable_comparer(bool async)
{
// Always throws for sync before getting to the exception to test.
if (async)
{
// Requires collections of converted elements
await Assert.ThrowsAsync<InvalidOperationException>(
() => base.Parameter_collection_of_nullable_structs_Contains_nullable_struct_with_nullable_comparer(async));

AssertSql();
}
}

[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,16 @@ await AssertQuery(
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_of_ints_Contains_nullable_int(bool async)
{
var ints = new int?[] { 10, 999 };
var ints = new[] { 10, 999 };

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => ints.Contains(c.NullableInt)));
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => ints.Contains(c.NullableInt!.Value)),
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => c.NullableInt != null && ints.Contains(c.NullableInt!.Value)));
await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !ints.Contains(c.NullableInt)));
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !ints.Contains(c.NullableInt!.Value)),
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => c.NullableInt == null || !ints.Contains(c.NullableInt!.Value)));
}

[ConditionalTheory]
Expand Down Expand Up @@ -405,6 +407,114 @@ await AssertQuery(
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !nullableInts.Contains(c.NullableInt)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_of_structs_Contains_struct(bool async)
{
var values = new List<WrappedId> { new(22), new(33) };

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => values.Contains(c.WrappedId)));

values = new List<WrappedId> { new(11), new(44) };

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !values.Contains(c.WrappedId)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_of_structs_Contains_nullable_struct(bool async)
{
var values = new List<WrappedId> { new(22), new(33) };

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => values.Contains(c.NullableWrappedId!.Value)),
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => c.NullableWrappedId != null && values.Contains(c.NullableWrappedId.Value)));

values = new List<WrappedId> { new(11), new(44) };

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !values.Contains(c.NullableWrappedId!.Value)),
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => c.NullableWrappedId == null || !values.Contains(c.NullableWrappedId!.Value)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))] // Issue #35117
public virtual async Task Parameter_collection_of_structs_Contains_nullable_struct_with_nullable_comparer(bool async)
{
var values = new List<WrappedId> { new(22), new(33) };

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => values.Contains(c.NullableWrappedIdWithNullableComparer!.Value)),
ss => ss.Set<PrimitiveCollectionsEntity>().Where(
c => c.NullableWrappedIdWithNullableComparer != null && values.Contains(c.NullableWrappedIdWithNullableComparer.Value)));

values = new List<WrappedId> { new(11), new(44) };

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !values.Contains(c.NullableWrappedId!.Value)),
ss => ss.Set<PrimitiveCollectionsEntity>().Where(
c => c.NullableWrappedIdWithNullableComparer == null || !values.Contains(c.NullableWrappedIdWithNullableComparer!.Value)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_of_nullable_structs_Contains_struct(bool async)
{
var values = new List<WrappedId?> { null, new(22) };

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => values.Contains(c.WrappedId)));

values = new List<WrappedId?> { new(11), new(44) };

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !values.Contains(c.WrappedId)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_of_nullable_structs_Contains_nullable_struct(bool async)
{
var values = new List<WrappedId?> { null, new(22) };

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => values.Contains(c.NullableWrappedId)));

values = new List<WrappedId?> { new(11), new(44) };

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !values.Contains(c.NullableWrappedId)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_of_nullable_structs_Contains_nullable_struct_with_nullable_comparer(bool async)
{
var values = new List<WrappedId?> { null, new(22) };

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => values.Contains(c.NullableWrappedIdWithNullableComparer)));

values = new List<WrappedId?> { new(11), new(44) };

await AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(c => !values.Contains(c.NullableWrappedIdWithNullableComparer)));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Parameter_collection_of_strings_Contains_string(bool async)
Expand Down Expand Up @@ -1370,7 +1480,21 @@ public Func<DbContext> GetContextCreator()
=> () => CreateContext();

protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext context)
=> modelBuilder.Entity<PrimitiveCollectionsEntity>().Property(p => p.Id).ValueGeneratedNever();
=> modelBuilder.Entity<PrimitiveCollectionsEntity>(
b =>
{
b.Property(e => e.Id).ValueGeneratedNever();
b.Property(e => e.WrappedId).HasConversion<WrappedIdConverter>();
b.Property(e => e.NullableWrappedId).HasConversion<WrappedIdConverter>();
b.Property(e => e.NullableWrappedIdWithNullableComparer).HasConversion<NullableWrappedIdConverter>();
});

protected class WrappedIdConverter() : ValueConverter<WrappedId, int>(v => v.Value, v => new(v));

// Note that value comparers over nullable value types are not a good idea, unless the comparer is handling nulls itself.
protected class NullableWrappedIdConverter() : ValueConverter<WrappedId?, int?>(
id => id == null ? null : id.Value.Value,
value => value == null ? null : new WrappedId(value.Value));

protected override Task SeedAsync(PrimitiveCollectionsContext context)
{
Expand Down Expand Up @@ -1420,9 +1544,12 @@ public class PrimitiveCollectionsEntity
public int Int { get; set; }
public DateTime DateTime { get; set; }
public bool Bool { get; set; }
public WrappedId WrappedId { get; set; }
public MyEnum Enum { get; set; }
public int? NullableInt { get; set; }
public string? NullableString { get; set; }
public WrappedId? NullableWrappedId { get; set; }
public WrappedId? NullableWrappedIdWithNullableComparer { get; set; }

public required string[] Strings { get; set; }
public required int[] Ints { get; set; }
Expand All @@ -1433,6 +1560,8 @@ public class PrimitiveCollectionsEntity
public required string?[] NullableStrings { get; set; }
}

public readonly record struct WrappedId(int Value);

public enum MyEnum { Value1, Value2, Value3, Value4 }

public class PrimitiveCollectionsData : ISetSource
Expand Down Expand Up @@ -1464,8 +1593,11 @@ private static IReadOnlyList<PrimitiveCollectionsEntity> CreatePrimitiveArrayEnt
DateTime = new DateTime(2020, 1, 10, 12, 30, 0, DateTimeKind.Utc),
Bool = true,
Enum = MyEnum.Value1,
WrappedId = new(22),
NullableInt = 10,
NullableString = "10",
NullableWrappedId = new(22),
NullableWrappedIdWithNullableComparer = new(22),
Ints = [1, 10],
Strings = ["1", "10"],
DateTimes =
Expand All @@ -1475,7 +1607,7 @@ private static IReadOnlyList<PrimitiveCollectionsEntity> CreatePrimitiveArrayEnt
Bools = [true, false],
Enums = [MyEnum.Value1, MyEnum.Value2],
NullableInts = [1, 10],
NullableStrings = ["1", "10"]
NullableStrings = ["1", "10"],
},
new()
{
Expand All @@ -1485,8 +1617,11 @@ private static IReadOnlyList<PrimitiveCollectionsEntity> CreatePrimitiveArrayEnt
DateTime = new DateTime(2020, 1, 11, 12, 30, 0, DateTimeKind.Utc),
Bool = false,
Enum = MyEnum.Value2,
WrappedId = new(22),
NullableInt = null,
NullableString = null,
NullableWrappedId = null,
NullableWrappedIdWithNullableComparer = null,
Ints = [1, 11, 111],
Strings = ["1", "11", "111"],
DateTimes =
Expand All @@ -1508,8 +1643,11 @@ private static IReadOnlyList<PrimitiveCollectionsEntity> CreatePrimitiveArrayEnt
DateTime = new DateTime(2022, 1, 10, 12, 30, 0, DateTimeKind.Utc),
Bool = true,
Enum = MyEnum.Value1,
WrappedId = new(22),
NullableInt = 20,
NullableString = "20",
NullableWrappedId = new(22),
NullableWrappedIdWithNullableComparer = new(22),
Ints = [1, 1, 10, 10, 10, 1, 10],
Strings = ["1", "10", "10", "1", "1"],
DateTimes =
Expand All @@ -1533,8 +1671,11 @@ private static IReadOnlyList<PrimitiveCollectionsEntity> CreatePrimitiveArrayEnt
DateTime = new DateTime(2024, 1, 11, 12, 30, 0, DateTimeKind.Utc),
Bool = false,
Enum = MyEnum.Value2,
WrappedId = new(22),
NullableInt = null,
NullableString = null,
NullableWrappedId = null,
NullableWrappedIdWithNullableComparer = null,
Ints = [1, 1, 111, 11, 1, 111],
Strings = ["1", "11", "111", "11"],
DateTimes =
Expand All @@ -1551,7 +1692,7 @@ private static IReadOnlyList<PrimitiveCollectionsEntity> CreatePrimitiveArrayEnt
Bools = [false],
Enums = [MyEnum.Value2, MyEnum.Value3],
NullableInts = [null, null],
NullableStrings = [null, null]
NullableStrings = [null, null],
},
new()
{
Expand All @@ -1561,8 +1702,11 @@ private static IReadOnlyList<PrimitiveCollectionsEntity> CreatePrimitiveArrayEnt
DateTime = new DateTime(2000, 1, 1, 0, 0, 0, DateTimeKind.Utc),
Bool = false,
Enum = MyEnum.Value1,
WrappedId = new(22),
NullableWrappedIdWithNullableComparer = null,
NullableInt = null,
NullableString = null,
NullableWrappedId = null,
Ints = [],
Strings = [],
DateTimes = [],
Expand Down
Loading