Skip to main content
Notice removed Reward existing answer by ShapeOfMatter
Bounty Ended with Peter Blackson's answer chosen by ShapeOfMatter
edited tags
Link
Notice added Reward existing answer by ShapeOfMatter
Bounty Started worth 100 reputation by ShapeOfMatter
Notice removed Draw attention by チーズパン
Bounty Ended with ShapeOfMatter's answer chosen by チーズパン
Notice added Draw attention by チーズパン
Bounty Started worth 200 reputation by チーズパン
Notice removed Draw attention by CommunityBot
Bounty Ended with no winning answer by CommunityBot
Tweeted twitter.com/StackCodeReview/status/1497360764176412679
Notice added Draw attention by チーズパン
Bounty Started worth 100 reputation by チーズパン
Source Link

Median cut algorithm

I have implemented a simple version of the median cut algorithm. It takes a vector of Color structs representing pixel in an image. I also use the ColorChanel enum representing RGBA channels.

pub struct Color {
    pub r: u8,
    pub g: u8,
    pub b: u8,
    pub a: u8,
}

#[derive(Debug)]
pub enum ColorChannel {
    R,
    G,
    B,
    A,
}

impl Color {
    pub fn channel_val(&self, channel: &ColorChannel) -> u8 {
        match channel {
            ColorChannel::R => self.r,
            ColorChannel::G => self.g,
            ColorChannel::B => self.b,
            ColorChannel::A => self.a,
        }
    }
}

impl PartialEq for Color {
    fn eq(&self, other: &Self) -> bool {
        self.r == other.r && self.g == other.g && self.b == other.b && self.a == other.a
    }
}

For each Color/Pixel vector I look for the channel with the highest range:

/// Returns the color channel with the highest range.
/// IMPORTANT: Ignores alpha channel!
///
/// # Arguments
///
/// * `colors` - Color vector from which the highest range is evaluated.
///
fn highest_range_channel(colors: &Vec<Color>) -> Option<ColorChannel> {
    if let Some(ranges) = color_ranges(colors) {
        let mut highest_range_channel = ColorChannel::R;
        let mut highest_value = ranges.r;

        if ranges.g > highest_value {
            highest_range_channel = ColorChannel::G;
            highest_value = ranges.g;
        }

        if ranges.b > highest_value {
            highest_range_channel = ColorChannel::B;
        }

        return Some(highest_range_channel);
    }

    None
}

The color ranges are calculated in the according function:

/// Returns the ranges for each color channel
///
/// # Arguments
///
/// * `colors` - Color vector from which the ranges are calculated.
///
/// # Examples
///
/// ```
/// let colors = Vec::<Color>::new();
/// let color_ranges_data = color_ranges(colors);
/// ```
///
fn color_ranges(colors: &Vec<Color>) -> Option<Color> {
    if colors.is_empty() {
        return None;
    }

    // Unwrap is ok here, because `max_by_key` only returns `None` for empty vectors
    let r_range = colors.iter().max_by_key(|c| c.r).unwrap().r;
    let g_range = colors.iter().max_by_key(|c| c.g).unwrap().g;
    let b_range = colors.iter().max_by_key(|c| c.b).unwrap().b;
    let a_range = colors.iter().max_by_key(|c| c.a).unwrap().a;

    Some(Color {
        r: r_range,
        g: g_range,
        b: b_range,
        a: a_range,
    })
}

After that I calculate the median value for the channel with the highest range. I am doing this by sorting the vector based on the desired channel and finding the value in the "middle" of the vector.

/// Sort a color vector for a specific channel.
///
/// # Arguments
///
/// * `colors` - Color data which will be sorted.
/// * `channel` - Target channel. The sorting is performed based on this value.
///
/// # Examples
///
/// ```
/// let mut colors = Vec::<Color>::new();
/// sort_colors(&mut colors, &ColorChannel::R);
/// ```
///
fn sort_colors(colors: &mut Vec<Color>, channel: &ColorChannel) {
    if colors.is_empty() {
        return;
    }

    match channel {
        ColorChannel::R => colors.sort_by(|a, b| a.r.cmp(&b.r)),
        ColorChannel::G => colors.sort_by(|a, b| a.g.cmp(&b.g)),
        ColorChannel::B => colors.sort_by(|a, b| a.b.cmp(&b.b)),
        ColorChannel::A => colors.sort_by(|a, b| a.a.cmp(&b.a)),
    }
}

/// Returns median value for a specific `ColorChannel`.
///
/// # Arguments
///
/// * `colors` - Color vector from which the median value is calculated.
/// * `channel` - Target channel for which the median is calculated.
///
/// # Examples
/// ```
/// let mut colors = Vec::<Color>::new();
/// let mut result = color_median(&mut colors, &ColorChannel::R);
/// ```
///
fn color_median(colors: &mut Vec<Color>, channel: &ColorChannel) -> Option<u8> {
    if colors.is_empty() {
        return None;
    }

    sort_colors(colors, channel);

    let mid = colors.len() / 2;
    if colors.len() % 2 == 0 {
        channel_mean(&vec![colors[mid - 1], colors[mid]], channel)
    } else {
        channel_value_by_index(colors, mid, channel)
    }
}

/// Returns a color value based on the provided channel and index parameters.
///
/// # Arguments
///
/// * `colors` - Color vector from which the value is retreived.
/// * `index` - Index of the target color in the vector.
/// * `channel` - Color channel of the searched value.
///
/// # Examples
///
/// ```
/// let mut colors: Vec<Color> = Vec::new();
/// colors.push(Color { r: 100, g: 22, b: 12, a: 0 });
/// assert_eq!(Some(100), channel_value_by_index(&colors, 0, &ColorChannel::R));
/// ```
///
fn channel_value_by_index(colors: &Vec<Color>, index: usize, channel: &ColorChannel) -> Option<u8> {
    if colors.is_empty() || index >= colors.len() {
        return None;
    }
    match channel {
        ColorChannel::R => Some(colors[index].r),
        ColorChannel::G => Some(colors[index].g),
        ColorChannel::B => Some(colors[index].b),
        ColorChannel::A => Some(colors[index].a),
    }
}

