@@ -14,13 +14,25 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.ChangeTracking.Internal;
14
14
/// </summary>
15
15
public sealed class StringDictionaryComparer < TDictionary , TElement > : ValueComparer < object > , IInfrastructure < ValueComparer >
16
16
{
17
+ private static readonly bool UseOldBehavior35239 =
18
+ AppContext . TryGetSwitch ( "Microsoft.EntityFrameworkCore.Issue35239" , out var enabled35239 ) && enabled35239 ;
19
+
17
20
private static readonly MethodInfo CompareMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
21
+ nameof ( Compare ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( object ) , typeof ( object ) , typeof ( Func < TElement , TElement , bool > ) ] ) ! ;
22
+
23
+ private static readonly MethodInfo LegacyCompareMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
18
24
nameof ( Compare ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( object ) , typeof ( object ) , typeof ( ValueComparer ) ] ) ! ;
19
25
20
26
private static readonly MethodInfo GetHashCodeMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
27
+ nameof ( GetHashCode ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( IEnumerable ) , typeof ( Func < TElement , int > ) ] ) ! ;
28
+
29
+ private static readonly MethodInfo LegacyGetHashCodeMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
21
30
nameof ( GetHashCode ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( IEnumerable ) , typeof ( ValueComparer ) ] ) ! ;
22
31
23
32
private static readonly MethodInfo SnapshotMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
33
+ nameof ( Snapshot ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( object ) , typeof ( Func < TElement , TElement > ) ] ) ! ;
34
+
35
+ private static readonly MethodInfo LegacySnapshotMethod = typeof ( StringDictionaryComparer < TDictionary , TElement > ) . GetMethod (
24
36
nameof ( Snapshot ) , BindingFlags . Static | BindingFlags . NonPublic , [ typeof ( object ) , typeof ( ValueComparer ) ] ) ! ;
25
37
26
38
/// <summary>
@@ -52,9 +64,23 @@ ValueComparer IInfrastructure<ValueComparer>.Instance
52
64
var prm1 = Expression . Parameter ( typeof ( object ) , "a" ) ;
53
65
var prm2 = Expression . Parameter ( typeof ( object ) , "b" ) ;
54
66
67
+ if ( elementComparer is ValueComparer < TElement > && ! UseOldBehavior35239 )
68
+ {
69
+ // (a, b) => Compare(a, b, elementComparer.Equals)
70
+ return Expression . Lambda < Func < object ? , object ? , bool > > (
71
+ Expression . Call (
72
+ CompareMethod ,
73
+ prm1 ,
74
+ prm2 ,
75
+ elementComparer . EqualsExpression ) ,
76
+ prm1 ,
77
+ prm2 ) ;
78
+ }
79
+
80
+ // (a, b) => Compare(a, b, new Comparer(...))
55
81
return Expression . Lambda < Func < object ? , object ? , bool > > (
56
82
Expression . Call (
57
- CompareMethod ,
83
+ LegacyCompareMethod ,
58
84
prm1 ,
59
85
prm2 ,
60
86
#pragma warning disable EF9100
@@ -68,9 +94,23 @@ private static Expression<Func<object, int>> GetHashCodeLambda(ValueComparer ele
68
94
{
69
95
var prm = Expression . Parameter ( typeof ( object ) , "o" ) ;
70
96
97
+ if ( elementComparer is ValueComparer < TElement > && ! UseOldBehavior35239 )
98
+ {
99
+ // o => GetHashCode((IEnumerable)o, elementComparer.GetHashCode)
100
+ return Expression . Lambda < Func < object , int > > (
101
+ Expression . Call (
102
+ GetHashCodeMethod ,
103
+ Expression . Convert (
104
+ prm ,
105
+ typeof ( IEnumerable ) ) ,
106
+ elementComparer . HashCodeExpression ) ,
107
+ prm ) ;
108
+ }
109
+
110
+ // o => GetHashCode((IEnumerable)o, new Comparer(...))
71
111
return Expression . Lambda < Func < object , int > > (
72
112
Expression . Call (
73
- GetHashCodeMethod ,
113
+ LegacyGetHashCodeMethod ,
74
114
Expression . Convert (
75
115
prm ,
76
116
typeof ( IEnumerable ) ) ,
@@ -84,16 +124,70 @@ private static Expression<Func<object, object>> SnapshotLambda(ValueComparer ele
84
124
{
85
125
var prm = Expression . Parameter ( typeof ( object ) , "source" ) ;
86
126
127
+ if ( elementComparer is ValueComparer < TElement > && ! UseOldBehavior35239 )
128
+ {
129
+ // source => Snapshot(source, elementComparer.Snapshot)
130
+ return Expression . Lambda < Func < object , object > > (
131
+ Expression . Call (
132
+ SnapshotMethod ,
133
+ prm ,
134
+ elementComparer . SnapshotExpression ) ,
135
+ prm ) ;
136
+ }
137
+
138
+ // source => Snapshot(source, new Comparer(..))
87
139
return Expression . Lambda < Func < object , object > > (
88
140
Expression . Call (
89
- SnapshotMethod ,
141
+ LegacySnapshotMethod ,
90
142
prm ,
91
143
#pragma warning disable EF9100
92
144
elementComparer . ConstructorExpression ) ,
93
145
#pragma warning restore EF9100
94
146
prm ) ;
95
147
}
96
148
149
+ private static bool Compare ( object ? a , object ? b , Func < TElement ? , TElement ? , bool > elementCompare )
150
+ {
151
+ if ( ReferenceEquals ( a , b ) )
152
+ {
153
+ return true ;
154
+ }
155
+
156
+ if ( a is null )
157
+ {
158
+ return b is null ;
159
+ }
160
+
161
+ if ( b is null )
162
+ {
163
+ return false ;
164
+ }
165
+
166
+ if ( a is IReadOnlyDictionary < string , TElement ? > aDictionary && b is IReadOnlyDictionary < string , TElement ? > bDictionary )
167
+ {
168
+ if ( aDictionary . Count != bDictionary . Count )
169
+ {
170
+ return false ;
171
+ }
172
+
173
+ foreach ( var pair in aDictionary )
174
+ {
175
+ if ( ! bDictionary . TryGetValue ( pair . Key , out var bValue )
176
+ || ! elementCompare ( pair . Value , bValue ) )
177
+ {
178
+ return false ;
179
+ }
180
+ }
181
+
182
+ return true ;
183
+ }
184
+
185
+ throw new InvalidOperationException (
186
+ CosmosStrings . BadDictionaryType (
187
+ ( a is IDictionary < string , TElement ? > ? b : a ) . GetType ( ) . ShortDisplayName ( ) ,
188
+ typeof ( IDictionary < , > ) . MakeGenericType ( typeof ( string ) , typeof ( TElement ) ) . ShortDisplayName ( ) ) ) ;
189
+ }
190
+
97
191
private static bool Compare ( object ? a , object ? b , ValueComparer elementComparer )
98
192
{
99
193
if ( ReferenceEquals ( a , b ) )
@@ -136,6 +230,27 @@ private static bool Compare(object? a, object? b, ValueComparer elementComparer)
136
230
typeof ( IDictionary < , > ) . MakeGenericType ( typeof ( string ) , elementComparer . Type ) . ShortDisplayName ( ) ) ) ;
137
231
}
138
232
233
+ private static int GetHashCode ( IEnumerable source , Func < TElement ? , int > elementGetHashCode )
234
+ {
235
+ if ( source is not IReadOnlyDictionary < string , TElement ? > sourceDictionary )
236
+ {
237
+ throw new InvalidOperationException (
238
+ CosmosStrings . BadDictionaryType (
239
+ source . GetType ( ) . ShortDisplayName ( ) ,
240
+ typeof ( IList < > ) . MakeGenericType ( typeof ( TElement ) ) . ShortDisplayName ( ) ) ) ;
241
+ }
242
+
243
+ var hash = new HashCode ( ) ;
244
+
245
+ foreach ( var pair in sourceDictionary )
246
+ {
247
+ hash . Add ( pair . Key ) ;
248
+ hash . Add ( pair . Value == null ? 0 : elementGetHashCode ( pair . Value ) ) ;
249
+ }
250
+
251
+ return hash . ToHashCode ( ) ;
252
+ }
253
+
139
254
private static int GetHashCode ( IEnumerable source , ValueComparer elementComparer )
140
255
{
141
256
if ( source is not IReadOnlyDictionary < string , TElement ? > sourceDictionary )
@@ -157,6 +272,25 @@ private static int GetHashCode(IEnumerable source, ValueComparer elementComparer
157
272
return hash . ToHashCode ( ) ;
158
273
}
159
274
275
+ private static IReadOnlyDictionary < string , TElement ? > Snapshot ( object source , Func < TElement ? , TElement ? > elementSnapshot )
276
+ {
277
+ if ( source is not IReadOnlyDictionary < string , TElement ? > sourceDictionary )
278
+ {
279
+ throw new InvalidOperationException (
280
+ CosmosStrings . BadDictionaryType (
281
+ source . GetType ( ) . ShortDisplayName ( ) ,
282
+ typeof ( IDictionary < , > ) . MakeGenericType ( typeof ( string ) , typeof ( TElement ) ) . ShortDisplayName ( ) ) ) ;
283
+ }
284
+
285
+ var snapshot = new Dictionary < string , TElement ? > ( ) ;
286
+ foreach ( var pair in sourceDictionary )
287
+ {
288
+ snapshot [ pair . Key ] = pair . Value == null ? default : ( TElement ? ) elementSnapshot ( pair . Value ) ;
289
+ }
290
+
291
+ return snapshot ;
292
+ }
293
+
160
294
private static IReadOnlyDictionary < string , TElement ? > Snapshot ( object source , ValueComparer elementComparer )
161
295
{
162
296
if ( source is not IReadOnlyDictionary < string , TElement ? > sourceDictionary )
0 commit comments