I have implemented a class in Python using Matplotlib for visualizing multiple subplots along with navigation buttons to move between different subsets of data. Each subplot displays a contour plot along with a colorbar.
While it kind of works, I am seeking advice on optimizing the code for better performance and efficiency. I want to improve the way subplots are updated and cleared when navigating between different subsets of data. I am open to suggestions. Any insights or alternative approaches to achieve the same functionality would be greatly appreciated.
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
import numpy as np
class MultipleFigures: 
    def __init__(self, data):
        self.indices = [0, 9]  # 9 subplots
        self.data = data
        self.fig, self.ax = plt.subplots(3, 3, figsize=(15, 7))
        self.ax = self.ax.ravel()
        plt.subplots_adjust(wspace=0.5, hspace=0.5)
        self.cb = list()  # List to store colorbars
        self.plot_initial_data()  # Create initial plots
        nextax = plt.axes([0.8, 0.02, 0.1, 0.04])
        prevax = plt.axes([0.1, 0.02, 0.1, 0.04])
        self.button_next = Button(nextax, 'Next')
        self.button_prev = Button(prevax, 'Previous')
        self.button_next.on_clicked(self.next)
        self.button_prev.on_clicked(self.previous)
        
    def plot_initial_data(self):
        # Create initial plots for the first 9 subplots
        self.temp = self.data[slice(*self.indices)]
        for i, dataframe in enumerate(self.temp):
            axes, colorbar = self.contourplot(dataframe, ax=self.ax[i])  # Plot contour plot and get axes and colorbar
            self.cb.append(colorbar)  # Store colorbar
        self.clear_unused_plots(len(self.data)) # Hide unused plots
        
    def next(self, event): 
        if self.indices[1] >= len(self.data):  # If at the end of data, do nothing
            return 
        else: 
            self._clear_previous_axes()  # Clear previous plots
            self.indices = [i + 9 for i in self.indices]
            self.temp = self.data[self.indices[0]:self.indices[1]]
            self.update_plots(self.temp)
            self.clear_unused_plots(len(self.temp))   # Clear unused plots for the final rows
            
    def previous(self, event): 
        if self.indices[0] == 0:  # If at the beginning of data, do nothing
            return 
        else: 
            self._clear_previous_axes()  # Clear previous plots
            self.indices = [i - 9 for i in self.indices]
            self.temp = self.data[slice(*self.indices)]
            self.update_plots(self.temp)
    def update_plots(self, temp_df):
        # Update plots for the current set of data
        for i, dataframe in enumerate(temp_df):
            self.ax[i].set_visible(True)  # Make subplot visible
            _, colorbar = self.contourplot(dataframe, ax=self.ax[i]) 
            self.cb.append(colorbar)  
        self.fig.canvas.draw_idle()  
    def clear_unused_plots(self, num_plots_to_clear):
        for i in range(num_plots_to_clear, len(self.ax)):
            self.ax[i].clear() 
            self.ax[i].set_visible(False) 
    def _clear_previous_axes(self): 
        for i in range(len(self.temp)): 
            self.cb[i].remove() 
            self.ax[i].clear() 
        self.cb.clear()  
    @staticmethod
    def contourplot(data, ax):
        contour = ax.contourf(data)
        cb = plt.colorbar(contour, ax=ax)
        return ax, cb
%matplotlib qt
data = np.random.rand(20,10,10)
fl_ = MultipleFigures(data)
plt.show()