/// Calculate the mean value for a specific color channel on a vector of `Color`.
///
/// # Arguments
///
/// * `colors` - Color vector from which the mean value is calculated.
/// * `channel` - Target channel for which the mean is calculated.
///
/// # Examples
///
/// ```
/// let mut colors: Vec<Color> = Vec::new();
/// let mut result = channel_mean(&colors, &ColorChannel::R);
/// ```
///
fn channel_mean(colors: &Vec<Color>, channel: &ColorChannel) -> Option<u8> {
    let number_colors = colors.len();

    if number_colors == 0 {
        return None;
    }

    match channel {
        ColorChannel::R => Some((colors.iter().fold(0, |acc: u32, x| x.r as u32 + acc) / number_colors as u32) as u8),
        ColorChannel::G => Some((colors.iter().fold(0, |acc: u32, x| x.g as u32 + acc) / number_colors as u32) as u8),
        ColorChannel::B => Some((colors.iter().fold(0, |acc: u32, x| x.b as u32 + acc) / number_colors as u32) as u8),
        ColorChannel::A => Some((colors.iter().fold(0, |acc: u32, x| x.a as u32 + acc) / number_colors as u32) as u8),
    }
}

Now I create two new vectors/buckets with one containing all Colors above the median value and another with Color values below the median. This entire process is implemented in the median_cut function.

/// Performs the median cut on a single vector (bucket) of `Color`.
/// Returns two `color` vectors representing the colors above and colors below median value.
///
/// # Arguments
///
/// * `colors` - `Color` vector on which the median cut is performed.
///
fn median_cut(colors: &mut Vec<Color>) -> (Vec<Color>, Vec<Color>) {
    if colors.is_empty() {
        return (Vec::<Color>::new(), Vec::<Color>::new());
    }

    if let Some(highest_range_channel) = highest_range_channel(&colors) {
        if let Some(median) = color_median(colors, &highest_range_channel) {
            let mut above_median = Vec::<Color>::new();
            let mut below_median = Vec::<Color>::new();
            for color in colors {
                if color.channel_val(&highest_range_channel) > median {
                    above_median.push(*color);
                } else {
                    below_median.push(*color);
                }
            }

            return (above_median, below_median);
        }
    }

    return (Vec::<Color>::new(), Vec::<Color>::new());
}

In order to perform multiple iterations I call the recurse function. It takes a bucket, performs the median_cut on it calculates the mean color for each output bucket and performs the median_cut on the them. The function stops and returns all mean colors when the amount of iterations reaches 0.

pub fn recurse(bucket: &mut Vec<Color>, iter_count: u8) -> Option<Vec<Color>> {
    if iter_count < 1 || bucket.is_empty() {
        return None;
    }
    let mut result = Vec::<Color>::new();

    let mut new_buckets = median_cut(bucket);
    if !new_buckets.0.is_empty() {
        if let Some(c_0) = color_mean(&new_buckets.0) {
            result.push(c_0);
        }

        if let Some(new_colors) = recurse(&mut new_buckets.0, iter_count - 1) {
            result.append(&mut new_colors.clone());
        }
    }

    if !new_buckets.1.is_empty() {
        if let Some(c_1) = color_mean(&new_buckets.1) {
            result.push(c_1);
        }

        if let Some(new_colors) = recurse(&mut new_buckets.1, iter_count - 1) {
            result.append(&mut new_colors.clone());
        }
    }

    Some(result)
}

/// Returns the mean color value based on the passed colors.
///
/// # Arguments
///
/// * `colors` - Color vector from which the mean color is calculated.
///
/// # Examples
///
/// ```
/// let colors = Vec::<Color>::new();
/// let result = color_mean(&colors);
/// ```
///
fn color_mean(colors: &Vec<Color>) -> Option<Color> {
    if colors.is_empty() {
        return None;
    }

    let r_mean = (colors.iter().fold(0, |acc: u32, c| acc + c.r as u32) / colors.len() as u32) as u8;
    let g_mean = (colors.iter().fold(0, |acc: u32, c| acc + c.g as u32) / colors.len() as u32) as u8;
    let b_mean = (colors.iter().fold(0, |acc: u32, c| acc + c.b as u32) / colors.len() as u32) as u8;
    let a_mean = (colors.iter().fold(0, |acc: u32, c| acc + c.a as u32) / colors.len() as u32) as u8;

    Some(Color {
        r: r_mean,
        g: g_mean,
        b: b_mean,
        a: a_mean,
    })
}

So basically a call to the entire algorithm looks like this:

let mut pixels = Vec::new();
// fill pixels with data

if let Some(colors) = recurse(&mut pixels, 3) {
    // Do something with the output colors
}

I am very new to rust so my concerns about the code are:

  • The usage of Option. I second guess if I should have used Result with some meaningful error information. Also I am feeling like the "high level" usage of functions returning Options produces arrow-like code (like in median_cut) which I personally find hard to read.
  • Overall performance, it feels like there is a lot of iteration and cloning going on.
  • Algorithm implementation, There are scenarios which return "interesting" results. E.g. if I use an image which only contains red and blue pixels, the algorithm returns a blue, a red and a purple result. As I understand it the algorithm should perform a color reduction.

How could I optimize the code based on the above topics?