Skip to content

Conversation

@ashermancinelli
Copy link
Contributor

@ashermancinelli ashermancinelli commented Nov 11, 2025

The C++ index switch op has utilities for getCaseBlock(int i) and getDefaultBlock(), so these have been added.
Optional body builder args have been added for the default case and each switch case.

The list comprehensions for accessing case regions are due to what appears to be a bug in RegionSequence; using a comprehension with explicit indices circumvents, but we can address this before merging this patch. The same paradigm is used for get_case_block(i: int), but this is unavoidable.

The C++ index switch op has utilies for getCaseBlock(int i)
and getDefaultBlock(), so these have been added.

Optional builder args have been added for the default case
and each switch case.

The list comprehensions for accessing case regions are due
to what appears to be a bug in RegionSequence; using a comprehension
with explicit indices circumvents this.
The same paradigm is used for get_case_block(i: int), but this
is unavoidable.
@llvmbot
Copy link
Member

llvmbot commented Nov 11, 2025

@llvm/pr-subscribers-mlir

Author: Asher Mancinelli (ashermancinelli)

Changes

The C++ index switch op has utilities for getCaseBlock(int i) and getDefaultBlock(), so these have been added.
Optional body builder args have been added for the default case and each switch case.

The list comprehensions for accessing case regions are due to what appears to be a bug in RegionSequence; using a comprehension with explicit indices circumvents this. The same paradigm is used for get_case_block(i: int), but this is unavoidable.


Full diff: https://github.com/llvm/llvm-project/pull/167458.diff

2 Files Affected:

  • (modified) mlir/python/mlir/dialects/scf.py (+76)
  • (modified) mlir/test/python/dialects/scf.py (+122-4)
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 678ceeebac204..59ccbce147be3 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -6,12 +6,14 @@
 from ._scf_ops_gen import *
 from ._scf_ops_gen import _Dialect
 from .arith import constant
+import builtins
 
 try:
     from ..ir import *
     from ._ods_common import (
         get_op_result_or_value as _get_op_result_or_value,
         get_op_results_or_values as _get_op_results_or_values,
+        get_op_result_or_op_results as _get_op_result_or_op_results,
         _cext as _ods_cext,
     )
 except ImportError as e:
