From 9b8521df6086c90ca43ca2e4ed5bb6a7dd527d0f Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui Date: Thu, 3 Jul 2025 11:44:22 -0700 Subject: [PATCH] [Mosaic GPU] Replace repeated calls to `reduce_hint` with `reduce_hints`. This allows more easily dropping `Hint`s if they become `Unsatisfiable`, which we didn't do properly before (we didn't update `reduce_hint` when `reduce_expression` became able to return `Unsatisfiable`). PiperOrigin-RevId: 778933116 --- jax/experimental/mosaic/gpu/equations.py | 34 +-- .../mosaic/gpu/layout_inference2.py | 247 +++++++++--------- tests/mosaic/gpu_equations_test.py | 32 +-- tests/mosaic/gpu_layout_inference_test.py | 24 +- 4 files changed, 163 insertions(+), 174 deletions(-) diff --git a/jax/experimental/mosaic/gpu/equations.py b/jax/experimental/mosaic/gpu/equations.py index 41be6fc44e45..9dc0f834a46e 100644 --- a/jax/experimental/mosaic/gpu/equations.py +++ b/jax/experimental/mosaic/gpu/equations.py @@ -130,25 +130,25 @@ def most_replicated_expression( return None -def simplify_expression( +def reduce_expression( expr: Expression, assignments: dict[Variable, ConstantExpression] ) -> Expression: - """Simplifies an expression as much as is possible given a set of known variable assignments.""" - simplify = simplify_expression + """Reduces an expression as much as is possible given a set of known variable assignments.""" + reduce = reduce_expression match expr: case ConstantExpression(): return expr case Variable(): return assignments.get(expr, expr) case MostReplicatedExpression(expressions=expressions): - reduced_expressions = tuple(simplify(e, assignments) for e in expressions) + reduced_expressions = tuple(reduce(e, assignments) for e in expressions) if most_replicated := most_replicated_expression(reduced_expressions): - return simplify(most_replicated, assignments) + return reduce(most_replicated, assignments) return MostReplicatedExpression(expressions=reduced_expressions) case LeastReplicatedExpression(expressions=expressions): - reduced_expressions = tuple(simplify(e, assignments) for e in expressions) + reduced_expressions = tuple(reduce(e, assignments) for e in expressions) if least_replicated := least_replicated_expression(reduced_expressions): - return simplify(least_replicated, assignments) + return reduce(least_replicated, assignments) return LeastReplicatedExpression(expressions=reduced_expressions) case _: assert_never(expr) @@ -180,8 +180,8 @@ def reduce_equation( a variable. - Unknown(): if the equation contains remaining unknown variables. """ - lhs = simplify_expression(eq.lhs, assignments) - rhs = simplify_expression(eq.rhs, assignments) + lhs = reduce_expression(eq.lhs, assignments) + rhs = reduce_expression(eq.rhs, assignments) match (lhs, rhs): case (Variable(), ConstantExpression()): return SatisfiedBy((lhs, rhs)) @@ -269,16 +269,16 @@ class Tautological: Solution = Unsatisfiable | SatisfiedBy | Unknown | Tautological -def _simplify_system_once( +def _reduce_system_once( equation_system: EquationSystem, ) -> EquationSystem | Unsatisfiable | None: - """Performs one simplification step over each equation in an equation system. + """Performs one reduction step over each equation in an equation system. Returns: - Unsatisfiable(): if the equation system is unsatisfiable. - - A new equation system if any equation was simplified. + - A new equation system if any equation was reduced. - None: if the equation system is not known unsatisfiable, but hasn't been - simplified. + reduced. """ changed = False assignments: dict[Variable, ConstantExpression] = dict() @@ -309,15 +309,15 @@ def _simplify_system_once( return None -def simplify(equation_system: EquationSystem) -> EquationSystem | Unsatisfiable: - """Simplifies an equation system until it can no longer be simplified. +def reduce(equation_system: EquationSystem) -> EquationSystem | Unsatisfiable: + """Reduces an equation system until it can no longer be reduced. Returns: - Unsatisfiable(): if the equation system is unsatisfiable. - - The maximally simplified equation system otherwise. + - The maximally reduced equation system otherwise. """ while True: - match (new_system := _simplify_system_once(equation_system)): + match (new_system := _reduce_system_once(equation_system)): case None: break case Unsatisfiable(): diff --git a/jax/experimental/mosaic/gpu/layout_inference2.py b/jax/experimental/mosaic/gpu/layout_inference2.py index db8b12944f2c..8205504d6ce7 100644 --- a/jax/experimental/mosaic/gpu/layout_inference2.py +++ b/jax/experimental/mosaic/gpu/layout_inference2.py @@ -56,21 +56,6 @@ class VariableKey: index: int -@dataclasses.dataclass(frozen=True) -class Variable(eqns.Variable): - """This variable represents an operand/result of a MLIR operation.""" - def __init__(self, operation: ir.OpView, type: VariableType, index: int): - super().__init__(VariableKey(operation, type, index)) - - @property - def is_operand(self) -> bool: - return self.key.type == VariableType.OPERAND - - @property - def is_result(self) -> bool: - return self.key.type == VariableType.RESULT - - @dataclasses.dataclass(frozen=True) class Hint: """Hints are used to model propagation of layouts across operations. @@ -80,13 +65,13 @@ class Hint: an equation-like form of "soft constraints", i.e., it suggests that `variable` should be equal to `expression`. """ - variable: Variable + variable: eqns.Variable expression: eqns.Expression def choose_variable_assignment_from_hints( hints: Sequence[Hint], -) -> tuple[Variable, eqns.ConstantExpression] | None: +) -> tuple[eqns.Variable, eqns.ConstantExpression] | None: """Attempts to choose a single variable assignment from a list of `Hint`s.""" for hint in hints: if isinstance(hint.expression, eqns.ConstantExpression): @@ -94,18 +79,30 @@ def choose_variable_assignment_from_hints( return None -def simplify_hint( - h: Hint, assignments: dict[Variable, eqns.ConstantExpression] -) -> Hint: - """Like `eqns.simplify_equation` but for `Hint`s.""" - return dataclasses.replace( - h, expression=eqns.simplify_expression(h.expression, assignments)) +def reduce_hints( + hints: Sequence[Hint], assignments: dict[eqns.Variable, eqns.ConstantExpression] +) -> Sequence[Hint]: + """Reduces a sequence of `Hint`s. + + We simplify the `Hint`s' expressions, drop `Unsatisfiable` hints, and drop + `Hint`s pertaining to pre-existing assignments. + """ + new_hints: list[Hint] = [] + for h in hints: + if h.variable not in assignments: + reduced_expression = eqns.reduce_expression(h.expression, assignments) + if isinstance(reduced_expression, eqns.Unsatisfiable): + continue + new_hints.append(dataclasses.replace(h, expression=reduced_expression)) + + return new_hints + def find_assignments_for( - unknowns: set[Variable], + unknowns: set[eqns.Variable], equation_system: eqns.EquationSystem, hints: Sequence[Hint], -) -> dict[Variable, eqns.ConstantExpression] | eqns.Unsatisfiable: +) -> dict[eqns.Variable, eqns.ConstantExpression] | eqns.Unsatisfiable: """Attempts to find assignments that satisfy `equation_system` for `unknowns`. Args: @@ -119,7 +116,7 @@ def find_assignments_for( such that the assignment satisfies the equation system otherwise. """ while True: - equation_system = eqns.simplify(equation_system) + equation_system = eqns.reduce(equation_system) if isinstance(equation_system, eqns.Unsatisfiable): return eqns.Unsatisfiable() @@ -130,15 +127,14 @@ def find_assignments_for( if not remaining_unknowns: return {v: k for v, k in equation_system.assignments.items() if v in unknowns} - # Simplify the expressions in the remaining hints based on the current + # Reduce the expressions in the remaining hints based on the current # assignments, and eliminate hints that pertain to variables that already # have an assignment. - hints = [simplify_hint(h, equation_system.assignments) for h in hints - if h.variable not in equation_system.assignments] + hints = reduce_hints(hints, equation_system.assignments) - # If unknowns remain and we have fully simplified the system, we may still + # If unknowns remain and we have fully reduced the system, we may still # be able to make progress by extracting an assignment from a `Hint`. In a - # system that has otherwise been fully simplified, it is guaranteed that + # system that has otherwise been fully reduced, it is guaranteed that # introducing a new assignment will yield a system that remains satisfiable # if the original system was satisfiable---because this is a sign of an # underdetermined system. @@ -156,7 +152,7 @@ def find_assignments_for( if variable in equation_system.assignments: continue # Try to instantiate a single variable to a strided layout and see if it - # simplifies the system. + # reduces the system. op = variable.key.operation # TODO(bchetioui): should we make variables carry a shape as well, to make # things easier? @@ -187,7 +183,21 @@ def find_assignments_for( return eqns.Unsatisfiable() -EquationSystemDerivationRule = Callable[[ir.OpView], eqns.EquationSystem] +KeysForVariable = dict[eqns.Variable, list[VariableKey]] + +# An equation system derivation rule is a function that takes an MLIR operation +# and returns an equation system, and a mapping from variables to variable keys. +# The intended meaning of the mapping is that, for each variable key in the list +# keyed by a given variable, the operand/result corresponding to that key has +# the same layout as the variable. +# +# An `EquationSystemDerivationRule` must return a mapping such that the variable +# key corresponding to each operand/result must appear in the mapping, and each +# variable key in the mapping must be keyed by exactly one variable. Lastly, +# the mapping must only refer to variables and variable keys that correspond to +# the given operation. +EquationSystemDerivationRule = Callable[ + [ir.OpView], tuple[eqns.EquationSystem, KeysForVariable]] _equation_system_derivation_rules: dict[str, EquationSystemDerivationRule] = {} @@ -208,26 +218,31 @@ def _constant_equation_system( constant_op: arith.ConstantOp ) -> eqns.EquationSystem: value = constant_op.value - variable = Variable(constant_op, VariableType.RESULT, 0) + key = VariableKey(constant_op, VariableType.RESULT, 0) + variable = eqns.Variable(key) if ( ir.DenseElementsAttr.isinstance(value) and ir.DenseElementsAttr(value).is_splat ): layout = fa.WGSplatFragLayout(shape=tuple(constant_op.result.type.shape)) - return eqns.EquationSystem(assignments={variable: eqns.ConstantExpression(layout)}) - return eqns.EquationSystem() + system = eqns.EquationSystem(assignments={variable: eqns.ConstantExpression(layout)}) + else: + system = eqns.EquationSystem() + + return system, {variable: [key]} @_add_equation_system_derivation_rule(mgpu.LayoutCastOp) def _layout_cast_equation_system( op: mgpu.LayoutCastOp -) -> eqns.EquationSystem: - in_variable = Variable(op, VariableType.OPERAND, 0) - out_variable = Variable(op, VariableType.RESULT, 0) +) -> tuple[eqns.EquationSystem, KeysForVariable]: + in_key = VariableKey(op, VariableType.OPERAND, 0) + out_key = VariableKey(op, VariableType.RESULT, 0) + variable = eqns.Variable(in_key) out_layout = eqns.ConstantExpression(layouts_lib.from_layout_attr(op.new_layout)) return eqns.EquationSystem( - assignments={out_variable: out_layout, in_variable: out_layout}, - ) + assignments={eqns.Variable(in_key): out_layout}, + ), {variable: [in_key, out_key]} def _ensure_all_layouts_are_set(op: ir.OpView): @@ -253,7 +268,7 @@ def _ensure_right_number_of_layouts( ) -def assign_layouts(solution: dict[Variable, eqns.ConstantExpression]): +def assign_layouts(solution: dict[VariableKey, eqns.ConstantExpression]): """Assigns the layouts in `solution` to the MLIR ops they belong to. This function requires that, for each MLIR op that appears in `solution`, @@ -261,18 +276,18 @@ def assign_layouts(solution: dict[Variable, eqns.ConstantExpression]): results. """ solution_sorted_by_op = sorted( - solution.items(), key=lambda kv: id(kv[0].key.operation) + solution.items(), key=lambda kv: id(kv[0].operation) ) solution_per_op = itertools.groupby( - solution_sorted_by_op, key=lambda kv: kv[0].key.operation + solution_sorted_by_op, key=lambda kv: kv[0].operation ) for op, assignments in solution_per_op: - assignments_sorted_by_type = sorted(assignments, key=lambda kv: kv[0].key.type) + assignments_sorted_by_type = sorted(assignments, key=lambda kv: kv[0].type) assignments_by_type = { ty: list(group) for ty, group in itertools.groupby( - assignments_sorted_by_type, key=lambda kv: kv[0].key.type + assignments_sorted_by_type, key=lambda kv: kv[0].type ) } @@ -280,11 +295,11 @@ def assign_layouts(solution: dict[Variable, eqns.ConstantExpression]): out_assignments = assignments_by_type.get(VariableType.RESULT, []) in_layouts = [ - ce.value for _, ce in sorted(in_assignments, key=lambda kv: kv[0].key.index) + ce.value for _, ce in sorted(in_assignments, key=lambda kv: kv[0].index) ] out_layouts = [ ce.value - for _, ce in sorted(out_assignments, key=lambda kv: kv[0].key.index) + for _, ce in sorted(out_assignments, key=lambda kv: kv[0].index) ] _ensure_right_number_of_layouts(op, in_layouts, out_layouts) @@ -294,109 +309,83 @@ def assign_layouts(solution: dict[Variable, eqns.ConstantExpression]): op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts_attrs) -def op_variables(op: ir.OpView) -> list[Variable]: - """Returns all the operand and result variables for the given op.""" - variables = [ - Variable(op, VariableType.OPERAND, i) +def op_variable_keys(op: ir.OpView) -> list[VariableKey]: + """Returns all the operand and result variable keys for the given op.""" + keys = [ + VariableKey(op, VariableType.OPERAND, i) for i, o in enumerate(op.operands) if is_vector(o) ] - variables.extend([ - Variable(op, VariableType.RESULT, i) + keys.extend([ + VariableKey(op, VariableType.RESULT, i) for i, o in enumerate(op.results) if is_vector(o) ]) - return variables + return keys -def producer_variable(variable: Variable) -> Variable: - """Given a variable, returns the corresponding result variable in its producer. +def producer_variable_key(key: VariableKey) -> VariableKey: + """Given a variable key, returns the corresponding result variable key in its producer. - The variable has to represent an operand of its operation. + The variable key has to represent an operand of its operation. """ - assert variable.is_operand - value = variable.key.operation.operands[variable.key.index] + assert key.type == VariableType.OPERAND + value = key.operation.operands[key.index] producer = value.owner if isinstance(producer, ir.Operation): index = list(producer.results).index(value) - return Variable(producer.opview, VariableType.RESULT, index) + return VariableKey(producer.opview, VariableType.RESULT, index) # Block case, useful for deriving layouts for ops # depending on function parameters, or loop block arguments. if isinstance(producer, ir.Block): index = list(cast(ir.Block, producer).arguments).index(value) - return Variable(producer, VariableType.OPERAND, index) + return VariableKey(producer, VariableType.OPERAND, index) raise TypeError( f"Producer {producer} is not an operation nor a block: {type(producer)}." ) -def consumer_variables(variable: Variable) -> Sequence[Variable]: - """Given a variable, returns the corresponding operand variables in its consumers. +def consumer_variable_keys(key: VariableKey) -> Sequence[eqns.VariableKey]: + """Given a variable key, returns the corresponding operand variable keys in its consumers. - The variable has to represent a result of its operation. + The variable key has to represent a result of its operation. """ - assert variable.is_result - consumer_variables: list[Variable] = [] + assert key.type == VariableType.RESULT + consumer_keys: list[VariableKey] = [] # The layout can also be chosen from the layout of the consumers of the # results. - for use in cast(ir.OpResult, variable.key.operation.results[variable.key.index]).uses: + for use in cast(ir.OpResult, key.operation.results[key.index]).uses: consumer = use.owner.opview # pytype: disable=attribute-error index = use.operand_number - consumer_variables.append(Variable(consumer, VariableType.OPERAND, index)) - return consumer_variables - - -def equation_system_and_hints_for_op( - op: ir.OpView, rule: EquationSystemDerivationRule -) -> tuple[eqns.EquationSystem, list[Hint]]: - """Produces an equation system and a list of hints for the given op. - - The equation system is derived directly from the given rule, and is not - further constrained. Hints are subsequently derived from this equation system - that relate the variables of the op to the producers of the op's operands and - the consumers of the op's results. - """ - equation_system = rule(op) - all_variables: list[Variable] = op_variables(op) - visited: set[Variable] = set() - hints: list[Hint] = list() + consumer_keys.append(VariableKey(consumer, VariableType.OPERAND, index)) + return consumer_keys + + +def derive_hints(keys_for_variable: KeysForVariable) -> list[Hint]: + """Derives propagation hints from the given variable mapping.""" + hints: list[Hint] = [] + variable_for_key: dict[VariableKey, eqns.Variable] = {} + for variable, keys in keys_for_variable.items(): + for k in keys: + if k in variable_for_key: + raise ValueError( + f"Key {k} is mapped to both {variable} and {variable_for_key[k]}" + ) + variable_for_key |= {key: variable for key in keys} - for variable in all_variables: - if variable in visited: - continue - # Construct a list containing all the variables that are necessary equal to - # the current variable. Consider the following pseudo-program: - # - # a = producer0() # variable v0 is producer0's out_layouts[0] - # b = producer1() # variable v1 is producer1's out_layouts[0] - # c = add(a, b) # variable v2, v3, v4 are respectively add's in_layouts[0], in_layouts[1], and out_layouts[0] - # consumer0(c) # variable v5 is consumer0's in_layouts[0] - # consumer1(c) # variable v6 is consumer1's in_layouts[0] - # - # We know that v2 = v3 = v4, and we may want to propagate a layout from v0, - # v1, v5, or v6. For that reason, we capture all the connected variables, - # and then extract their producer/consumers to construct a `Hint`. - # - # We use a list here because we care about having a deterministic iteration - # order. - union: list[Variable] = [variable] - for equation in equation_system.equations: - lhs, rhs = equation.lhs, equation.rhs - if lhs == variable and isinstance(rhs, Variable) and rhs not in union: - union.append(rhs) - if rhs == variable and isinstance(lhs, Variable) and lhs not in union: - union.append(lhs) - - producers = tuple(producer_variable(v) for v in union if v.is_operand) - consumers: list[Variable] = [] - for v in union: - if v.is_result: - consumers.extend(consumer_variables(v)) + for variable, keys in keys_for_variable.items(): + producers: list[eqns.Variable] = [] + consumers: list[eqns.Variable] = [] + for k in keys: + if k.type == VariableType.OPERAND: + producers.append(variable_for_key[producer_variable_key(k)]) + elif k.type == VariableType.RESULT: + consumers.extend(variable_for_key[c] for c in consumer_variable_keys(k)) if producers: - least_replicated_producer = eqns.LeastReplicatedExpression(producers) + least_replicated_producer = eqns.LeastReplicatedExpression(tuple(producers)) hint_expr = eqns.MostReplicatedExpression( (least_replicated_producer, *consumers) ) @@ -404,15 +393,13 @@ def equation_system_and_hints_for_op( elif consumers: hint_expr = eqns.MostReplicatedExpression(tuple(consumers)) hints.append(Hint(variable, hint_expr)) - visited.update(union) - return equation_system, [simplify_hint(h, equation_system.assignments) for h in hints] + return hints def infer_layout(module: ir.Module): global_equation_system = eqns.EquationSystem() - all_hints: list[Hint] = [] - variables: set[Variable] = set() + keys_for_variable: KeysForVariable = {} def gather_equations(op: ir.Operation): if not inference_utils.should_have_layout(op): @@ -422,25 +409,31 @@ def gather_equations(op: ir.Operation): else: raise NotImplementedError(f"No layout inference rule defined for {op}") - variables.update(op_variables(op)) + equation_system, mapping = rule(op) + keys_for_variable.update(mapping) nonlocal global_equation_system - equation_system, hints = equation_system_and_hints_for_op(op, rule) global_equation_system &= equation_system - all_hints.extend(hints) for op in module.body: inference_utils.traverse_op(op, gather_equations) + assert not isinstance(global_equation_system, eqns.Unsatisfiable) + hints = reduce_hints(derive_hints(keys_for_variable), global_equation_system.assignments) # pytype: disable=attribute-error + # Attempt to find assignments that satisfy the equation system. - solution = find_assignments_for(variables, global_equation_system, all_hints) + solution = find_assignments_for( + keys_for_variable.keys(), global_equation_system, hints + ) if isinstance(solution, eqns.Unsatisfiable): raise ValueError( "Failed to infer a possible set of layouts. This should never happen." ) + layout_for_key = {k: solution[v] for v, ks in keys_for_variable.items() for k in ks} + # Assigns the layouts that we found to the ops. - assign_layouts(solution) + assign_layouts(layout_for_key) # Sanity check: ensure that all ops have the right number of in/out layouts. for op in module.body: diff --git a/tests/mosaic/gpu_equations_test.py b/tests/mosaic/gpu_equations_test.py index 8369aff74ecb..901bb41dc8e8 100644 --- a/tests/mosaic/gpu_equations_test.py +++ b/tests/mosaic/gpu_equations_test.py @@ -35,23 +35,23 @@ def test_equation_system_is_unsatisfiable_if_assignments_are_incompatible(self): system = equations.EquationSystem( equations=[Eq(v0, layout0), Eq(v0, layout1)], ) - self.assertIsInstance(equations.simplify(system), equations.Unsatisfiable) + self.assertIsInstance(equations.reduce(system), equations.Unsatisfiable) - def test_simplify_equation_system_removes_tautological_equations(self): + def test_reduce_equation_system_removes_tautological_equations(self): v0, v1 = V(0), V(1) system = equations.EquationSystem( equations=[Eq(v0, v1), Eq(v0, v0)], ) - self.assertLen(equations.simplify(system).equations, 1) + self.assertLen(equations.reduce(system).equations, 1) - def test_simplify_equation_system_of_simplified_system_is_noop(self): + def test_reduce_equation_system_of_simplified_system_is_noop(self): v0, v1 = V(0), V(1) system = equations.EquationSystem( equations=[Eq(v0, v1)], ) - self.assertEqual(equations.simplify(system), system) + self.assertEqual(equations.reduce(system), system) - def test_simplify_equation_system_assigns_variables_with_known_equations(self): + def test_reduce_equation_system_assigns_variables_with_known_equations(self): v0, v1 = V(0), V(1) layout = C(mgpu.WGSplatFragLayout((1, 1))) @@ -60,7 +60,7 @@ def test_simplify_equation_system_assigns_variables_with_known_equations(self): equations=[Eq(v0, layout), Eq(v0, v1)], ) self.assertEqual( - equations.simplify(system), + equations.reduce(system), equations.EquationSystem(assignments={v0: layout, v1: layout}) ) @@ -69,7 +69,7 @@ def test_simplify_equation_system_assigns_variables_with_known_equations(self): equations=[Eq(v1, layout), Eq(v0, v1)], ) self.assertEqual( - equations.simplify(system), + equations.reduce(system), equations.EquationSystem(assignments={v0: layout, v1: layout}) ) @@ -109,7 +109,7 @@ def test_intersection_of_compatible_systems_is_union_of_fields(self): self.assertSequenceEqual(system1.unknowns(), [v1]) self.assertSequenceEqual(system_intersection.unknowns(), [v0, v1]) - def test_simplify_extracts_most_replicated_expression_correctly(self): + def test_reduce_extracts_most_replicated_expression_correctly(self): v0 = V(0) shape = (1, 128) layout0 = C(mgpu.WGSplatFragLayout(shape)) @@ -119,7 +119,7 @@ def test_simplify_extracts_most_replicated_expression_correctly(self): equations=[Eq(v0, equations.MostReplicatedExpression((layout0, layout1)))], ) self.assertEqual( - equations.simplify(system), + equations.reduce(system), equations.EquationSystem(assignments={v0: layout0}) ) @@ -128,7 +128,7 @@ def test_simplify_extracts_most_replicated_expression_correctly(self): equations=[Eq(v0, equations.MostReplicatedExpression((layout0,)))], ) self.assertEqual( - equations.simplify(system), + equations.reduce(system), equations.EquationSystem(assignments={v0: layout0}) ) @@ -139,9 +139,9 @@ def test_simplify_extracts_most_replicated_expression_correctly(self): system = equations.EquationSystem( equations=[Eq(v0, equations.MostReplicatedExpression((v1, layout1)))], ) - self.assertEqual(equations.simplify(system), system) + self.assertEqual(equations.reduce(system), system) - def test_simplify_extracts_least_replicated_expression_correctly(self): + def test_reduce_extracts_least_replicated_expression_correctly(self): v0 = V(0) shape = (1, 128) layout0 = C(mgpu.WGSplatFragLayout(shape)) @@ -151,7 +151,7 @@ def test_simplify_extracts_least_replicated_expression_correctly(self): equations=[Eq(v0, equations.LeastReplicatedExpression([layout0, layout1]))], ) self.assertEqual( - equations.simplify(system), + equations.reduce(system), equations.EquationSystem(assignments={v0: layout1}) ) @@ -160,7 +160,7 @@ def test_simplify_extracts_least_replicated_expression_correctly(self): equations=[Eq(v0, equations.LeastReplicatedExpression((layout0,)))], ) self.assertEqual( - equations.simplify(system), + equations.reduce(system), equations.EquationSystem(assignments={v0: layout0}) ) @@ -171,7 +171,7 @@ def test_simplify_extracts_least_replicated_expression_correctly(self): system = equations.EquationSystem( equations=[Eq(v0, equations.LeastReplicatedExpression((v1, layout0)))], ) - self.assertEqual(equations.simplify(system), system) + self.assertEqual(equations.reduce(system), system) if __name__ == "__main__": diff --git a/tests/mosaic/gpu_layout_inference_test.py b/tests/mosaic/gpu_layout_inference_test.py index 64111d3d729d..8b16f11f6eda 100644 --- a/tests/mosaic/gpu_layout_inference_test.py +++ b/tests/mosaic/gpu_layout_inference_test.py @@ -636,7 +636,7 @@ def body(lhs, rhs): class LayoutInferenceTestEquations(LayoutInferenceTest, inference_impl=InferenceImplementation.EQUATIONS): ... - def test_hint_extraction_for_op_works_correctly(self): + def test_hint_extraction_works_correctly(self): shape = (64,) bf16 = ir.BF16Type.get() layout = mgpu.WGMMA_ROW_LAYOUT @@ -647,19 +647,14 @@ def test_hint_extraction_for_op_works_correctly(self): cst = arith.ConstantOp(ty, ir.DenseElementsAttr.get(attrs, type=ty)) lc = layout_cast(cst, layouts.to_layout_attr(layout)).owner.opview - equation_system, hints = layout_inference2.equation_system_and_hints_for_op( - lc, layout_inference2._layout_cast_equation_system - ) - - in_variable, out_variable = layout_inference2.op_variables(lc) - [cst_out_variable] = layout_inference2.op_variables(cst) + cst_system, cst_mapping = layout_inference2._constant_equation_system(cst) + lc_system, lc_mapping = layout_inference2._layout_cast_equation_system(lc) + assignments = cst_system.assignments | lc_system.assignments + [hint_cst] = layout_inference2.reduce_hints( + layout_inference2.derive_hints(cst_mapping | lc_mapping), assignments) - assignments = {v: C(layout) for v in [in_variable, out_variable]} - - self.assertEqual( - equation_system, equations.EquationSystem(assignments=assignments) - ) - self.assertEqual(hints, [H(in_variable, cst_out_variable)]) + self.assertEqual(hint_cst.variable.key.operation, cst) + self.assertEqual(hint_cst.expression, C(layout)) def test_unambiguous_hints_are_used_to_assign_variables_correctly(self): v0 = V(0) @@ -680,7 +675,8 @@ def test_cannot_find_assignments_for_unsatisfiable_equation_system(self): attrs = [ir.FloatAttr.get(bf16, i) for i in range(shape[0])] cst = arith.ConstantOp(ty, ir.DenseElementsAttr.get(attrs, type=ty)) - [variable] = layout_inference2.op_variables(cst) + [key] = layout_inference2.op_variable_keys(cst) + variable = equations.Variable(key) assignments = layout_inference2.find_assignments_for( {variable}, equations.EquationSystem(