@@ -79,6 +79,16 @@ protected abstract bool TryAnalyzePatternCondition(
79
79
ISyntaxFacts syntaxFacts , TExpressionSyntax conditionNode ,
80
80
[ NotNullWhen ( true ) ] out TExpressionSyntax ? conditionPartToCheck , out bool isEquals ) ;
81
81
82
+ public ( INamedTypeSymbol ? expressionType , IMethodSymbol ? referenceEqualsMethod ) GetAnalysisSymbols ( Compilation compilation )
83
+ {
84
+ var expressionType = compilation . ExpressionOfTType ( ) ;
85
+ var objectType = compilation . GetSpecialType ( SpecialType . System_Object ) ;
86
+ var referenceEqualsMethod = objectType ? . GetMembers ( nameof ( ReferenceEquals ) )
87
+ . OfType < IMethodSymbol > ( )
88
+ . FirstOrDefault ( m => m is { DeclaredAccessibility : Accessibility . Public , Parameters . Length : 2 } ) ;
89
+ return ( expressionType , referenceEqualsMethod ) ;
90
+ }
91
+
82
92
protected override void InitializeWorker ( AnalysisContext context )
83
93
{
84
94
context . RegisterCompilationStartAction ( context =>
@@ -88,22 +98,33 @@ protected override void InitializeWorker(AnalysisContext context)
88
98
89
99
var expressionType = context . Compilation . ExpressionOfTType ( ) ;
90
100
91
- var objectType = context . Compilation . GetSpecialType ( SpecialType . System_Object ) ;
92
- var referenceEqualsMethod = objectType ? . GetMembers ( nameof ( ReferenceEquals ) )
93
- . OfType < IMethodSymbol > ( )
94
- . FirstOrDefault ( m => m is { DeclaredAccessibility : Accessibility . Public , Parameters . Length : 2 } ) ;
101
+ var ( objectType , referenceEqualsMethod ) = GetAnalysisSymbols ( context . Compilation ) ;
95
102
96
103
var syntaxKinds = this . SyntaxFacts . SyntaxKinds ;
97
104
context . RegisterSyntaxNodeAction (
98
- context => AnalyzeTernaryConditionalExpression ( context , expressionType , referenceEqualsMethod ) ,
105
+ context => AnalyzeTernaryConditionalExpressionAndReportDiagnostic ( context , expressionType , referenceEqualsMethod ) ,
99
106
syntaxKinds . Convert < TSyntaxKind > ( syntaxKinds . TernaryConditionalExpression ) ) ;
100
107
context . RegisterSyntaxNodeAction (
101
- context => AnalyzeIfStatement ( context , referenceEqualsMethod ) ,
108
+ context => AnalyzeIfStatementAndReportDiagnostic ( context , referenceEqualsMethod ) ,
102
109
IfStatementSyntaxKind ) ;
103
110
} ) ;
104
111
}
105
112
106
- private void AnalyzeTernaryConditionalExpression (
113
+ public ( TExpressionSyntax conditionalPart , SyntaxNode whenPart ) ? GetPartsOfConditionalExpression (
114
+ SemanticModel semanticModel ,
115
+ TConditionalExpressionSyntax conditionalExpression ,
116
+ CancellationToken cancellationToken )
117
+ {
118
+ var ( objectType , referenceEqualsMethod ) = GetAnalysisSymbols ( semanticModel . Compilation ) ;
119
+ var analysisResult = AnalyzeTernaryConditionalExpression (
120
+ semanticModel , objectType , referenceEqualsMethod , conditionalExpression , cancellationToken ) ;
121
+ if ( analysisResult is null )
122
+ return null ;
123
+
124
+ return ( analysisResult . Value . ConditionPartToCheck , analysisResult . Value . WhenPartToCheck ) ;
125
+ }
126
+
127
+ private void AnalyzeTernaryConditionalExpressionAndReportDiagnostic (
107
128
SyntaxNodeAnalysisContext context ,
108
129
INamedTypeSymbol ? expressionType ,
109
130
IMethodSymbol ? referenceEqualsMethod )
@@ -115,6 +136,27 @@ private void AnalyzeTernaryConditionalExpression(
115
136
if ( ! option . Value || ShouldSkipAnalysis ( context , option . Notification ) )
116
137
return ;
117
138
139
+ var analysisResult = AnalyzeTernaryConditionalExpression (
140
+ context . SemanticModel , expressionType , referenceEqualsMethod , conditionalExpression , cancellationToken ) ;
141
+ if ( analysisResult is null )
142
+ return ;
143
+
144
+ context . ReportDiagnostic ( DiagnosticHelper . Create (
145
+ Descriptor ,
146
+ conditionalExpression . GetLocation ( ) ,
147
+ option . Notification ,
148
+ context . Options ,
149
+ additionalLocations : [ conditionalExpression . GetLocation ( ) ] ,
150
+ analysisResult . Value . Properties ) ) ;
151
+ }
152
+
153
+ public ConditionalExpressionAnalysisResult ? AnalyzeTernaryConditionalExpression (
154
+ SemanticModel semanticModel ,
155
+ INamedTypeSymbol ? expressionType ,
156
+ IMethodSymbol ? referenceEqualsMethod ,
157
+ TConditionalExpressionSyntax conditionalExpression ,
158
+ CancellationToken cancellationToken )
159
+ {
118
160
var syntaxFacts = this . SyntaxFacts ;
119
161
syntaxFacts . GetPartsOfConditionalExpression (
120
162
conditionalExpression , out var condition , out var whenTrue , out var whenFalse ) ;
@@ -125,32 +167,31 @@ private void AnalyzeTernaryConditionalExpression(
125
167
var whenFalseNode = ( TExpressionSyntax ) syntaxFacts . WalkDownParentheses ( whenFalse ) ;
126
168
127
169
if ( ! TryAnalyzeCondition (
128
- context , syntaxFacts , referenceEqualsMethod , conditionNode ,
129
- out var conditionPartToCheck , out var isEquals ) )
170
+ semanticModel , referenceEqualsMethod , conditionNode ,
171
+ out var conditionPartToCheck , out var isEquals , cancellationToken ) )
130
172
{
131
- return ;
173
+ return null ;
132
174
}
133
175
134
176
// Needs to be of the form:
135
177
// x == null ? null : ... or
136
178
// x != null ? ... : null;
137
179
if ( isEquals && ! syntaxFacts . IsNullLiteralExpression ( whenTrueNode ) )
138
- return ;
180
+ return null ;
139
181
140
182
if ( ! isEquals && ! syntaxFacts . IsNullLiteralExpression ( whenFalseNode ) )
141
- return ;
183
+ return null ;
142
184
143
185
var whenPartToCheck = isEquals ? whenFalseNode : whenTrueNode ;
144
186
145
- var semanticModel = context . SemanticModel ;
146
187
var whenPartMatch = GetWhenPartMatch ( syntaxFacts , semanticModel , conditionPartToCheck , whenPartToCheck , cancellationToken ) ;
147
188
if ( whenPartMatch == null )
148
- return ;
189
+ return null ;
149
190
150
191
// can't use ?. on a pointer
151
192
var whenPartType = semanticModel . GetTypeInfo ( whenPartMatch , cancellationToken ) . Type ;
152
193
if ( whenPartType is IPointerTypeSymbol )
153
- return ;
194
+ return null ;
154
195
155
196
var type = semanticModel . GetTypeInfo ( conditionalExpression , cancellationToken ) . Type ;
156
197
if ( type ? . IsValueType == true )
@@ -160,7 +201,7 @@ private void AnalyzeTernaryConditionalExpression(
160
201
// User has something like: If(str is nothing, nothing, str.Length)
161
202
// In this case, converting to str?.Length changes the type of this from
162
203
// int to int?
163
- return ;
204
+ return null ;
164
205
}
165
206
// But for a nullable type, such as If(c is nothing, nothing, c.nullable)
166
207
// converting to c?.nullable doesn't affect the type
@@ -172,7 +213,7 @@ private void AnalyzeTernaryConditionalExpression(
172
213
// `x == null ? x : x.M` cannot be converted to `x?.M` when M is a method symbol.
173
214
var memberSymbol = semanticModel . GetSymbolInfo ( whenPartToCheck , cancellationToken ) . GetAnySymbol ( ) ;
174
215
if ( memberSymbol is IMethodSymbol )
175
- return ;
216
+ return null ;
176
217
177
218
// `x == null ? x : x.Value` will be converted to just 'x'.
178
219
if ( UseNullPropagationHelpers . IsSystemNullableValueProperty ( memberSymbol ) )
@@ -181,12 +222,7 @@ private void AnalyzeTernaryConditionalExpression(
181
222
182
223
// ?. is not available in expression-trees. Disallow the fix in that case.
183
224
if ( this . SemanticFacts . IsInExpressionTree ( semanticModel , conditionNode , expressionType , cancellationToken ) )
184
- return ;
185
-
186
- var locations = ImmutableArray . Create (
187
- conditionalExpression . GetLocation ( ) ,
188
- conditionPartToCheck . GetLocation ( ) ,
189
- whenPartToCheck . GetLocation ( ) ) ;
225
+ return null ;
190
226
191
227
var whenPartIsNullable = whenPartType ? . OriginalDefinition . SpecialType == SpecialType . System_Nullable_T ;
192
228
var properties = whenPartIsNullable
@@ -196,23 +232,21 @@ private void AnalyzeTernaryConditionalExpression(
196
232
if ( isTrivialNullableValueAccess )
197
233
properties = properties . Add ( UseNullPropagationHelpers . IsTrivialNullableValueAccess , UseNullPropagationHelpers . IsTrivialNullableValueAccess ) ;
198
234
199
- context . ReportDiagnostic ( DiagnosticHelper . Create (
200
- Descriptor ,
201
- conditionalExpression . GetLocation ( ) ,
202
- option . Notification ,
203
- context . Options ,
204
- locations ,
205
- properties ) ) ;
235
+ return new (
236
+ conditionPartToCheck ,
237
+ whenPartToCheck ,
238
+ properties ) ;
206
239
}
207
240
208
241
private bool TryAnalyzeCondition (
209
- SyntaxNodeAnalysisContext context ,
210
- ISyntaxFacts syntaxFacts ,
242
+ SemanticModel semanticModel ,
211
243
IMethodSymbol ? referenceEqualsMethod ,
212
244
TExpressionSyntax condition ,
213
245
[ NotNullWhen ( true ) ] out TExpressionSyntax ? conditionPartToCheck ,
214
- out bool isEquals )
246
+ out bool isEquals ,
247
+ CancellationToken cancellationToken )
215
248
{
249
+ var syntaxFacts = this . SyntaxFacts ;
216
250
condition = ( TExpressionSyntax ) syntaxFacts . WalkDownParentheses ( condition ) ;
217
251
var conditionIsNegated = false ;
218
252
if ( syntaxFacts . IsLogicalNotExpression ( condition ) )
@@ -228,8 +262,7 @@ private bool TryAnalyzeCondition(
228
262
syntaxFacts , binaryExpression , out conditionPartToCheck , out isEquals ) ,
229
263
230
264
TInvocationExpressionSyntax invocation => TryAnalyzeInvocationCondition (
231
- context , syntaxFacts , referenceEqualsMethod , invocation ,
232
- out conditionPartToCheck , out isEquals ) ,
265
+ semanticModel , syntaxFacts , referenceEqualsMethod , invocation , out conditionPartToCheck , out isEquals , cancellationToken ) ,
233
266
234
267
_ => TryAnalyzePatternCondition ( syntaxFacts , condition , out conditionPartToCheck , out isEquals ) ,
235
268
} ;
@@ -261,12 +294,13 @@ private static bool TryAnalyzeBinaryExpressionCondition(
261
294
}
262
295
263
296
private static bool TryAnalyzeInvocationCondition (
264
- SyntaxNodeAnalysisContext context ,
297
+ SemanticModel semanticModel ,
265
298
ISyntaxFacts syntaxFacts ,
266
299
IMethodSymbol ? referenceEqualsMethod ,
267
300
TInvocationExpressionSyntax invocation ,
268
301
[ NotNullWhen ( true ) ] out TExpressionSyntax ? conditionPartToCheck ,
269
- out bool isEquals )
302
+ out bool isEquals ,
303
+ CancellationToken cancellationToken )
270
304
{
271
305
conditionPartToCheck = null ;
272
306
isEquals = true ;
@@ -311,8 +345,6 @@ private static bool TryAnalyzeInvocationCondition(
311
345
return false ;
312
346
}
313
347
314
- var semanticModel = context . SemanticModel ;
315
- var cancellationToken = context . CancellationToken ;
316
348
var symbol = semanticModel . GetSymbolInfo ( invocation , cancellationToken ) . Symbol ;
317
349
return referenceEqualsMethod . Equals ( symbol ) ;
318
350
}
@@ -337,7 +369,8 @@ private static bool TryAnalyzeInvocationCondition(
337
369
return conditionRightIsNull ? conditionLeft : conditionRight ;
338
370
}
339
371
340
- internal static TExpressionSyntax ? GetWhenPartMatch (
372
+ #pragma warning disable CA1822 // Mark members as static. Helper method that doesn't want to call through generic form.
373
+ public TExpressionSyntax ? GetWhenPartMatch (
341
374
ISyntaxFacts syntaxFacts ,
342
375
SemanticModel semanticModel ,
343
376
TExpressionSyntax expressionToMatch ,
@@ -361,6 +394,7 @@ private static bool TryAnalyzeInvocationCondition(
361
394
current = unwrapped ;
362
395
}
363
396
}
397
+ #pragma warning restore CA1822 // Mark members as static
364
398
365
399
private static TExpressionSyntax RemoveObjectCastIfAny (
366
400
ISyntaxFacts syntaxFacts , SemanticModel semanticModel , TExpressionSyntax node , CancellationToken cancellationToken )
0 commit comments