DEV Community

Cover image for Neural Network in Rust on MNIST dataset from scratch
Dhruv Jain
Dhruv Jain

Posted on

Neural Network in Rust on MNIST dataset from scratch

Implement and train a neural network from scratch on MNIST dataset in Rust without using high-level libraries like TensorFlow or PyTorch.

You can find the code at: https://github.com/dhruvkjain/mnist-nn-rs

It demonstrates:

  • Manual forward and backward propagation
  • Use of ReLU and softmax activation functions
  • One-hot encoding
  • Gradient descent for training
  • Accuracy evaluation
  • Model parameter export to CSV using polars

🔧 Dependencies

  • ndarray (store 2d array of data)
  • ndarray-rand (generate intial random weights(w) and biases(b))
  • polars (to read write data in csv)

🧠 Model Overview

  • input layer, 1 hidden layer, output layer
  • Input: 784-dimensional MNIST images
  • Hidden layer: 10 neurons with ReLU as activation function
  • Output layer: 10 neurons with softmax as activation function for multi-class classification

📂 Structure

  • main.rs: Training loop and evaluation
  • lib.rs: Core model logic — forward, backward, update, softmax, etc.
  • mnistdata/: Contains input dataset

📦 Dataset

Make sure the MNIST dataset is placed in mnistdata/.

Prerequisite:

image

Intialization of data using polars crate:

pub fn load_training_data() -> Result<(Array2<f32>, Array2<f32>), Box<dyn Error>> {
    let q = LazyCsvReader::new("./mnistdata/mnist_train.csv")
    .with_has_header(true)
    .finish()?;

    let training_labels = q
        .clone()
        .with_streaming(true)
        .select([col("label")])
        .collect()?;

    let training_data = q
        .clone()
        .with_streaming(true)
        .drop([col("label")])
        .collect()?;

    let mut traning_data_ndarray = training_data
        .to_ndarray::<Float32Type>(IndexOrder::Fortran)
        .unwrap();
    let mut training_labels_ndarray = training_labels
        .to_ndarray::<Float32Type>(IndexOrder::Fortran)
        .unwrap();

    traning_data_ndarray = traning_data_ndarray.reversed_axes()/ 255.0;
    training_labels_ndarray = training_labels_ndarray.reversed_axes();

    let data_dimensions:&[usize] = traning_data_ndarray.shape();
    let labels_dimensions:&[usize] = training_labels_ndarray.shape();

    // println!("{}", traning_data_ndarray);
    // println!("{}", training_labels_ndarray);
    println!("DATA: {}, {}", data_dimensions[0], data_dimensions[1]);
    println!("LABELS: {}, {}", labels_dimensions[0], labels_dimensions[1]);
    Ok((traning_data_ndarray, training_labels_ndarray))
}
Enter fullscreen mode Exit fullscreen mode

What is Neural Network and how to use it to recognize digits:

image


image


Our Approach for Neural Network:

image

Declaring intial weights and biasis using ndarray_rand crate:

pub fn init_params()->(Array2<f32>, Array2<f32>, Array2<f32>, Array2<f32>){

    let w1 = Array2::random((10, 784), Uniform::new(-0.5, 0.5));
    let b1 = Array2::random((10, 1), Uniform::new(-0.5, 0.5));
    let w2 = Array2::random((10, 10), Uniform::new(-0.5, 0.5));
    let b2 = Array2::random((10, 1), Uniform::new(-0.5, 0.5));

    (w1,b1,w2,b2)
}
Enter fullscreen mode Exit fullscreen mode

Forward Propagation:

image


pub fn relu(z:&mut Array2<f32>){
    z.mapv_inplace(|x| x.max(0.0))
}

pub fn softmax(z: &mut Array2<f32>) {
    for mut col in z.axis_iter_mut(Axis(1)) {
        // Subtract max for numerical stability
        let max = col.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
        col.mapv_inplace(|x| (x - max).exp());

        let sum = col.sum();
        col.mapv_inplace(|x| x / sum);
    }
}

pub fn forward_propagation(
    w1:&mut Array2<f32>,
    b1:&mut Array2<f32>,
    w2:&mut Array2<f32>,
    b2:&mut Array2<f32>,
    x:&mut  Array2<f32>
) -> (Array2<f32>, Array2<f32>, Array2<f32>, Array2<f32>)
{
    let z1 = w1.dot(x) + &*b1;
    let mut a1 = z1.clone();
    relu(&mut a1);

    let z2 = w2.dot(&a1) + &*b2;
    let mut a2 = z2.clone();
    softmax(&mut a2);

    (z1,a1,z2,a2)
}
Enter fullscreen mode Exit fullscreen mode

Back Propagation:

image

pub fn one_hot_encoded(y:&mut Array2<f32>, num_classes:usize) -> Array2<f32> {
    let ydash= y.flatten();
    let label_dimensions:&[usize] = ydash.shape();
    let mut one_hot_y = Array2::<f32>::zeros((label_dimensions[0], num_classes));

    for (row, &label) in ydash.iter().enumerate(){
        let class_index = label as usize;
        one_hot_y[(row, class_index)] = 1.0;
    }
    one_hot_y.reversed_axes()
}

