Skip to content

Commit 5a719f2

Browse files
[pycodestyle] Auto-fix redundant boolean comparison (E712) (#17090)
This pull request fixes #17014 changes this ```python from __future__ import annotations flag1 = True flag2 = True if flag1 == True or flag2 == True: pass if flag1 == False and flag2 == False: pass flag3 = True if flag1 == flag3 and (flag2 == False or flag3 == True): # Should become: if flag1==flag3 and (not flag2 or flag3) pass if flag1 == True and (flag2 == False or not flag3 == True): # Should become: if flag1 and (not flag2 or not flag3) pass if flag1 != True and (flag2 != False or not flag3 == True): # Should become: if not flag1 and (flag2 or not flag3) pass flag = True while flag == True: # Should become: while flag flag = False flag = True x = 5 if flag == True and x > 0: # Should become: if flag and x > 0 print("ok") flag = True result = "yes" if flag == True else "no" # Should become: result = "yes" if flag else "no" x = flag == True < 5 x = (flag == True) == False < 5 ``` to this ```python from __future__ import annotations flag1 = True flag2 = True if flag1 or flag2: pass if not flag1 and not flag2: pass flag3 = True if flag1 == flag3 and (not flag2 or flag3): # Should become: if flag1 == flag3 and (not flag2 or flag3) pass if flag1 and (not flag2 or not flag3): # Should become: if flag1 and (not flag2 or not flag3) pass if not flag1 and (flag2 or not flag3): # Should become: if not flag1 and (flag2 or not flag3) pass flag = True while flag: # Should become: while flag flag = False flag = True x = 5 if flag and x > 0: # Should become: if flag and x > 0 print("ok") flag = True result = "yes" if flag else "no" # Should become: result = "yes" if flag else "no" x = flag is True < 5 x = (flag) is False < 5 ``` --------- Co-authored-by: Brent Westbrook <[email protected]>
1 parent e7f38fe commit 5a719f2

File tree

3 files changed

+100
-27
lines changed

3 files changed

+100
-27
lines changed

crates/ruff_linter/src/rules/pycodestyle/rules/literal_comparisons.rs

Lines changed: 85 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
use ruff_python_ast::parenthesize::parenthesized_range;
12
use rustc_hash::FxHashMap;
23

34
use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix};
45
use ruff_macros::{derive_message_formats, ViolationMetadata};
5-
use ruff_python_ast::helpers;
6-
use ruff_python_ast::helpers::generate_comparison;
6+
use ruff_python_ast::helpers::{self, generate_comparison};
77
use ruff_python_ast::{self as ast, CmpOp, Expr};
88
use ruff_text_size::Ranged;
99

@@ -170,6 +170,42 @@ impl AlwaysFixableViolation for TrueFalseComparison {
170170
}
171171
}
172172