@@ -254,3 +256,77 @@ def for_(
             yield iv, iter_args[0], for_op.results[0]
         else:
             yield iv
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class IndexSwitchOp(IndexSwitchOp):
+    __doc__ = IndexSwitchOp.__doc__
+
+    def __init__(
+        self,
+        results_,
+        arg,
+        cases,
+        case_body_builder=None,
+        default_body_builder=None,
+        loc=None,
+        ip=None,
+    ):
+        cases = DenseI64ArrayAttr.get(cases)
+        super().__init__(
+            results_, arg, cases, num_caseRegions=len(cases), loc=loc, ip=ip
+        )
+        for region in self.regions:
+            region.blocks.append()
+
+        if default_body_builder is not None:
+            with InsertionPoint(self.default_block):
+                default_body_builder(self)
+
+        if case_body_builder is not None:
+            for i, case in enumerate(cases):
+                with InsertionPoint(self.case_block(i)):
+                    case_body_builder(self, i, self.cases[i])
+
+    @builtins.property
+    def default_region(self) -> Region:
+        return self.regions[0]
+
+    @builtins.property
+    def default_block(self) -> Block:
+        return self.default_region.blocks[0]
+
+    @builtins.property
+    def case_regions(self) -> Sequence[Region]:
+        return [self.regions[1 + i] for i in range(len(self.cases))]
+
+    def case_region(self, i: int) -> Region:
+        return self.case_regions[i]
+
+    @builtins.property
+    def case_blocks(self) -> Sequence[Block]:
+        return [region.blocks[0] for region in self.case_regions]
+
+    def case_block(self, i: int) -> Block:
+        return self.case_regions[i].blocks[0]
+
+
+def index_switch(
+    results_,
+    arg,
+    cases,
+    case_body_builder=None,
+    default_body_builder=None,
+    loc=None,
+    ip=None,
+) -> Union[OpResult, OpResultList, IndexSwitchOp]:
+    op = IndexSwitchOp(
+        results_=results_,
+        arg=arg,
+        cases=cases,
+        case_body_builder=case_body_builder,
+        default_body_builder=default_body_builder,
+        loc=loc,
+        ip=ip,
+    )
+    return _get_op_result_or_op_results(op)
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index 62d11d5e189c8..11d207b4a5e07 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -1,10 +1,14 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 from mlir.ir import *
-from mlir.dialects import arith
-from mlir.dialects import func
-from mlir.dialects import memref
-from mlir.dialects import scf
+from mlir.extras import types as T
+from mlir.dialects import (
+    arith,
+    func,
+    memref,
+    scf,
+    cf,
+)
 from mlir.passmanager import PassManager
 
 
@@ -355,3 +359,117 @@ def simple_if_else(cond):
 # CHECK:   scf.yield %[[TWO]], %[[THREE]]
 # CHECK: arith.addi %[[RET]]#0, %[[RET]]#1
 # CHECK: return
+
+
+@constructAndPrintInModule
+def testIndexSwitch():
+    i32 = T.i32()
+
+    @func.FuncOp.from_py_func(T.index(), results=[i32])
+    def index_switch(index):
+        c1 = arith.constant(i32, 1)
+        c0 = arith.constant(i32, 0)
+        value = arith.constant(i32, 5)
+        switch_op = scf.IndexSwitchOp([i32], index, range(3))
+
+        assert switch_op.regions[0] == switch_op.default_region
+        assert switch_op.regions[1] == switch_op.case_regions[0]
+        assert switch_op.regions[1] == switch_op.case_region(0)
+        assert len(switch_op.case_regions) == 3
+        assert len(switch_op.regions) == 4
+
+        with InsertionPoint(switch_op.default_block):
+            cf.assert_(arith.constant(T.bool(), 0), "Whoops!")
+            scf.yield_([c1])
+
+        for i, block in enumerate(switch_op.case_blocks):
+            with InsertionPoint(block):
+                scf.yield_([arith.constant(i32, i)])
+
+        func.return_([switch_op.results[0]])
+
+    return index_switch
+
+
+# CHECK-LABEL:   func.func @index_switch(
+# CHECK-SAME:      %[[ARG0:.*]]: index) -> i32 {
+# CHECK:           %[[CONSTANT_0:.*]] = arith.constant 1 : i32
+# CHECK:           %[[CONSTANT_1:.*]] = arith.constant 0 : i32
+# CHECK:           %[[CONSTANT_2:.*]] = arith.constant 5 : i32
+# CHECK:           %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32
+# CHECK:           case 0 {
+# CHECK:             %[[CONSTANT_3:.*]] = arith.constant 0 : i32
+# CHECK:             scf.yield %[[CONSTANT_3]] : i32
+# CHECK:           }
+# CHECK:           case 1 {
+# CHECK:             %[[CONSTANT_4:.*]] = arith.constant 1 : i32
+# CHECK:             scf.yield %[[CONSTANT_4]] : i32
+# CHECK:           }
+# CHECK:           case 2 {
+# CHECK:             %[[CONSTANT_5:.*]] = arith.constant 2 : i32
+# CHECK:             scf.yield %[[CONSTANT_5]] : i32
+# CHECK:           }
+# CHECK:           default {
+# CHECK:             %[[CONSTANT_6:.*]] = arith.constant false
+# CHECK:             cf.assert %[[CONSTANT_6]], "Whoops!"
+# CHECK:             scf.yield %[[CONSTANT_0]] : i32
+# CHECK:           }
+# CHECK:           return %[[INDEX_SWITCH_0]] : i32
+# CHECK:         }
+
+
+@constructAndPrintInModule
+def testIndexSwitchWithBodyBuilders():
+    i32 = T.i32()
+
+    @func.FuncOp.from_py_func(T.index(), results=[i32])
+    def index_switch(index):
+        c1 = arith.constant(i32, 1)
+        c0 = arith.constant(i32, 0)
+        value = arith.constant(i32, 5)
+
+        def default_body_builder(switch_op):
+            cf.assert_(arith.constant(T.bool(), 0), "Whoops!")
+            scf.yield_([c1])
+
+        def case_body_builder(switch_op, case_index: int, case_value: int):
+            scf.yield_([arith.constant(i32, case_value)])
+
+        result = scf.index_switch(
+            results_=[i32],
+            arg=index,
+            cases=range(3),
+            case_body_builder=case_body_builder,
+            default_body_builder=default_body_builder,
+        )
+
+        func.return_([result])
+
+    return index_switch
+
+
+# CHECK-LABEL:   func.func @index_switch(
+# CHECK-SAME:      %[[ARG0:.*]]: index) -> i32 {
+# CHECK:           %[[CONSTANT_0:.*]] = arith.constant 1 : i32
+# CHECK:           %[[CONSTANT_1:.*]] = arith.constant 0 : i32
+# CHECK:           %[[CONSTANT_2:.*]] = arith.constant 5 : i32
+# CHECK:           %[[INDEX_SWITCH_0:.*]] = scf.index_switch %[[ARG0]] -> i32
+# CHECK:           case 0 {
+# CHECK:             %[[CONSTANT_3:.*]] = arith.constant 0 : i32
+# CHECK:             scf.yield %[[CONSTANT_3]] : i32
+# CHECK:           }
+# CHECK:           case 1 {
+# CHECK:             %[[CONSTANT_4:.*]] = arith.constant 1 : i32
+# CHECK:             scf.yield %[[CONSTANT_4]] : i32
+# CHECK:           }
+# CHECK:           case 2 {
+# CHECK:             %[[CONSTANT_5:.*]] = arith.constant 2 : i32
+# CHECK:             scf.yield %[[CONSTANT_5]] : i32
+# CHECK:           }
+# CHECK:           default {
+# CHECK:             %[[CONSTANT_6:.*]] = arith.constant false
+# CHECK:             cf.assert %[[CONSTANT_6]], "Whoops!"
+# CHECK:             scf.yield %[[CONSTANT_0]] : i32
+# CHECK:           }
+# CHECK:           return %[[INDEX_SWITCH_0]] : i32
+# CHECK:         }
@makslevental
Copy link
Contributor

are due to what appears to be a bug in RegionSequence

lemme fix that bug so we don't have to land a workaround in our own codebase lol

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:python MLIR Python bindings mlir

3 participants