diff --git a/crates/ruff_linter/src/rules/pycodestyle/rules/literal_comparisons.rs b/crates/ruff_linter/src/rules/pycodestyle/rules/literal_comparisons.rs index 265d0c7046a6c..cffcde998fddf 100644 --- a/crates/ruff_linter/src/rules/pycodestyle/rules/literal_comparisons.rs +++ b/crates/ruff_linter/src/rules/pycodestyle/rules/literal_comparisons.rs @@ -1,9 +1,9 @@ +use ruff_python_ast::parenthesize::parenthesized_range; use rustc_hash::FxHashMap; use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix}; use ruff_macros::{derive_message_formats, ViolationMetadata}; -use ruff_python_ast::helpers; -use ruff_python_ast::helpers::generate_comparison; +use ruff_python_ast::helpers::{self, generate_comparison}; use ruff_python_ast::{self as ast, CmpOp, Expr}; use ruff_text_size::Ranged; @@ -170,6 +170,42 @@ impl AlwaysFixableViolation for TrueFalseComparison { } } +fn is_redundant_boolean_comparison(op: CmpOp, comparator: &Expr) -> Option { + let value = comparator.as_boolean_literal_expr()?.value; + match op { + CmpOp::Is | CmpOp::Eq => Some(value), + CmpOp::IsNot | CmpOp::NotEq => Some(!value), + _ => None, + } +} + +fn generate_redundant_comparison( + compare: &ast::ExprCompare, + comment_ranges: &ruff_python_trivia::CommentRanges, + source: &str, + comparator: &Expr, + kind: bool, + needs_wrap: bool, +) -> String { + let comparator_range = + parenthesized_range(comparator.into(), compare.into(), comment_ranges, source) + .unwrap_or(comparator.range()); + + let comparator_str = &source[comparator_range]; + + let result = if kind { + comparator_str.to_string() + } else { + format!("not {comparator_str}") + }; + + if needs_wrap { + format!("({result})") + } else { + result + } +} + /// E711, E712 pub(crate) fn literal_comparisons(checker: &Checker, compare: &ast::ExprCompare) { // Mapping from (bad operator index) to (replacement operator). As we iterate @@ -323,7 +359,6 @@ pub(crate) fn literal_comparisons(checker: &Checker, compare: &ast::ExprCompare) // TODO(charlie): Respect `noqa` directives. If one of the operators has a // `noqa`, but another doesn't, both will be removed here. if !bad_ops.is_empty() { - // Replace the entire comparison expression. let ops = compare .ops .iter() @@ -331,14 +366,53 @@ pub(crate) fn literal_comparisons(checker: &Checker, compare: &ast::ExprCompare) .map(|(idx, op)| bad_ops.get(&idx).unwrap_or(op)) .copied() .collect::>(); - let content = generate_comparison( - &compare.left, - &ops, - &compare.comparators, - compare.into(), - checker.comment_ranges(), - checker.source(), - ); + + let comment_ranges = checker.comment_ranges(); + let source = checker.source(); + + let content = match (&*compare.ops, &*compare.comparators) { + ([op], [comparator]) => { + if let Some(kind) = is_redundant_boolean_comparison(*op, &compare.left) { + let needs_wrap = compare.left.range().start() != compare.range().start(); + generate_redundant_comparison( + compare, + comment_ranges, + source, + comparator, + kind, + needs_wrap, + ) + } else if let Some(kind) = is_redundant_boolean_comparison(*op, comparator) { + let needs_wrap = comparator.range().end() != compare.range().end(); + generate_redundant_comparison( + compare, + comment_ranges, + source, + &compare.left, + kind, + needs_wrap, + ) + } else { + generate_comparison( + &compare.left, + &ops, + &compare.comparators, + compare.into(), + comment_ranges, + source, + ) + } + } + _ => generate_comparison( + &compare.left, + &ops, + &compare.comparators, + compare.into(), + comment_ranges, + source, + ), + }; + for diagnostic in &mut diagnostics { diagnostic.set_fix(Fix::unsafe_edit(Edit::range_replacement( content.to_string(), diff --git a/crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__E712_E712.py.snap b/crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__E712_E712.py.snap index 02f2c26951687..76d290a4b5cea 100644 --- a/crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__E712_E712.py.snap +++ b/crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__E712_E712.py.snap @@ -14,7 +14,7 @@ E712.py:2:4: E712 [*] Avoid equality comparisons to `True`; use `if res:` for tr ℹ Unsafe fix 1 1 | #: E712 2 |-if res == True: - 2 |+if res is True: + 2 |+if res: 3 3 | pass 4 4 | #: E712 5 5 | if res != False: @@ -35,7 +35,7 @@ E712.py:5:4: E712 [*] Avoid inequality comparisons to `False`; use `if res:` for 3 3 | pass 4 4 | #: E712 5 |-if res != False: - 5 |+if res is not False: + 5 |+if res: 6 6 | pass 7 7 | #: E712 8 8 | if True != res: @@ -56,7 +56,7 @@ E712.py:8:4: E712 [*] Avoid inequality comparisons to `True`; use `if not res:` 6 6 | pass 7 7 | #: E712 8 |-if True != res: - 8 |+if True is not res: + 8 |+if not res: 9 9 | pass 10 10 | #: E712 11 11 | if False == res: @@ -77,7 +77,7 @@ E712.py:11:4: E712 [*] Avoid equality comparisons to `False`; use `if not res:` 9 9 | pass 10 10 | #: E712 11 |-if False == res: - 11 |+if False is res: + 11 |+if not res: 12 12 | pass 13 13 | #: E712 14 14 | if res[1] == True: @@ -98,7 +98,7 @@ E712.py:14:4: E712 [*] Avoid equality comparisons to `True`; use `if res[1]:` fo 12 12 | pass 13 13 | #: E712 14 |-if res[1] == True: - 14 |+if res[1] is True: + 14 |+if res[1]: 15 15 | pass 16 16 | #: E712 17 17 | if res[1] != False: @@ -119,7 +119,7 @@ E712.py:17:4: E712 [*] Avoid inequality comparisons to `False`; use `if res[1]:` 15 15 | pass 16 16 | #: E712 17 |-if res[1] != False: - 17 |+if res[1] is not False: + 17 |+if res[1]: 18 18 | pass 19 19 | #: E712 20 20 | var = 1 if cond == True else -1 if cond == False else cond @@ -140,7 +140,7 @@ E712.py:20:12: E712 [*] Avoid equality comparisons to `True`; use `if cond:` for 18 18 | pass 19 19 | #: E712 20 |-var = 1 if cond == True else -1 if cond == False else cond - 20 |+var = 1 if cond is True else -1 if cond == False else cond + 20 |+var = 1 if cond else -1 if cond == False else cond 21 21 | #: E712 22 22 | if (True) == TrueElement or x == TrueElement: 23 23 | pass @@ -161,7 +161,7 @@ E712.py:20:36: E712 [*] Avoid equality comparisons to `False`; use `if not cond: 18 18 | pass 19 19 | #: E712 20 |-var = 1 if cond == True else -1 if cond == False else cond - 20 |+var = 1 if cond == True else -1 if cond is False else cond + 20 |+var = 1 if cond == True else -1 if not cond else cond 21 21 | #: E712 22 22 | if (True) == TrueElement or x == TrueElement: 23 23 | pass @@ -181,7 +181,7 @@ E712.py:22:4: E712 [*] Avoid equality comparisons to `True`; use `if TrueElement 20 20 | var = 1 if cond == True else -1 if cond == False else cond 21 21 | #: E712 22 |-if (True) == TrueElement or x == TrueElement: - 22 |+if (True) is TrueElement or x == TrueElement: + 22 |+if (TrueElement) or x == TrueElement: 23 23 | pass 24 24 | 25 25 | if res == True != False: @@ -241,7 +241,7 @@ E712.py:28:3: E712 [*] Avoid equality comparisons to `True`; use `if TrueElement 26 26 | pass 27 27 | 28 |-if(True) == TrueElement or x == TrueElement: - 28 |+if(True) is TrueElement or x == TrueElement: + 28 |+if(TrueElement) or x == TrueElement: 29 29 | pass 30 30 | 31 31 | if (yield i) == True: @@ -261,7 +261,7 @@ E712.py:31:4: E712 [*] Avoid equality comparisons to `True`; use `if yield i:` f 29 29 | pass 30 30 | 31 |-if (yield i) == True: - 31 |+if (yield i) is True: + 31 |+if (yield i): 32 32 | print("even") 33 33 | 34 34 | #: Okay diff --git a/crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__constant_literals.snap b/crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__constant_literals.snap index 04831e37b3bd6..7adb6732e7bf3 100644 --- a/crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__constant_literals.snap +++ b/crates/ruff_linter/src/rules/pycodestyle/snapshots/ruff_linter__rules__pycodestyle__tests__constant_literals.snap @@ -1,6 +1,5 @@ --- source: crates/ruff_linter/src/rules/pycodestyle/mod.rs -snapshot_kind: text --- constant_literals.py:4:4: F632 [*] Use `==` to compare constant literals | @@ -123,7 +122,7 @@ constant_literals.py:14:4: E712 [*] Avoid equality comparisons to `False`; use ` 12 12 | if False is "abc": # F632 (fix, but leaves behind unfixable E712) 13 13 | pass 14 |-if False == None: # E711, E712 (fix) - 14 |+if False is None: # E711, E712 (fix) + 14 |+if not None: # E711, E712 (fix) 15 15 | pass 16 16 | if None == False: # E711, E712 (fix) 17 17 | pass @@ -144,7 +143,7 @@ constant_literals.py:14:13: E711 [*] Comparison to `None` should be `cond is Non 12 12 | if False is "abc": # F632 (fix, but leaves behind unfixable E712) 13 13 | pass 14 |-if False == None: # E711, E712 (fix) - 14 |+if False is None: # E711, E712 (fix) + 14 |+if not None: # E711, E712 (fix) 15 15 | pass 16 16 | if None == False: # E711, E712 (fix) 17 17 | pass @@ -164,7 +163,7 @@ constant_literals.py:16:4: E711 [*] Comparison to `None` should be `cond is None 14 14 | if False == None: # E711, E712 (fix) 15 15 | pass 16 |-if None == False: # E711, E712 (fix) - 16 |+if None is False: # E711, E712 (fix) + 16 |+if not None: # E711, E712 (fix) 17 17 | pass 18 18 | 19 19 | named_var = [] @@ -184,7 +183,7 @@ constant_literals.py:16:4: E712 [*] Avoid equality comparisons to `False`; use ` 14 14 | if False == None: # E711, E712 (fix) 15 15 | pass 16 |-if None == False: # E711, E712 (fix) - 16 |+if None is False: # E711, E712 (fix) + 16 |+if not None: # E711, E712 (fix) 17 17 | pass 18 18 | 19 19 | named_var = []