pub fn deriv_relu(z:&mut Array2<f32>){
    z.mapv_inplace(|x| if x > 0.0 { 1.0 } else { 0.0 });
}

pub fn backward_propagation(
    z1:&mut Array2<f32>,
    a1:&mut Array2<f32>,
    a2:&mut Array2<f32>,
    w2:&mut Array2<f32>,
    x:&mut Array2<f32>,
    y:&mut Array2<f32>,
)->(Array2<f32>, Array2<f32>, Array2<f32>, Array2<f32>){
    let m = y.len() as f32;
    let a1t = a1.view().reversed_axes();
    let w2t = w2.view().reversed_axes();
    let xt = x.view().reversed_axes();
    let one_hot_y = one_hot_encoded(y, 10);

    let dz2 = &*a2 - &one_hot_y;
    let dw2 = (1.0/m)*(dz2.dot(&a1t));
    let db2 = dz2.sum_axis(Axis(1)).insert_axis(Axis(1)) * (1.0 / m);

    let mut z1_deriv = z1.clone();
    deriv_relu(&mut z1_deriv);
    let dz1 = w2t.dot(&dz2)*z1_deriv;
    let dw1 = (1.0/m)*(dz1.dot(&xt));
    let db1 = dz1.sum_axis(Axis(1)).insert_axis(Axis(1)) * (1.0 / m);

    (dw1, db1, dw2, db2)
}
Enter fullscreen mode Exit fullscreen mode

Update weights and biasis:

image

pub fn update_params(
    w1: &mut Array2<f32>,
    b1: &mut Array2<f32>,
    w2: &mut Array2<f32>,
    b2: &mut Array2<f32>,
    dw1: &Array2<f32>,
    db1: &Array2<f32>,
    dw2: &Array2<f32>,
    db2: &Array2<f32>,
    alpha: f32,
) {
    *w1 -= &(alpha * dw1);
    *b1 -= &(alpha * db1);
    *w2 -= &(alpha * dw2);
    *b2 -= &(alpha * db2);
}
Enter fullscreen mode Exit fullscreen mode

Function to find accuracy of our model:

pub fn get_accuracy(predictions: &Array2<f32>, labels: &Array2<f32>) -> f32 {
    let pred_classes: Array1<usize> = predictions
        .axis_iter(Axis(1))
        .map(|col| {
            col.iter()
                .enumerate()
                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
                .unwrap()
                .0
        })
        .collect();

    let true_classes: Array1<usize> = labels.iter().map(|x| *x as usize).collect();

    let correct = pred_classes
        .iter()
        .zip(true_classes.iter())
        .filter(|(pred, truth)| pred == truth)
        .count();

    correct as f32 / labels.len() as f32
}
Enter fullscreen mode Exit fullscreen mode

Finally using all functions in main:

fn main() -> Result<(), Box<dyn Error>> {
    let (mut training_data, mut training_label) = load_training_data()?;
    let (mut w1,mut b1, mut w2, mut b2) = init_params();

    let iterations = 501;
    let alpha = 0.1;
    println!("{}", training_label);

    for i in 0..iterations{
        let (mut z1, mut a1, mut z2, mut a2) = forward_propagation(&mut w1, &mut b1, &mut w2, &mut b2, &mut training_data);
        let (dw1, db1, dw2, db2) = backward_propagation(&mut z1, &mut a1, &mut a2, &mut w2, &mut training_data, &mut training_label);
        update_params(&mut w1, &mut b1, &mut w2, &mut b2, &dw1, &db1, &dw2, &db2, alpha);
        if i%50 == 0{
            println!("Iteration: {}", i);
            let acc = get_accuracy(&a2, &training_label);
            println!("Accuracy: {:.2}%", acc * 100.0);
        }
    }

    Ok(())
}
Enter fullscreen mode Exit fullscreen mode

Results for 200 iterations and learning rate = 0.1

DATA: 784, 60000
LABELS: 1, 60000
[[5, 0, 4, 1, 9, ..., 8, 3, 5, 6, 8]]

Iteration: 0
Accuracy: 10.86%

Iteration: 50
Accuracy: 56.97%

Iteration: 100
Accuracy: 69.91%

Iteration: 150
Accuracy: 75.45%

Iteration: 200
Accuracy: 78.56%
Enter fullscreen mode Exit fullscreen mode

Results for 500 iterations and learning rate = 0.1

DATA: 784, 60000
LABELS: 1, 60000
[[5, 0, 4, 1, 9, ..., 8, 3, 5, 6, 8]]

Iteration: 0
Accuracy: 12.46%

Iteration: 50
Accuracy: 47.05%

Iteration: 100
Accuracy: 61.53%

Iteration: 150
Accuracy: 69.01%

Iteration: 200
Accuracy: 73.28%

Iteration: 250
Accuracy: 76.48%

Iteration: 300
Accuracy: 78.93%

Iteration: 350
Accuracy: 80.81%

Iteration: 400
Accuracy: 82.38%

Iteration: 450
Accuracy: 83.53%

Iteration: 500
Accuracy: 84.48%
Enter fullscreen mode Exit fullscreen mode

Medium Blog: https://medium.com/@dkjain2005co/neural-network-in-rust-on-mnist-dataset-from-scratch-f42971eaead3

Top comments (0)