Implement and train a neural network from scratch on MNIST dataset in Rust without using high-level libraries like TensorFlow or PyTorch.
You can find the code at: https://github.com/dhruvkjain/mnist-nn-rs
It demonstrates:
- Manual forward and backward propagation
- Use of ReLU and softmax activation functions
- One-hot encoding
- Gradient descent for training
- Accuracy evaluation
- Model parameter export to CSV using polars
🔧 Dependencies
- ndarray (store 2d array of data)
- ndarray-rand (generate intial random weights(w) and biases(b))
- polars (to read write data in csv)
🧠Model Overview
- input layer, 1 hidden layer, output layer
- Input: 784-dimensional MNIST images
- Hidden layer: 10 neurons with ReLU as activation function
- Output layer: 10 neurons with softmax as activation function for multi-class classification
📂 Structure
-
main.rs
: Training loop and evaluation -
lib.rs
: Core model logic — forward, backward, update, softmax, etc. -
mnistdata/
: Contains input dataset
📦 Dataset
Make sure the MNIST dataset is placed in mnistdata/.
Prerequisite:
Intialization of data using polars crate:
pub fn load_training_data() -> Result<(Array2<f32>, Array2<f32>), Box<dyn Error>> {
let q = LazyCsvReader::new("./mnistdata/mnist_train.csv")
.with_has_header(true)
.finish()?;
let training_labels = q
.clone()
.with_streaming(true)
.select([col("label")])
.collect()?;
let training_data = q
.clone()
.with_streaming(true)
.drop([col("label")])
.collect()?;
let mut traning_data_ndarray = training_data
.to_ndarray::<Float32Type>(IndexOrder::Fortran)
.unwrap();
let mut training_labels_ndarray = training_labels
.to_ndarray::<Float32Type>(IndexOrder::Fortran)
.unwrap();
traning_data_ndarray = traning_data_ndarray.reversed_axes()/ 255.0;
training_labels_ndarray = training_labels_ndarray.reversed_axes();
let data_dimensions:&[usize] = traning_data_ndarray.shape();
let labels_dimensions:&[usize] = training_labels_ndarray.shape();
// println!("{}", traning_data_ndarray);
// println!("{}", training_labels_ndarray);
println!("DATA: {}, {}", data_dimensions[0], data_dimensions[1]);
println!("LABELS: {}, {}", labels_dimensions[0], labels_dimensions[1]);
Ok((traning_data_ndarray, training_labels_ndarray))
}
What is Neural Network and how to use it to recognize digits:
Our Approach for Neural Network:
Declaring intial weights and biasis using ndarray_rand crate:
pub fn init_params()->(Array2<f32>, Array2<f32>, Array2<f32>, Array2<f32>){
let w1 = Array2::random((10, 784), Uniform::new(-0.5, 0.5));
let b1 = Array2::random((10, 1), Uniform::new(-0.5, 0.5));
let w2 = Array2::random((10, 10), Uniform::new(-0.5, 0.5));
let b2 = Array2::random((10, 1), Uniform::new(-0.5, 0.5));
(w1,b1,w2,b2)
}
Forward Propagation:
pub fn relu(z:&mut Array2<f32>){
z.mapv_inplace(|x| x.max(0.0))
}
pub fn softmax(z: &mut Array2<f32>) {
for mut col in z.axis_iter_mut(Axis(1)) {
// Subtract max for numerical stability
let max = col.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
col.mapv_inplace(|x| (x - max).exp());
let sum = col.sum();
col.mapv_inplace(|x| x / sum);
}
}
pub fn forward_propagation(
w1:&mut Array2<f32>,
b1:&mut Array2<f32>,
w2:&mut Array2<f32>,
b2:&mut Array2<f32>,
x:&mut Array2<f32>
) -> (Array2<f32>, Array2<f32>, Array2<f32>, Array2<f32>)
{
let z1 = w1.dot(x) + &*b1;
let mut a1 = z1.clone();
relu(&mut a1);
let z2 = w2.dot(&a1) + &*b2;
let mut a2 = z2.clone();
softmax(&mut a2);
(z1,a1,z2,a2)
}
Back Propagation:
pub fn one_hot_encoded(y:&mut Array2<f32>, num_classes:usize) -> Array2<f32> {
let ydash= y.flatten();
let label_dimensions:&[usize] = ydash.shape();
let mut one_hot_y = Array2::<f32>::zeros((label_dimensions[0], num_classes));
for (row, &label) in ydash.iter().enumerate(){
let class_index = label as usize;
one_hot_y[(row, class_index)] = 1.0;
}
one_hot_y.reversed_axes()
}
pub fn deriv_relu(z:&mut Array2<f32>){
z.mapv_inplace(|x| if x > 0.0 { 1.0 } else { 0.0 });
}
pub fn backward_propagation(
z1:&mut Array2<f32>,
a1:&mut Array2<f32>,
a2:&mut Array2<f32>,
w2:&mut Array2<f32>,
x:&mut Array2<f32>,
y:&mut Array2<f32>,
)->(Array2<f32>, Array2<f32>, Array2<f32>, Array2<f32>){
let m = y.len() as f32;
let a1t = a1.view().reversed_axes();
let w2t = w2.view().reversed_axes();
let xt = x.view().reversed_axes();
let one_hot_y = one_hot_encoded(y, 10);
let dz2 = &*a2 - &one_hot_y;
let dw2 = (1.0/m)*(dz2.dot(&a1t));
let db2 = dz2.sum_axis(Axis(1)).insert_axis(Axis(1)) * (1.0 / m);
let mut z1_deriv = z1.clone();
deriv_relu(&mut z1_deriv);
let dz1 = w2t.dot(&dz2)*z1_deriv;
let dw1 = (1.0/m)*(dz1.dot(&xt));
let db1 = dz1.sum_axis(Axis(1)).insert_axis(Axis(1)) * (1.0 / m);
(dw1, db1, dw2, db2)
}
Update weights and biasis:
pub fn update_params(
w1: &mut Array2<f32>,
b1: &mut Array2<f32>,
w2: &mut Array2<f32>,
b2: &mut Array2<f32>,
dw1: &Array2<f32>,
db1: &Array2<f32>,
dw2: &Array2<f32>,
db2: &Array2<f32>,
alpha: f32,
) {
*w1 -= &(alpha * dw1);
*b1 -= &(alpha * db1);
*w2 -= &(alpha * dw2);
*b2 -= &(alpha * db2);
}
Function to find accuracy of our model:
pub fn get_accuracy(predictions: &Array2<f32>, labels: &Array2<f32>) -> f32 {
let pred_classes: Array1<usize> = predictions
.axis_iter(Axis(1))
.map(|col| {
col.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0
})
.collect();
let true_classes: Array1<usize> = labels.iter().map(|x| *x as usize).collect();
let correct = pred_classes
.iter()
.zip(true_classes.iter())
.filter(|(pred, truth)| pred == truth)
.count();
correct as f32 / labels.len() as f32
}
Finally using all functions in main:
fn main() -> Result<(), Box<dyn Error>> {
let (mut training_data, mut training_label) = load_training_data()?;
let (mut w1,mut b1, mut w2, mut b2) = init_params();
let iterations = 501;
let alpha = 0.1;
println!("{}", training_label);
for i in 0..iterations{
let (mut z1, mut a1, mut z2, mut a2) = forward_propagation(&mut w1, &mut b1, &mut w2, &mut b2, &mut training_data);
let (dw1, db1, dw2, db2) = backward_propagation(&mut z1, &mut a1, &mut a2, &mut w2, &mut training_data, &mut training_label);
update_params(&mut w1, &mut b1, &mut w2, &mut b2, &dw1, &db1, &dw2, &db2, alpha);
if i%50 == 0{
println!("Iteration: {}", i);
let acc = get_accuracy(&a2, &training_label);
println!("Accuracy: {:.2}%", acc * 100.0);
}
}
Ok(())
}
Results for 200 iterations and learning rate = 0.1
DATA: 784, 60000
LABELS: 1, 60000
[[5, 0, 4, 1, 9, ..., 8, 3, 5, 6, 8]]
Iteration: 0
Accuracy: 10.86%
Iteration: 50
Accuracy: 56.97%
Iteration: 100
Accuracy: 69.91%
Iteration: 150
Accuracy: 75.45%
Iteration: 200
Accuracy: 78.56%
Results for 500 iterations and learning rate = 0.1
DATA: 784, 60000
LABELS: 1, 60000
[[5, 0, 4, 1, 9, ..., 8, 3, 5, 6, 8]]
Iteration: 0
Accuracy: 12.46%
Iteration: 50
Accuracy: 47.05%
Iteration: 100
Accuracy: 61.53%
Iteration: 150
Accuracy: 69.01%
Iteration: 200
Accuracy: 73.28%
Iteration: 250
Accuracy: 76.48%
Iteration: 300
Accuracy: 78.93%
Iteration: 350
Accuracy: 80.81%
Iteration: 400
Accuracy: 82.38%
Iteration: 450
Accuracy: 83.53%
Iteration: 500
Accuracy: 84.48%
Medium Blog: https://medium.com/@dkjain2005co/neural-network-in-rust-on-mnist-dataset-from-scratch-f42971eaead3
Top comments (0)