DEV Community

Cover image for 🚀 LeetCode 3197: Covering All Ones with 3 Rectangles (C++, Python & Java)🚀
Om Shree
Om Shree

Posted on

🚀 LeetCode 3197: Covering All Ones with 3 Rectangles (C++, Python & Java)🚀

LeetCode’s Problem 3197 (Hard) is one of those grid-based problems that force us to think deeply about partitioning, optimization, and dynamic computation. Many developers struggle with finding efficient solutions because it requires geometry + DP + prefix sums reasoning.

Problem Statement (LeetCode 3197)

We are given a 2D binary grid. We need to find 3 non-overlapping rectangles (aligned with grid axes, non-zero area, allowed to touch) such that all 1s are covered and the sum of areas is minimized.

👉 Example 1:

Input: grid = [[1,0,1],
               [1,1,1]]

Output: 5
Enter fullscreen mode Exit fullscreen mode

Explanation:

  • Cover (0,0) and (1,0) → rectangle of area 2
  • Cover (0,2) and (1,2) → rectangle of area 2
  • Cover (1,1) alone → rectangle of area 1 Total area = 5

💡 Key Insights

  1. Bounding Rectangles: To cover all 1s in a region with one rectangle, we can just compute the smallest bounding rectangle of the 1’s.
  2. Recursive Splits: To cover with 2 rectangles, we try partitioning by row or column.
  3. Dynamic Programming + Prefix Sums:
  • Use row and column suffix sums to quickly check presence of 1s in a region.
  • Build helper functions:

    • f1(...) → Minimum area with 1 rectangle
    • f2(...) → Minimum area with 2 rectangles
  • Finally combine for 3 rectangles using partitions.

This hierarchical breakdown (1 → 2 → 3 rectangles) avoids brute force and gives a much faster solution.


🧩 Step-by-Step Solution Strategy

  1. Precompute Suffix Sums:
  • row_sfx[i][j]: Number of 1s in row i from column j to end.
  • col_sfx[j][i]: Number of 1s in column j from row i to bottom.
  1. Function f1:
  • Shrinks boundaries until we enclose all 1s → returns area of bounding box.
  1. Function f2:
  • Try splitting the region into 2 parts (horizontal or vertical).
  • Minimize sum of f1 for each partition.
  1. Final Combination:
  • For 3 rectangles, try splitting first into (1 rect + 2 rect) or (2 rect + 1 rect) partitions.
  • Iterate over all row and column cuts.

🖥️ C++ Implementation (Efficient Solution)

Here’s the solution you shared, polished for readability:

class Solution {
public:
    vector<vector<int>> row_sfx, col_sfx;

    int get_row_count(int i, int l, int r) {
        return row_sfx[i][l] - row_sfx[i][r + 1];
    }

    int get_col_count(int j, int t, int b) {
        return col_sfx[j][t] - col_sfx[j][b + 1];
    }

    // Cover all 1s with 1 rectangle
    int f1(int si, int ei, int sj, int ej) {
        while (si < ei && !get_row_count(si, sj, ej)) si++;
        while (ei > si && !get_row_count(ei, sj, ej)) ei--;
        while (sj < ej && !get_col_count(sj, si, ei)) sj++;
        while (ej > sj && !get_col_count(ej, si, ei)) ej--;

        int h = max(0, ei - si + 1), w = max(0, ej - sj + 1);
        return h * w;
    }

    // Cover all 1s with 2 rectangles
    int f2(int si, int ei, int sj, int ej) {
        int res = f1(si, ei, sj, ej);
        for (int i = si; i < ei; i++)
            res = min(res, f1(si, i, sj, ej) + f1(i + 1, ei, sj, ej));
        for (int j = sj; j < ej; j++)
            res = min(res, f1(si, ei, sj, j) + f1(si, ei, j + 1, ej));
        return res;
    }

    int minimumSum(vector<vector<int>>& grid) {
        int n = grid.size(), m = grid[0].size();
        row_sfx.assign(n, vector<int>(m + 1));
        col_sfx.assign(m, vector<int>(n + 1));

        // Precompute suffix sums
        for (int i = 0; i < n; i++)
            for (int j = m - 1; j >= 0; j--)
                row_sfx[i][j] = grid[i][j] + row_sfx[i][j + 1];

        for (int j = 0; j < m; j++)
            for (int i = n - 1; i >= 0; i--)
                col_sfx[j][i] = grid[i][j] + col_sfx[j][i + 1];

        int ans = n * m;
        // Partition rows
        for (int i = 0; i < n - 1; i++) {
            ans = min(ans, f1(0, i, 0, m - 1) + f2(i + 1, n - 1, 0, m - 1));
            ans = min(ans, f2(0, i, 0, m - 1) + f1(i + 1, n - 1, 0, m - 1));
        }
        // Partition cols
        for (int j = 0; j < m - 1; j++) {
            ans = min(ans, f1(0, n - 1, 0, j) + f2(0, n - 1, j + 1, m - 1));
            ans = min(ans, f2(0, n - 1, 0, j) + f1(0, n - 1, j + 1, m - 1));
        }
        return ans;
    }
};
Enter fullscreen mode Exit fullscreen mode

