Statement:
Create a script that, given an integer \$n\$, create a square matrix of dimensions \$n \times n\$ with the numbers from \$1\$ to \$n^2\$, arranged in a snail pattern.
Example:
1 2 3
8 9 4
7 6 5
I'm concerned by code quality, maintainability and general advices.
My implementation:
import collections.abc
import itertools
import math
type Matrix = list[list[int]]
def update_layer_edges(
    c: collections.abc.Iterator, matrix: Matrix, layer_start: int, layer_end: int
) -> None:
    """Define the current layer tiles."""
    for y, x in itertools.chain(
        # top edge, where x increases and y is fixed to minimum
        ((layer_start, x) for x in range(layer_start, layer_end)),
        # right edge where y increases and x is fixed to maximum
        ((y, layer_end - 1) for y in range(layer_start + 1, layer_end)),
        # bot edge where x decreases and y is fixed to maximum
        ((layer_end - 1, x) for x in range(layer_end - 2, layer_start - 1, -1)),
        # left edge where y decreases and x is fixed to minimum
        ((y, layer_start) for y in range(layer_end - 2, layer_start, -1)),
    ):
        matrix[y][x] = next(c)
def square(size: int) -> list[list[int]]:
    """Create an size x size matrix arranged in a snail pattern."""
    matrix = [[0] * size for _ in range(size)]
    c = itertools.count(1)
    layer_start = 0
    layer_end = size
    for _ in range(math.ceil(size / 2)):  # Repeat for each layer of the matrix
        update_layer_edges(c, matrix, layer_start, layer_end)
        layer_start, layer_end = layer_start + 1, layer_end - 1
    return matrix
Test the implementation:
if __name__ == "__main__":
    def print_matrix(matrix: Matrix) -> None:
        padding = len(str(len(matrix) ** 2))
        for row in matrix:
            print(" ".join(f"{num:>{padding}}" for num in row))
    print_matrix(square(10))
    
    assert square(1) == [[1]]
    assert square(2) == [[1, 2], [4, 3]]
    assert square(3) == [[1, 2, 3], [8, 9, 4], [7, 6, 5]]
    assert square(4) == [[1, 2, 3, 4], [12, 13, 14, 5], [11, 16, 15, 6], [10, 9, 8, 7]]
    assert square(5) == [
        [1, 2, 3, 4, 5],
        [16, 17, 18, 19, 6],
        [15, 24, 25, 20, 7],
        [14, 23, 22, 21, 8],
        [13, 12, 11, 10, 9],
    ]
Visual representation of the matrix with code that generate the picture (I use matplotlib==3.8.4):
def show_snail(n) -> None:
    import matplotlib.pyplot as plt
    matrix = square(n)
    # Using matplotlib to create a heatmap of the matrix
    plt.figure(figsize=(8, 8))
    plt.imshow(
        matrix, cmap="viridis"
    )  # You can choose any colormap like 'viridis', 'hot', 'cool', etc.
    plt.gca().set_position(
        [0, 0, 1, 1]
    )  # Set the axes to fill the whole figure area
    plt.gcf().set_size_inches(w=1, h=1)
    plt.axis("off")  # Hide the axes
    plt.show()
show_snail(5)
    
