Skip to content

Commit ebf9703

Browse files
committed
Find comparison operation dynamically instead of using hardcoded location
1 parent 3510a3c commit ebf9703

File tree

2 files changed

+88
-31
lines changed

2 files changed

+88
-31
lines changed

pynguin/instrumentation/instrumentation.py

Lines changed: 70 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import json
1111
import logging
12+
from itertools import count
1213
from types import CodeType
1314
from typing import TYPE_CHECKING
1415

@@ -54,11 +55,11 @@ class InstrumentationAdapter:
5455
# TODO(fk) make this more fine grained? e.g. visit_line, visit_compare etc.
5556
# Or use sub visitors?
5657

57-
def visit_entry_node(self, block: BasicBlock, code_object_id: int) -> None:
58+
def visit_entry_node(self, basic_block: BasicBlock, code_object_id: int) -> None:
5859
"""Called when we visit the entry node of a code object.
5960
6061
Args:
61-
block: The basic block of the entry node.
62+
basic_block: The basic block of the entry node.
6263
code_object_id: The code object id of the containing code object.
6364
"""
6465

@@ -238,13 +239,9 @@ class BranchCoverageInstrumentation(InstrumentationAdapter):
238239
"""Instruments code objects to enable tracking branch distances and thus
239240
branch coverage."""
240241

241-
# Conditional jump operations are the last operation within a basic block
242+
# Jump operations are the last operation within a basic block
242243
_JUMP_OP_POS = -1
243244

244-
# If a conditional jump is based on a comparison, it has to be the second-to-last
245-
# instruction within the basic block.
246-
_COMPARE_OP_POS = -2
247-
248245
_logger = logging.getLogger(__name__)
249246