🐍 Python Implementation

class Solution:
    def minimumSum(self, grid):
        n, m = len(grid), len(grid[0])

        # Precompute suffix sums
        row_sfx = [[0]*(m+1) for _ in range(n)]
        col_sfx = [[0]*(n+1) for _ in range(m)]
        for i in range(n):
            for j in range(m-1, -1, -1):
                row_sfx[i][j] = grid[i][j] + row_sfx[i][j+1]
        for j in range(m):
            for i in range(n-1, -1, -1):
                col_sfx[j][i] = grid[i][j] + col_sfx[j][i+1]

        def get_row_count(i, l, r):
            return row_sfx[i][l] - row_sfx[i][r+1]

        def get_col_count(j, t, b):
            return col_sfx[j][t] - col_sfx[j][b+1]

        def f1(si, ei, sj, ej):
            while si < ei and all(get_row_count(si, sj, ej) == 0): si += 1
            while ei > si and all(get_row_count(ei, sj, ej) == 0): ei -= 1
            while sj < ej and all(get_col_count(sj, si, ei) == 0): sj += 1
            while ej > sj and all(get_col_count(ej, si, ei) == 0): ej -= 1
            return max(0, ei-si+1) * max(0, ej-sj+1)

        def f2(si, ei, sj, ej):
            res = f1(si, ei, sj, ej)
            for i in range(si, ei):
                res = min(res, f1(si, i, sj, ej) + f1(i+1, ei, sj, ej))
            for j in range(sj, ej):
                res = min(res, f1(si, ei, sj, j) + f1(si, ei, j+1, ej))
            return res

        ans = n*m
        for i in range(n-1):
            ans = min(ans, f1(0, i, 0, m-1) + f2(i+1, n-1, 0, m-1))
            ans = min(ans, f2(0, i, 0, m-1) + f1(i+1, n-1, 0, m-1))
        for j in range(m-1):
            ans = min(ans, f1(0, n-1, 0, j) + f2(0, n-1, j+1, m-1))
            ans = min(ans, f2(0, n-1, 0, j) + f1(0, n-1, j+1, m-1))
        return ans
Enter fullscreen mode Exit fullscreen mode

☕ Java Implementation

class Solution {
    int[][] row_sfx, col_sfx;

    int getRowCount(int i, int l, int r) {
        return row_sfx[i][l] - row_sfx[i][r+1];
    }

    int getColCount(int j, int t, int b) {
        return col_sfx[j][t] - col_sfx[j][b+1];
    }

    int f1(int si, int ei, int sj, int ej) {
        while (si < ei && getRowCount(si, sj, ej) == 0) si++;
        while (ei > si && getRowCount(ei, sj, ej) == 0) ei--;
        while (sj < ej && getColCount(sj, si, ei) == 0) sj++;
        while (ej > sj && getColCount(ej, si, ei) == 0) ej--;
        return Math.max(0, ei-si+1) * Math.max(0, ej-sj+1);
    }

    int f2(int si, int ei, int sj, int ej) {
        int res = f1(si, ei, sj, ej);
        for (int i = si; i < ei; i++)
            res = Math.min(res, f1(si, i, sj, ej) + f1(i+1, ei, sj, ej));
        for (int j = sj; j < ej; j++)
            res = Math.min(res, f1(si, ei, sj, j) + f1(si, ei, j+1, ej));
        return res;
    }

    public int minimumSum(int[][] grid) {
        int n = grid.length, m = grid[0].length;
        row_sfx = new int[n][m+1];
        col_sfx = new int[m][n+1];

        for (int i = 0; i < n; i++)
            for (int j = m-1; j >= 0; j--)
                row_sfx[i][j] = grid[i][j] + row_sfx[i][j+1];
        for (int j = 0; j < m; j++)
            for (int i = n-1; i >= 0; i--)
                col_sfx[j][i] = grid[i][j] + col_sfx[j][i+1];

        int ans = n*m;
        for (int i = 0; i < n-1; i++) {
            ans = Math.min(ans, f1(0, i, 0, m-1) + f2(i+1, n-1, 0, m-1));
            ans = Math.min(ans, f2(0, i, 0, m-1) + f1(i+1, n-1, 0, m-1));
        }
        for (int j = 0; j < m-1; j++) {
            ans = Math.min(ans, f1(0, n-1, 0, j) + f2(0, n-1, j+1, m-1));
            ans = Math.min(ans, f2(0, n-1, 0, j) + f1(0, n-1, j+1, m-1));
        }
        return ans;
    }
}
Enter fullscreen mode Exit fullscreen mode

⚡ Complexity Analysis

  • Precomputation: O(n·m)
  • Each f1: O(1) after shrinking boundaries
  • Each f2: O(n + m) (splitting rows + cols)
  • Final loops: O(n + m) partitions
  • ✅ Overall complexity: O(n² + m²) (efficient for n, m ≤ 30)

🎯 Takeaways

  • This problem is a great example of combining geometry + DP partitioning + prefix sums.
  • It teaches us how to reduce a seemingly exponential partition problem into manageable recursive functions (f1, f2, and final combo).
  • Having multiple language implementations makes the logic accessible to more developers.

Top comments (0)