@@ -16,6 +16,9 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
16
16
/// </summary>
17
17
public class CosmosQueryableMethodTranslatingExpressionVisitor : QueryableMethodTranslatingExpressionVisitor
18
18
{
19
+ private static readonly bool UseOldBehavior35094 =
20
+ AppContext . TryGetSwitch ( "Microsoft.EntityFrameworkCore.Issue35094" , out var enabled ) && enabled ;
21
+
19
22
private readonly CosmosQueryCompilationContext _queryCompilationContext ;
20
23
private readonly ISqlExpressionFactory _sqlExpressionFactory ;
21
24
private readonly ITypeMappingSource _typeMappingSource ;
@@ -445,23 +448,29 @@ private ShapedQueryExpression CreateShapedQueryExpression(SelectExpression selec
445
448
/// </summary>
446
449
protected override ShapedQueryExpression ? TranslateAverage ( ShapedQueryExpression source , LambdaExpression ? selector , Type resultType )
447
450
{
448
- var selectExpression = ( SelectExpression ) source . QueryExpression ;
449
- if ( selectExpression . IsDistinct
450
- || selectExpression . Limit != null
451
- || selectExpression . Offset != null )
451
+ if ( UseOldBehavior35094 )
452
452
{
453
- return null ;
454
- }
453
+ var selectExpression = ( SelectExpression ) source . QueryExpression ;
454
+ if ( selectExpression . IsDistinct
455
+ || selectExpression . Limit != null
456
+ || selectExpression . Offset != null )
457
+ {
458
+ return null ;
459
+ }
455
460
456
- if ( selector != null )
457
- {
458
- source = TranslateSelect ( source , selector ) ;
459
- }
461
+ if ( selector != null )
462
+ {
463
+ source = TranslateSelect ( source , selector ) ;
464
+ }
460
465
461
- var projection = ( SqlExpression ) selectExpression . GetMappedProjection ( new ProjectionMember ( ) ) ;
462
- projection = _sqlExpressionFactory . Function ( "AVG" , new [ ] { projection } , projection . Type , projection . TypeMapping ) ;
466
+ var projection = ( SqlExpression ) selectExpression . GetMappedProjection ( new ProjectionMember ( ) ) ;
467
+ projection = _sqlExpressionFactory . Function ( "AVG" , new [ ] { projection } , projection . Type , projection . TypeMapping ) ;
463
468
464
- return AggregateResultShaper ( source , projection , throwOnNullResult : true , resultType ) ;
469
+ return AggregateResultShaper ( source , projection , throwOnNullResult : true , resultType ) ;
470
+
471
+ }
472
+
473
+ return TranslateAggregate ( source , selector , resultType , "AVG" ) ;
465
474
}
466
475
467
476
/// <summary>
@@ -843,24 +852,29 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
843
852
/// </summary>
844
853
protected override ShapedQueryExpression ? TranslateMax ( ShapedQueryExpression source , LambdaExpression ? selector , Type resultType )
845
854
{
846
- var selectExpression = ( SelectExpression ) source . QueryExpression ;
847
- if ( selectExpression . IsDistinct
848
- || selectExpression . Limit != null
849
- || selectExpression . Offset != null )
855
+ if ( UseOldBehavior35094 )
850
856
{
851
- return null ;
852
- }
857
+ var selectExpression = ( SelectExpression ) source . QueryExpression ;
858
+ if ( selectExpression . IsDistinct
859
+ || selectExpression . Limit != null
860
+ || selectExpression . Offset != null )
861
+ {
862
+ return null ;
863
+ }
853
864
854
- if ( selector != null )
855
- {
856
- source = TranslateSelect ( source , selector ) ;
857
- }
865
+ if ( selector != null )
866
+ {
867
+ source = TranslateSelect ( source , selector ) ;
868
+ }
858
869
859
- var projection = ( SqlExpression ) selectExpression . GetMappedProjection ( new ProjectionMember ( ) ) ;
870
+ var projection = ( SqlExpression ) selectExpression . GetMappedProjection ( new ProjectionMember ( ) ) ;
860
871
861
- projection = _sqlExpressionFactory . Function ( "MAX" , new [ ] { projection } , resultType , projection . TypeMapping ) ;
872
+ projection = _sqlExpressionFactory . Function ( "MAX" , new [ ] { projection } , resultType , projection . TypeMapping ) ;
862
873
863
- return AggregateResultShaper ( source , projection , throwOnNullResult : true , resultType ) ;
874
+ return AggregateResultShaper ( source , projection , throwOnNullResult : true , resultType ) ;
875
+ }
876
+
877
+ return TranslateAggregate ( source , selector , resultType , "MAX" ) ;
864
878
}
865
879
866
880
/// <summary>
@@ -871,24 +885,29 @@ protected override ShapedQueryExpression TranslateCast(ShapedQueryExpression sou
871
885
/// </summary>
872
886
protected override ShapedQueryExpression ? TranslateMin ( ShapedQueryExpression source , LambdaExpression ? selector , Type resultType )
873
887
{
874
- var selectExpression = ( SelectExpression ) source . QueryExpression ;
875
- if ( selectExpression . IsDistinct
876
- || selectExpression . Limit != null
877
- || selectExpression . Offset != null )
888
+ if ( UseOldBehavior35094 )
878
889
{
879
- return null ;
880
- }
890
+ var selectExpression = ( SelectExpression ) source . QueryExpression ;
891
+ if ( selectExpression . IsDistinct
892
+ || selectExpression . Limit != null
893
+ || selectExpression . Offset != null )
894
+ {
895
+ return null ;
896
+ }
881
897
882
- if ( selector != null )
883
- {
884
- source = TranslateSelect ( source , selector ) ;
885
- }
898
+ if ( selector != null )
899
+ {
900
+ source = TranslateSelect ( source , selector ) ;
901
+ }
886
902
887
- var projection = ( SqlExpression ) selectExpression . GetMappedProjection ( new ProjectionMember ( ) ) ;
903
+ var projection = ( SqlExpression ) selectExpression . GetMappedProjection ( new ProjectionMember ( ) ) ;
888
904
889
- projection = _sqlExpressionFactory . Function ( "MIN" , new [ ] { projection } , resultType , projection . TypeMapping ) ;
905
+ projection = _sqlExpressionFactory . Function ( "MIN" , new [ ] { projection } , resultType , projection . TypeMapping ) ;
890
906
891
- return AggregateResultShaper ( source , projection , throwOnNullResult : true , resultType ) ;
907
+ return AggregateResultShaper ( source , projection , throwOnNullResult : true , resultType ) ;
908
+ }
909
+
910
+ return TranslateAggregate ( source , selector , resultType , "MIN" ) ;
892
911
}
893
912
894
913
/// <summary>
@@ -1520,6 +1539,35 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s
1520
1539
1521
1540
#endregion Queryable collection support
1522
1541
1542
+ private ShapedQueryExpression ? TranslateAggregate ( ShapedQueryExpression source , LambdaExpression ? selector , Type resultType , string functionName )
1543
+ {
1544
+ var selectExpression = ( SelectExpression ) source . QueryExpression ;
1545
+ if ( selectExpression . IsDistinct
1546
+ || selectExpression . Limit != null
1547
+ || selectExpression . Offset != null )
1548
+ {
1549
+ return null ;
1550
+ }
1551
+
1552
+ if ( selector != null )
1553
+ {
1554
+ source = TranslateSelect ( source , selector ) ;
1555
+ }
1556
+
1557
+ if ( ! _subquery && resultType . IsNullableType ( ) )
1558
+ {
1559
+ // For nullable types, we want to return null from Max, Min, and Average, rather than throwing. See Issue #35094.
1560
+ // Note that relational databases typically return null, which propagates. Cosmos will instead return no elements,
1561
+ // and hence for Cosmos only we need to change no elements into null.
1562
+ source = source . UpdateResultCardinality ( ResultCardinality . SingleOrDefault ) ;
1563
+ }
1564
+
1565
+ var projection = ( SqlExpression ) selectExpression . GetMappedProjection ( new ProjectionMember ( ) ) ;
1566
+ projection = _sqlExpressionFactory . Function ( functionName , [ projection ] , resultType , _typeMappingSource . FindMapping ( resultType ) ) ;
1567
+
1568
+ return AggregateResultShaper ( source , projection , throwOnNullResult : true , resultType ) ;
1569
+ }
1570
+
1523
1571
private bool TryApplyPredicate ( ShapedQueryExpression source , LambdaExpression predicate )
1524
1572
{
1525
1573
var select = ( SelectExpression ) source . QueryExpression ;
0 commit comments