173+
fn is_redundant_boolean_comparison(op: CmpOp, comparator: &Expr) -> Option<bool> {
174+
let value = comparator.as_boolean_literal_expr()?.value;
175+
match op {
176+
CmpOp::Is | CmpOp::Eq => Some(value),
177+
CmpOp::IsNot | CmpOp::NotEq => Some(!value),
178+
_ => None,
179+
}
180+
}
181+
182+
fn generate_redundant_comparison(
183+
compare: &ast::ExprCompare,
184+
comment_ranges: &ruff_python_trivia::CommentRanges,
185+
source: &str,
186+
comparator: &Expr,
187+
kind: bool,
188+
needs_wrap: bool,
189+
) -> String {
190+
let comparator_range =
191+
parenthesized_range(comparator.into(), compare.into(), comment_ranges, source)
192+
.unwrap_or(comparator.range());
193+
194+
let comparator_str = &source[comparator_range];
195+
196+
let result = if kind {
197+
comparator_str.to_string()
198+
} else {
199+
format!("not {comparator_str}")
200+
};
201+
202+
if needs_wrap {
203+
format!("({result})")
204+
} else {
205+
result
206+
}
207+
}
208+
173209
/// E711, E712
174210
pub(crate) fn literal_comparisons(checker: &Checker, compare: &ast::ExprCompare) {
175211
// Mapping from (bad operator index) to (replacement operator). As we iterate
@@ -323,22 +359,60 @@ pub(crate) fn literal_comparisons(checker: &Checker, compare: &ast::ExprCompare)
323359
// TODO(charlie): Respect `noqa` directives. If one of the operators has a
324360
// `noqa`, but another doesn't, both will be removed here.
325361
if !bad_ops.is_empty() {
326-
// Replace the entire comparison expression.
327362
let ops = compare
328363
.ops
329364
.iter()
330365
.enumerate()
331366
.map(|(idx, op)| bad_ops.get(&idx).unwrap_or(op))
332367
.copied()
333368
.collect::<Vec<_>>();
334-
let content = generate_comparison(
335-
&compare.left,
336-
&ops,
337-
&compare.comparators,
338-
compare.into(),
339-
checker.comment_ranges(),
340-
checker.source(),
341-
);
369+
370+
let comment_ranges = checker.comment_ranges();
371+
let source = checker.source();
372+
373+
let content = match (&*compare.ops, &*compare.comparators) {
374+
([op], [comparator]) => {
375+
if let Some(kind) = is_redundant_boolean_comparison(*op, &compare.left) {
376+
let needs_wrap = compare.left.range().start() != compare.range().start();
377+
generate_redundant_comparison(
378+
compare,
379+
comment_ranges,
380+
source,
381+
comparator,
382+
kind,
383+
needs_wrap,
384+
)
385+
} else if let Some(kind) = is_redundant_boolean_comparison(*op, comparator) {
386+
let needs_wrap = comparator.range().end() != compare.range().end();
387+
generate_redundant_comparison(
388+
compare,
389+
comment_ranges,
390+
source,
391+
&compare.left,
392+
kind,
393+
needs_wrap,
394+
)
395+
} else {
396+
generate_comparison(
397+
&compare.left,
398+
&ops,
399+
&compare.comparators,
400+
compare.into(),
401+
comment_ranges,
402+
source,
403+
)
404+
}
405+
}
406+
_ => generate_comparison(
407+
&compare.left,
408+
&ops,
409+
&compare.comparators,
410+
compare.into(),
411+
comment_ranges,
412+
source,
413+
),
414+
};
415+
342416
for diagnostic in &mut diagnostics {
343417
diagnostic.set_fix(Fix::unsafe_edit(Edit::range_replacement(
344418
content.to_string(),

crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__E712_E712.py.snap

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ E712.py:2:4: E712 [*] Avoid equality comparisons to `True`; use `if res:` for tr
1414
Unsafe fix
1515
1 1 | #: E712
1616
2 |-if res == True:
17-
2 |+if res is True:
17+
2 |+if res:
1818
3 3 | pass
1919
4 4 | #: E712
2020
5 5 | if res != False:
@@ -35,7 +35,7 @@ E712.py:5:4: E712 [*] Avoid inequality comparisons to `False`; use `if res:` for
3535
3 3 | pass
3636
4 4 | #: E712
3737
5 |-if res != False:
38-
5 |+if res is not False:
38+
5 |+if res:
3939
6 6 | pass
4040
7 7 | #: E712
4141
8 8 | if True != res:
@@ -56,7 +56,7 @@ E712.py:8:4: E712 [*] Avoid inequality comparisons to `True`; use `if not res:`
5656
6 6 | pass
5757
7 7 | #: E712
5858
8 |-if True != res:
59-
8 |+if True is not res:
59+
8 |+if not res:
6060
9 9 | pass
6161
10 10 | #: E712
6262
11 11 | if False == res:
@@ -77,7 +77,7 @@ E712.py:11:4: E712 [*] Avoid equality comparisons to `False`; use `if not res:`
7777
9 9 | pass
7878
10 10 | #: E712
7979
11 |-if False == res:
80-
11 |+if False is res:
80+
11 |+if not res:
8181
12 12 | pass
8282
13 13 | #: E712
8383
14 14 | if res[1] == True:
@@ -98,7 +98,7 @@ E712.py:14:4: E712 [*] Avoid equality comparisons to `True`; use `if res[1]:` fo
9898
12 12 | pass
9999
13 13 | #: E712
100100
14 |-if res[1] == True:
101-
14 |+if res[1] is True:
101+
14 |+if res[1]:
102102
15 15 | pass
103103
16 16 | #: E712
104104
17 17 | if res[1] != False:
@@ -119,7 +119,7 @@ E712.py:17:4: E712 [*] Avoid inequality comparisons to `False`; use `if res[1]:`
119119
15 15 | pass
120120
16 16 | #: E712
121121
17 |-if res[1] != False:
122-
17 |+if res[1] is not False:
122+
17 |+if res[1]:
123123
18 18 | pass
124124
19 19 | #: E712
125125
20 20 | var = 1 if cond == True else -1 if cond == False else cond
@@ -140,7 +140,7 @@ E712.py:20:12: E712 [*] Avoid equality comparisons to `True`; use `if cond:` for
140140
18 18 | pass
141141
19 19 | #: E712
142142
20 |-var = 1 if cond == True else -1 if cond == False else cond
143-
20 |+var = 1 if cond is True else -1 if cond == False else cond
143+
20 |+var = 1 if cond else -1 if cond == False else cond
144144
21 21 | #: E712
145145
22 22 | if (True) == TrueElement or x == TrueElement:
146146
23 23 | pass
@@ -161,7 +161,7 @@ E712.py:20:36: E712 [*] Avoid equality comparisons to `False`; use `if not cond:
161161
18 18 | pass
162162
19 19 | #: E712
163163
20 |-var = 1 if cond == True else -1 if cond == False else cond
164-
20 |+var = 1 if cond == True else -1 if cond is False else cond
164+
20 |+var = 1 if cond == True else -1 if not cond else cond
165165
21 21 | #: E712
166166
22 22 | if (True) == TrueElement or x == TrueElement:
167167
23 23 | pass
@@ -181,7 +181,7 @@ E712.py:22:4: E712 [*] Avoid equality comparisons to `True`; use `if TrueElement
181181
20 20 | var = 1 if cond == True else -1 if cond == False else cond
182182
21 21 | #: E712
183183
22 |-if (True) == TrueElement or x == TrueElement:
184-
22 |+if (True) is TrueElement or x == TrueElement:
184+
22 |+if (TrueElement) or x == TrueElement:
185185
23 23 | pass
186186
24 24 |
187187
25 25 | if res == True != False:
@@ -241,7 +241,7 @@ E712.py:28:3: E712 [*] Avoid equality comparisons to `True`; use `if TrueElement
241241
26 26 | pass
242242
27 27 |
243243
28 |-if(True) == TrueElement or x == TrueElement:
244-
28 |+if(True) is TrueElement or x == TrueElement:
244+
28 |+if(TrueElement) or x == TrueElement:
245245
29 29 | pass
246246
30 30 |
247247
31 31 | if (yield i) == True:
@@ -261,7 +261,7 @@ E712.py:31:4: E712 [*] Avoid equality comparisons to `True`; use `if yield i:` f
261261
29 29 | pass
262262
30 30 |
263263
31 |-if (yield i) == True:
264-
31 |+if (yield i) is True:
264+
31 |+if (yield i):
265265
32 32 | print("even")
266266
33 33 |
267267
34 34 | #: Okay

crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__constant_literals.snap

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
---
22
source: crates/ruff_linter/src/rules/pycodestyle/mod.rs
3-
snapshot_kind: text
43
---
54
constant_literals.py:4:4: F632 [*] Use `==` to compare constant literals
65
|
@@ -123,7 +122,7 @@ constant_literals.py:14:4: E712 [*] Avoid equality comparisons to `False`; use `
123122
12 12 | if False is "abc": # F632 (fix, but leaves behind unfixable E712)
124123
13 13 | pass
125124
14 |-if False == None: # E711, E712 (fix)
126-
14 |+if False is None: # E711, E712 (fix)
125+
14 |+if not None: # E711, E712 (fix)
127126
15 15 | pass
128127
16 16 | if None == False: # E711, E712 (fix)
129128
17 17 | pass
@@ -144,7 +143,7 @@ constant_literals.py:14:13: E711 [*] Comparison to `None` should be `cond is Non
144143
12 12 | if False is "abc": # F632 (fix, but leaves behind unfixable E712)
145144
13 13 | pass
146145
14 |-if False == None: # E711, E712 (fix)
147-
14 |+if False is None: # E711, E712 (fix)
146+
14 |+if not None: # E711, E712 (fix)
148147
15 15 | pass
149148
16 16 | if None == False: # E711, E712 (fix)
150149
17 17 | pass
@@ -164,7 +163,7 @@ constant_literals.py:16:4: E711 [*] Comparison to `None` should be `cond is None
164163
14 14 | if False == None: # E711, E712 (fix)
165164
15 15 | pass
166165
16 |-if None == False: # E711, E712 (fix)
167-
16 |+if None is False: # E711, E712 (fix)
166+
16 |+if not None: # E711, E712 (fix)
168167
17 17 | pass
169168
18 18 |
170169
19 19 | named_var = []
@@ -184,7 +183,7 @@ constant_literals.py:16:4: E712 [*] Avoid equality comparisons to `False`; use `
184183
14 14 | if False == None: # E711, E712 (fix)
185184
15 15 | pass
186185
16 |-if None == False: # E711, E712 (fix)
187-
16 |+if None is False: # E711, E712 (fix)
186+
16 |+if not None: # E711, E712 (fix)
188187
17 17 | pass
189188
18 18 |
190189
19 19 | named_var = []

0 commit comments

Comments
 (0)