I've written a Python script to generate the closest square spiral matrix to a given number.
I'm new to code reviews as a practice, and I'm interested in ways to improve. Please suggest improvements to the code where you see fit, particularly in regards to:
- Algorithm: Is there a faster/more elegant way to generate the matrix?
 - Style: I've tried to adhere to PEP8.
 
"""
Outputs the closest square spiral matrix of an input number
"""
from math import sqrt
def is_even_square(num):
    """True if an integer is even as well as a perfect square"""
    return sqrt(num).is_integer() and (num % 2 == 0)
def find_nearest_square(num):
    """Returns the nearest even perfect square to a given integer"""
    for i in range(num):
        if is_even_square(num - i):
            return num - i
        elif is_even_square(num + i):
            return num + i
def find_lower_squares(num):
    """Returns a list of even perfect squares less than a given integer"""
    squares = []
    for i in range(num, 3, -1):
        if is_even_square(i): squares.append(i)
    return squares
def nth_row(num, n):
    """Returns the nth row of the square spiral matrix"""
    edge = int(sqrt(num))
    squares = find_lower_squares(num)
    if n == 0:
        return list(range(num, num - edge, -1))
    elif n >= edge - 1:
        return list(range(num - 3*edge + 3, num - 2*edge + 3))
    elif n < edge // 2:
        return ([squares[1] + n] + nth_row(squares[1],n-1)
              + [num - edge - n + 1])
    else:
        return ([num - 3*edge + 4 + n - edge] + nth_row(squares[1],n-1)
              + [num - 2*edge + 1 - n + edge])
def generate_square_spiral(num):
    """Generates a square spiral matrix from a given integer"""
    edge = int(sqrt(num))
    square_spiral = [[None for x in range(edge)] for y in range(edge)]
    for row in range(edge): square_spiral[row] = nth_row(num, row)
    return square_spiral
def main ():
    num = None
    while not num:
        try:
            num = int(input('Input number: '))
        except ValueError:
            print('Invalid Number')
    nearest_square = find_nearest_square(num)
    matrix = generate_square_spiral(nearest_square)
    for row in range(len(matrix[0])):
        for col in range(len(matrix)):
            if matrix[row][col] < 10:
                print('  ',matrix[row][col],' ',sep='',end='')
            elif matrix[row][col] < 100:
                print(' ',matrix[row][col],' ',sep='',end='')
            else:
                print(matrix[row][col],' ',sep='',end='')
        print(2*"\n",end='')
if __name__ == '__main__':
    main()