250247
def __init__(self, tracer: ExecutionTracer) -> None:
@@ -270,30 +267,59 @@ def visit_node(
270267

271268
assert len(basic_block) > 0, "Empty basic block in CFG."
272269
maybe_jump: Instr = basic_block[self._JUMP_OP_POS]
273-
maybe_compare: Instr | None = (
274-
basic_block[self._COMPARE_OP_POS] if len(basic_block) > 1 else None
270+
maybe_compare_idx: int | None = self._find_index_of_potential_compare_instr(
271+
basic_block
275272
)
276273
if isinstance(maybe_jump, Instr):
277274
predicate_id: int | None = None
278275
if maybe_jump.name == "FOR_ITER":
279276
predicate_id = self._instrument_for_loop(
280-
cfg, node, basic_block, code_object_id
277+
cfg=cfg,
278+
node=node,
279+
basic_block=basic_block,
280+
code_object_id=code_object_id,
281281
)
282282
elif maybe_jump.is_cond_jump():
283283
predicate_id = self._instrument_cond_jump(
284-
code_object_id,
285-
maybe_compare,
286-
maybe_jump,
287-
basic_block,
288-
node,
284+
code_object_id=code_object_id,
285+
maybe_compare_idx=maybe_compare_idx,
286+
jump=maybe_jump,
287+
block=basic_block,
288+
node=node,
289289
)
290290
if predicate_id is not None:
291291
node.predicate_id = predicate_id
292292

293+
@staticmethod
294+
def _find_index_of_potential_compare_instr(basic_block: BasicBlock) -> int | None:
295+
"""It may happen that another instrumentation added artificial instructions
296+
between the conditional jump and the preceding comparison. Find the index of the
297+
first non-artificial instruction that precedes the jump at the end of a basic
298+
block.
299+
300+
Args:
301+
basic_block: The block to search
302+
303+
Returns:
304+
The index of the first non-artificial instruction that precedes the jump.
305+
The index is negative, i.e., it indexes from the end.
306+
"""
307+
block_without_jump = basic_block[: BranchCoverageInstrumentation._JUMP_OP_POS]
308+
for idx, instr in zip(
309+
count(BranchCoverageInstrumentation._JUMP_OP_POS - 1, -1),
310+
reversed(block_without_jump),
311+
):
312+
if isinstance(instr, ArtificialInstr):
313+
# Skip over artificial instructions
314+
continue
315+
# Return first result
316+
return idx
317+
return None
318+
293319
def _instrument_cond_jump(
294320
self,
295321
code_object_id: int,
296-
maybe_compare: Instr | None,
322+
maybe_compare_idx: int | None,
297323
jump: Instr,
298324
block: BasicBlock,
299325
node: ProgramGraphNode,
@@ -307,27 +333,34 @@ def _instrument_cond_jump(
307333
308334
Args:
309335
code_object_id: The id of the containing Code Object.
310-
maybe_compare: The comparison operation, if any.
336+
maybe_compare_idx: The index of the comparison operation, if any.
311337
jump: The jump operation.
312338
block: The containing basic block.
313339
node: The associated node from the CFG.
314340
315341
Returns:
316342
The id that was assigned to the predicate.
317343
"""
344+
maybe_compare = block[maybe_compare_idx]
318345
if (
319346
maybe_compare is not None
320347
and isinstance(maybe_compare, Instr)
321348
and maybe_compare.name in ("COMPARE_OP", "IS_OP", "CONTAINS_OP")
322349
):
350+
assert maybe_compare_idx is not None
323351
return self._instrument_compare_based_conditional_jump(
324-
block, code_object_id, node
352+
block=block,
353+
code_object_id=code_object_id,
354+
compare_idx=maybe_compare_idx,
355+
node=node,
325356
)
326357
if jump.name == "JUMP_IF_NOT_EXC_MATCH":
327358
return self._instrument_exception_based_conditional_jump(
328-
block, code_object_id, node
359+
basic_block=block, code_object_id=code_object_id, node=node
329360
)
330-
return self._instrument_bool_based_conditional_jump(block, code_object_id, node)
361+
return self._instrument_bool_based_conditional_jump(
362+
block=block, code_object_id=code_object_id, node=node
363+
)
331364

332365
def _instrument_bool_based_conditional_jump(
333366
self, block: BasicBlock, code_object_id: int, node: ProgramGraphNode
@@ -369,7 +402,11 @@ def _instrument_bool_based_conditional_jump(
369402
return predicate_id
370403

371404
def _instrument_compare_based_conditional_jump(
372-
self, block: BasicBlock, code_object_id: int, node: ProgramGraphNode
405+
self,
406+
block: BasicBlock,
407+
compare_idx: int,
408+
code_object_id: int,
409+
node: ProgramGraphNode,
373410
) -> int:
374411
"""Instrument compare-based conditional jumps.
375412
@@ -378,6 +415,7 @@ def _instrument_compare_based_conditional_jump(
378415
379416
Args:
380417
block: The containing basic block.
418+
compare_idx: The index of the comparison index
381419
code_object_id: The id of the containing Code Object.
382420
node: The associated node from the CFG.
383421
@@ -391,7 +429,7 @@ def _instrument_compare_based_conditional_jump(
391429
predicate_id = self._tracer.register_predicate(
392430
PredicateMetaData(line_no=lineno, code_object_id=code_object_id, node=node)
393431
)
394-
operation = block[self._COMPARE_OP_POS]
432+
operation = block[compare_idx]
395433

396434
match operation.name:
397435
case "COMPARE_OP":
@@ -409,7 +447,7 @@ def _instrument_compare_based_conditional_jump(
409447
# Insert instructions right before the comparison.
410448
# We duplicate the values on top of the stack and report
411449
# them to the tracer.
412-
block[self._COMPARE_OP_POS : self._COMPARE_OP_POS] = [
450+
block[compare_idx:compare_idx] = [
413451
ArtificialInstr("DUP_TOP_TWO", lineno=lineno),
414452
ArtificialInstr("LOAD_CONST", self._tracer, lineno=lineno),
415453
ArtificialInstr(
@@ -427,29 +465,29 @@ def _instrument_compare_based_conditional_jump(
427465
return predicate_id
428466

429467
def _instrument_exception_based_conditional_jump(
430-
self, block: BasicBlock, code_object_id: int, node: ProgramGraphNode
468+
self, basic_block: BasicBlock, code_object_id: int, node: ProgramGraphNode
431469
) -> int:
432470
"""Instrument exception-based conditional jumps.
433471
434472
We add a call to the tracer which reports the values that will be used
435473
in the following exception matching case.
436474
437475
Args:
438-
block: The containing basic block.
476+
basic_block: The containing basic block.
439477
code_object_id: The id of the containing Code Object.
440478
node: The associated node from the CFG.
441479
442480
Returns:
443481
The id assigned to the predicate.
444482
"""
445-
lineno = block[self._JUMP_OP_POS].lineno
483+
lineno = basic_block[self._JUMP_OP_POS].lineno
446484
predicate_id = self._tracer.register_predicate(
447485
PredicateMetaData(line_no=lineno, code_object_id=code_object_id, node=node)
448486
)
449487
# Insert instructions right before the conditional jump.
450488
# We duplicate the values on top of the stack and report
451489
# them to the tracer.
452-
block[self._JUMP_OP_POS : self._JUMP_OP_POS] = [
490+
basic_block[self._JUMP_OP_POS : self._JUMP_OP_POS] = [
453491
ArtificialInstr("DUP_TOP_TWO", lineno=lineno),
454492
ArtificialInstr("LOAD_CONST", self._tracer, lineno=lineno),
455493
ArtificialInstr(
@@ -465,19 +503,20 @@ def _instrument_exception_based_conditional_jump(
465503
]
466504
return predicate_id
467505

468-
def visit_entry_node(self, block: BasicBlock, code_object_id: int) -> None:
506+
def visit_entry_node(self, basic_block: BasicBlock, code_object_id: int) -> None:
469507
"""Add instructions at the beginning of the given basic block which inform
470508
the tracer, that the code object with the given id has been entered.
471509
472510
Args:
473-
block: The entry basic block of a code object, i.e. the first basic block.
511+
basic_block: The entry basic block of a code object, i.e. the first basic
512+
block.
474513
code_object_id: The id that the tracer has assigned to the code object
475514
which contains the given basic block.
476515
"""
477516
# Use line number of first instruction
478-
lineno = block[0].lineno
517+
lineno = basic_block[0].lineno
479518
# Insert instructions at the beginning.
480-
block[0:0] = [
519+
basic_block[0:0] = [
481520
ArtificialInstr("LOAD_CONST", self._tracer, lineno=lineno),
482521
ArtificialInstr(
483522
"LOAD_METHOD",

tests/instrumentation/test_instrumentation.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from pynguin.analyses.seeding import DynamicConstantSeeding
1717
from pynguin.instrumentation.instrumentation import (
18+
ArtificialInstr,
1819
BranchCoverageInstrumentation,
1920
DynamicSeedingInstrumentation,
2021
InstrumentationTransformer,
@@ -218,6 +219,23 @@ def test_avoid_duplicate_instrumentation(simple_module):
218219
transformer.instrument_module(already_instrumented)
219220

220221

222+
@pytest.mark.parametrize(
223+
"block,expected",
224+
[
225+
([], None),
226+
([MagicMock()], None),
227+
([MagicMock(), MagicMock()], -2),
228+
([MagicMock(), ArtificialInstr("POP_TOP"), MagicMock()], -3),
229+
([ArtificialInstr("POP_TOP"), ArtificialInstr("POP_TOP"), MagicMock()], None),
230+
],
231+
)
232+
def test__find_index_of_potential_compare_instr(block, expected):
233+
assert (
234+
BranchCoverageInstrumentation._find_index_of_potential_compare_instr(block)
235+
== expected
236+
)
237+
238+
221239
@pytest.mark.parametrize(
222240
"function_name, branchless_function_count, branches_count",
223241
[

0 commit comments

Comments
 (0)