I am practicing Rust by creating data structures, such as this N-dimensional array.
The purpose of this structure is to easily define and address into arbitrarily-deeply nested arrays without having to write out Vec<Vec<Vec<T>>> and incur the performance penalty coming from this (although performance is not a priority).
I am looking for feedback on how idiomatic the interface and implementation are, as well as how to improve my tests since I don't think they're very robust.
use std::ops::{Index, IndexMut};
#[derive(Debug, Clone)]
pub struct NdArray<T, const N: usize> {
dim: [usize; N],
data: Vec<T>,
}
impl<T, const N: usize> NdArray<T, N> {
fn calculate_capacity(dim: &[usize; N]) -> usize {
let mut cap = 1;
for &x in dim {
cap = usize::checked_mul(cap, x).expect("vector capacity overflowed usize");
}
cap
}
pub fn new(dim: [usize; N], default: T) -> Self
where
T: Clone,
{
let cap = Self::calculate_capacity(&dim);
NdArray {
dim,
data: vec![default; cap],
}
}
pub fn new_with(dim: [usize; N], generator: impl FnMut() -> T) -> Self {
let cap = Self::calculate_capacity(&dim);
NdArray {
dim,
data: {
let mut v = Vec::new();
v.resize_with(cap, generator);
v
},
}
}
fn get_flat_idx(&self, idx: &[usize; N]) -> usize {
let mut i = 0;
for d in 0..self.dim.len() {
assert!(
idx[d] < self.dim[d],
"index {} is out of bounds for dimension {} with size {}",
idx[d],
d,
self.dim[d],
);
// This cannot overflow since we already checked product of all dimensions fits usize
// and idx < self.dim.
i = i * self.dim[d] + idx[d];
}
i
}
}
impl<T, const N: usize> Index<&[usize; N]> for NdArray<T, N> {
type Output = T;
fn index(&self, idx: &[usize; N]) -> &T {
&self.data[self.get_flat_idx(idx)]
}
}
impl<T, const N: usize> IndexMut<&[usize; N]> for NdArray<T, N> {
fn index_mut(&mut self, idx: &[usize; N]) -> &mut T {
let i = self.get_flat_idx(idx);
&mut self.data[i]
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::panic::catch_unwind;
#[test]
fn nd_array_1d() {
let mut array = NdArray::new_with([10], || 1);
assert_eq!(array[&[0]], 1);
assert_eq!(array[&[9]], 1);
assert!(catch_unwind(|| array[&[10]]).is_err());
array[&[5]] = 99;
assert_eq!(array[&[4]], 1);
assert_eq!(array[&[5]], 99);
assert_eq!(array[&[6]], 1);
}
#[test]
fn nd_array() {
let mut array = NdArray::new([2, 3, 4], 0);
array[&[1, 2, 3]] = 10;
assert_eq!(array[&[1, 2, 3]], 10);
assert_eq!(array[&[1, 0, 0]], 0);
assert!(catch_unwind(|| array[&[2, 0, 0]]).is_err());
assert_eq!(array[&[0, 2, 0]], 0);
assert!(catch_unwind(|| array[&[0, 3, 0]]).is_err());
assert_eq!(array[&[0, 0, 3]], 0);
assert!(catch_unwind(|| array[&[0, 0, 4]]).is_err());
}
#[test]
fn nd_array_overflow() {
// Panics at the allocator code since sizeof(usize) > 2
// NdArray::new([usize::MAX / 2], 0);
assert!(catch_unwind(|| NdArray::new([usize::MAX / 2 + 1, 2], 0)).is_err());
assert!(catch_unwind(|| NdArray::new(
[usize::MAX / 3, usize::MAX / 3, usize::MAX / 3 + 1],
0
))
.is_err());
}
}