-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][python] Wrappers for scf.index_switch #167458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The C++ index switch op has utilies for getCaseBlock(int i) and getDefaultBlock(), so these have been added. Optional builder args have been added for the default case and each switch case. The list comprehensions for accessing case regions are due to what appears to be a bug in RegionSequence; using a comprehension with explicit indices circumvents this. The same paradigm is used for get_case_block(i: int), but this is unavoidable.
|
@llvm/pr-subscribers-mlir Author: Asher Mancinelli (ashermancinelli) ChangesThe C++ index switch op has utilities for The list comprehensions for accessing case regions are due to what appears to be a bug in RegionSequence; using a comprehension with explicit indices circumvents this. The same paradigm is used for get_case_block(i: int), but this is unavoidable. Full diff: https://github.com/llvm/llvm-project/pull/167458.diff 2 Files Affected:
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 678ceeebac204..59ccbce147be3 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -6,12 +6,14 @@
from ._scf_ops_gen import *
from ._scf_ops_gen import _Dialect
from .arith import constant
+import builtins
try:
from ..ir import *
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
+ get_op_result_or_op_results as _get_op_result_or_op_results,
_cext as _ods_cext,
)
except ImportError as e:
@@ -254,3 +256,77 @@ def for_(
yield iv, iter_args[0], for_op.results[0]
else:
yield iv
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class IndexSwitchOp(IndexSwitchOp):
+ __doc__ = IndexSwitchOp.__doc__
+
+ def __init__(
+ self,
+ results_,
+ arg,
+ cases,
+ case_body_builder=None,
+ default_body_builder=None,
+ loc=None,
+ ip=None,
+ ):
+ cases = DenseI64ArrayAttr.get(cases)
+ super().__init__(
+ results_, arg, cases, num_caseRegions=len(cases), loc=loc, ip=ip
+ )
+ for region in self.regions:
+ region.blocks.append()
+
+ if default_body_builder is not None:
+ with InsertionPoint(self.default_block):
+ default_body_builder(self)
+
+ if case_body_builder is not None:
+ for i, case in enumerate(cases):
+ with InsertionPoint(self.case_block(i)):
+ case_body_builder(self, i, self.cases[i])
+
+ @builtins.property
+ def default_region(self) -> Region:
+ return self.regions[0]
+
+ @builtins.property
+ def default_block(self) -> Block:
+ return self.default_region.blocks[0]
+
+ @builtins.property
+ def case_regions(self) -> Sequence[Region]:
+ return [self.regions[1 + i] for i in range(len(self.cases))]
+
+ def case_region(self, i: int) -> Region:
+ return self.case_regions[i]
+
+ @builtins.property
+ def case_blocks(self) -> Sequence[Block]:
+ return [region.blocks[0] for region in self.case_regions]
+
+ def case_block(self, i: int) -> Block:
+ return self.case_regions[i].blocks[0]
+
+
+def index_switch(
+ results_,
+ arg,
+ cases,
+ case_body_builder=None,
+ default_body_builder=None,
+ loc=None,
+ ip=None,
+) -> Union[OpResult, OpResultList, IndexSwitchOp]:
+ op = IndexSwitchOp(
+ results_=results_,
+ arg=arg,
+ cases=cases,
+ case_body_builder=case_body_builder,
+ default_body_builder=default_body_builder,
+ loc=loc,
+ ip=ip,
+ )
+ return _get_op_result_or_op_results(op)
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 62d11d5e189c8..11d207b4a5e07 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -1,10 +1,14 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
-from mlir.dialects import arith
-from mlir.dialects import func
-from mlir.dialects import memref
-from mlir.dialects import scf
+from mlir.extras import types as T
+from mlir.dialects import (
+ arith,
+ func,
+ memref,
+ scf,
+ cf,
+)
from mlir.passmanager import PassManager
@@ -355,3 +359,117 @@ def simple_if_else(cond):
# CHECK: scf.yield %[[TWO]], %[[THREE]]
# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1
# CHECK: return
+
+
+@constructAndPrintInModule
+def testIndexSwitch():
+ i32 = T.i32()
+
+ @func.FuncOp.from_py_func(T.index(), results=[i32])
+ def index_switch(index):
+ c1 = arith.constant(i32, 1)
+ c0 = arith.constant(i32, 0)
+ value = arith.constant(i32, 5)
+ switch_op = scf.IndexSwitchOp([i32], index, range(3))
+
+ assert switch_op.regions[0] == switch_op.default_region
+ assert switch_op.regions[1] == switch_op.case_regions[0]
+ assert switch_op.regions[1] == switch_op.case_region(0)
+ assert len(switch_op.case_regions) == 3
+ assert len(switch_op.regions) == 4
+
+ with InsertionPoint(switch_op.default_block):
+ cf.assert_(arith.constant(T.bool(), 0), "Whoops!")
+ scf.yield_([c1])
+
+ for i, block in enumerate(switch_op.case_blocks):
+ with InsertionPoint(block):
+ scf.yield_([arith.constant(i32, i)])
+
+ func.return_([switch_op.results[0]])
+
+ return index_switch
+
+
+# CHECK-LABEL: func.func @index_switch(
+# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 {
+# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32
+# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32
+# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32
+# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32
+# CHECK: case 0 {
+# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32
+# CHECK: scf.yield %[[CONSTANT_3]] : i32
+# CHECK: }
+# CHECK: case 1 {
+# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32
+# CHECK: scf.yield %[[CONSTANT_4]] : i32
+# CHECK: }
+# CHECK: case 2 {
+# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32
+# CHECK: scf.yield %[[CONSTANT_5]] : i32
+# CHECK: }
+# CHECK: default {
+# CHECK: %[[CONSTANT_6:.*]] = arith.constant false
+# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!"
+# CHECK: scf.yield %[[CONSTANT_0]] : i32
+# CHECK: }
+# CHECK: return %[[INDEX_SWITCH_0]] : i32
+# CHECK: }
+
+
+@constructAndPrintInModule
+def testIndexSwitchWithBodyBuilders():
+ i32 = T.i32()
+
+ @func.FuncOp.from_py_func(T.index(), results=[i32])
+ def index_switch(index):
+ c1 = arith.constant(i32, 1)
+ c0 = arith.constant(i32, 0)
+ value = arith.constant(i32, 5)
+
+ def default_body_builder(switch_op):
+ cf.assert_(arith.constant(T.bool(), 0), "Whoops!")
+ scf.yield_([c1])
+
+ def case_body_builder(switch_op, case_index: int, case_value: int):
+ scf.yield_([arith.constant(i32, case_value)])
+
+ result = scf.index_switch(
+ results_=[i32],
+ arg=index,
+ cases=range(3),
+ case_body_builder=case_body_builder,
+ default_body_builder=default_body_builder,
+ )
+
+ func.return_([result])
+
+ return index_switch
+
+
+# CHECK-LABEL: func.func @index_switch(
+# CHECK-SAME: %[[ARG0:.*]]: index) -> i32 {
+# CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : i32
+# CHECK: %[[CONSTANT_1:.*]] = arith.constant 0 : i32
+# CHECK: %[[CONSTANT_2:.*]] = arith.constant 5 : i32
+# CHECK: %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32
+# CHECK: case 0 {
+# CHECK: %[[CONSTANT_3:.*]] = arith.constant 0 : i32
+# CHECK: scf.yield %[[CONSTANT_3]] : i32
+# CHECK: }
+# CHECK: case 1 {
+# CHECK: %[[CONSTANT_4:.*]] = arith.constant 1 : i32
+# CHECK: scf.yield %[[CONSTANT_4]] : i32
+# CHECK: }
+# CHECK: case 2 {
+# CHECK: %[[CONSTANT_5:.*]] = arith.constant 2 : i32
+# CHECK: scf.yield %[[CONSTANT_5]] : i32
+# CHECK: }
+# CHECK: default {
+# CHECK: %[[CONSTANT_6:.*]] = arith.constant false
+# CHECK: cf.assert %[[CONSTANT_6]], "Whoops!"
+# CHECK: scf.yield %[[CONSTANT_0]] : i32
+# CHECK: }
+# CHECK: return %[[INDEX_SWITCH_0]] : i32
+# CHECK: }
|
lemme fix that bug so we don't have to land a workaround in our own codebase lol |
The C++ index switch op has utilities for
getCaseBlock(int i)andgetDefaultBlock(), so these have been added.Optional body builder args have been added for the default case and each switch case.
The list comprehensions for accessing case regions are due to what appears to be a bug in RegionSequence; using a comprehension with explicit indices circumvents, but we can address this before merging this patch. The same paradigm is used for get_case_block(i: int), but this is unavoidable.