Skip to main content
Bounty Awarded with 100 reputation awarded by ShapeOfMatter
problem with code
Source Link

EDIT: a range is maximum minus minimum, not just maximum.

EDIT: a range is maximum minus minimum, not just maximum.

Source Link

/*

Hello and welcome to the Rust community!

First, a couple of ideas after briefly looking through the code.

  • The algorithm can be optimized with SIMD operations. Read up on SIMD in Rust if you wish to learn more.
  • You have a struct with fields r, g, b, and a. Unfortunately you cannot do a struct of arrays instead, because you sort colors. Read up on ArrayOfStructs/StructOfArrays if you wish to learn more.
  • You guard against the emptiness of buckets in many places -- this invariant could be baked into a new type for holding buckets. This avoids almost all use of Options.
  • You have many freestanding functions. They do work on a common type. You can make a new type for holding buckets and impl these fns on that type. This could improve readability and clarity of code.
  • The code uses recursion. It is a good practice when working with algorithms to avoid recursion and work with a Vec and loop instead. Reasoning about recursion is sometimes difficult. I've even seen an amazing programmer (Niko Matsakis) make mistakes with recursion. But your recursion should be fine here, since it is a simple algorithm.
  • Use clippy! It provides some small suggestions for making this code in line with best practices.
  • There is fold, which could be refactored as map and sum instead.
  • You do impl PartialEq for Color, which could be automatically derived.
  • The clone in recurse is useless. Also, use extend instead.
  • If you have more than 2^24 pixels, for example in an image larger than 4096x4096, your color_mean may overflow. Use a u64 type there to avoid overflow.
  • recurse is not a good name for that fn. A better name is make_palette as suggested by @ShapeOfMatter.
  • sort_by can become sort_by_key.

Are sure there is any problem with your implementation? I do not see any purple color results on pixels of red and blue.

Here is how the code looks like after a refactor

I added in changes suggested by @ShapeOfMatter.

*/


// ---------
// I have implemented a simple version of the median cut algorithm.
// It takes a vector of Color structs representing pixel in an image.
// I also use the ColorChanel enum representing RGBA channels.
// ---------

#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Color {
    pub r: u8,
    pub g: u8,
    pub b: u8,
    pub a: u8,
}

#[derive(Debug, Copy, Clone)]
pub enum ColorChannel {
    R,
    G,
    B,
    A,
}

/// A list of colors.
///
/// INVARIANT: is nonempty.
struct ColorBucket {
    colors: Vec<Color>,
}

impl ::std::ops::Index<ColorChannel> for Color {
    type Output = u8;
    fn index(&self, index: ColorChannel) -> &Self::Output {
        match index {
            ColorChannel::R => &self.r,
            ColorChannel::G => &self.g,
            ColorChannel::B => &self.b,
            ColorChannel::A => &self.a,
        }
    }
}

/// Helper function for calculating the mean.
fn mean(iter: impl Iterator<Item=u8> + Clone) -> u8 {
    (iter.clone().map(|x| x as u64).sum::<u64>() / iter.count() as u64) as u8
}

impl ColorBucket {
    fn from_pixels(pixels: Vec<Color>) -> Option<Self> {
        if pixels.is_empty() {
            None
        } else {
            Some(Self {
                colors: pixels,
            })
        }
    }

    // ---------
    // For each Color/Pixel vector I look for the channel with the highest range:
    // ---------

    /// Returns the color channel with the highest range.
    /// IMPORTANT: Ignores alpha channel!
    ///
    /// # Arguments
    ///
    /// * `colors` - Color vector from which the highest range is evaluated.
    ///
    fn highest_range_channel(&self) -> ColorChannel {
        let ranges = self.color_ranges();

        let mut highest_range_channel = ColorChannel::R;
        let mut highest_value = ranges.r;

        if ranges.g > highest_value {
            highest_range_channel = ColorChannel::G;
            highest_value = ranges.g;
        }

        if ranges.b > highest_value {
            highest_range_channel = ColorChannel::B;
        }

        highest_range_channel
    }


    // ---------
    // The color ranges are calculated in the according function:
    // ---------

    /// Returns the ranges for each color channel
    ///
    /// # Arguments
    ///
    /// * `colors` - Color vector from which the ranges are calculated.
    ///
    /// # Examples
    ///
    /// ```
    /// let colors = Vec::<Color>::new();
    /// let color_ranges_data = color_ranges(colors);
    /// ```
    ///
    fn color_ranges(&self) -> Color {
        // Unwrap is ok here, because `max_by_key` only returns `None` for empty vectors
        let r_range = self.colors.iter().max_by_key(|c| c.r).unwrap().r;
        let g_range = self.colors.iter().max_by_key(|c| c.g).unwrap().g;
        let b_range = self.colors.iter().max_by_key(|c| c.b).unwrap().b;
        let a_range = self.colors.iter().max_by_key(|c| c.a).unwrap().a;

        Color {
            r: r_range,
            g: g_range,
            b: b_range,
            a: a_range,
        }
    }

    // ---------
    // After that I calculate the median value for the channel with the highest range.
    // I am doing this by sorting the vector based on the desired channel and finding
    // the value in the "middle" of the vector.
    // ---------

    /// Sort a color vector for a specific channel.
    ///
    /// # Arguments
    ///
    /// * `colors` - Color data which will be sorted.
    /// * `channel` - Target channel. The sorting is performed based on this value.
    ///
    /// # Examples
    ///
    /// ```
    /// let mut colors = Vec::<Color>::new();
    /// sort_colors(&mut colors, &ColorChannel::R);
    /// ```
    ///
    fn sort_colors(&mut self, channel: ColorChannel) {
        self.colors.sort_by_key(|x| x[channel])
    }

    /// Returns median value for a specific `ColorChannel`.
    ///
    /// # Arguments
    ///
    /// * `colors` - Color vector from which the median value is calculated.
    /// * `channel` - Target channel for which the median is calculated.
    ///
    /// # Examples
    /// ```
    /// let mut colors = Vec::<Color>::new();
    /// let mut result = color_median(&mut colors, &ColorChannel::R);
    /// ```
    ///
    fn color_median(&mut self, channel: ColorChannel) -> u8 {
        self.sort_colors(channel);

        let mid = self.colors.len() / 2;
        if self.colors.len() % 2 == 0 {
            let bucket = ColorBucket::from_pixels(vec![self.colors[mid - 1], self.colors[mid]]).unwrap();
            bucket.channel_mean(channel)
        } else {
            self.channel_value_by_index(mid, channel)
        }
    }

    /// Returns a color value based on the provided channel and index parameters.
    ///
    /// # Arguments
    ///
    /// * `colors` - Color vector from which the value is retreived.
    /// * `index` - Index of the target color in the vector.
    /// * `channel` - Color channel of the searched value.
    ///
    /// # Examples
    ///
    /// ```
    /// let mut colors: Vec<Color> = Vec::new();
    /// colors.push(Color { r: 100, g: 22, b: 12, a: 0 });
    /// assert_eq!(Some(100), channel_value_by_index(&colors, 0, &ColorChannel::R));
    /// ```
    ///
    /// # Panics
    /// 
    /// Panics when index is out of bounds.
    /// 
    fn channel_value_by_index(&self, index: usize, channel: ColorChannel) -> u8 {
        self.colors[index][channel]
    }

    /// Calculate the mean value for a specific color channel on a vector of `Color`.
    ///
    /// # Arguments
    ///
    /// * `colors` - Color vector from which the mean value is calculated.
    /// * `channel` - Target channel for which the mean is calculated.
    ///
    /// # Examples
    ///
    /// ```
    /// let mut colors: Vec<Color> = Vec::new();
    /// let mut result = channel_mean(&colors, &ColorChannel::R);
    /// ```
    ///
    fn channel_mean(&self, channel: ColorChannel) -> u8 {
        mean(self.colors.iter().map(|x| x[channel]))
    }

    // ---------
    // Now I create two new vectors/buckets with one containing all Colors
    // above the median value and another with Color values below the median.
    // This entire process is implemented in the median_cut function.
    // ---------

    /// Performs the median cut on a single vector (bucket) of `Color`.
    /// Returns two `color` vectors representing the colors above and colors below median value.
    ///
    /// # Arguments
    ///
    /// * `colors` - `Color` vector on which the median cut is performed.
    ///
    fn median_cut(&mut self) -> (Option<ColorBucket>, Option<ColorBucket>) {
        let highest_range_channel = self.highest_range_channel();
        let median = self.color_median(highest_range_channel);
        let mut above_median = vec![];
        let mut below_median = vec![];
        for color in &self.colors {
            if color[highest_range_channel] > median {
                above_median.push(*color);
            } else {
                below_median.push(*color);
            }
        }

        (ColorBucket::from_pixels(above_median), ColorBucket::from_pixels(below_median))
    }

    // --------
    // In order to perform multiple iterations I call the recurse function.
    // It takes a bucket, performs the median_cut on it calculates the mean
    // color for each output bucket and performs the median_cut on the them.
    // The function stops and returns all mean colors when the amount of iterations reaches 0.
    // --------

    fn recurse(&mut self, iter_count: u8, result: &mut Vec<Color>) {
        if iter_count == 0 {
            result.push(self.color_mean());
        } else {
            let new_buckets = self.median_cut();
            if let Some(mut bucket) = new_buckets.0 {
                bucket.recurse(iter_count - 1, result);
            }
            if let Some(mut bucket) = new_buckets.1 {
                bucket.recurse(iter_count - 1, result);
            }
        }
    }

    fn make_palette(&mut self, iter_count: u8) -> Vec<Color> {
        let mut result = vec![];
        self.recurse(iter_count, &mut result);
        result
    }

    /// Returns the mean color value based on the passed colors.
    ///
    /// # Arguments
    ///
    /// * `colors` - Color vector from which the mean color is calculated.
    ///
    /// # Examples
    ///
    /// ```
    /// let colors = Vec::<Color>::new();
    /// let result = color_mean(&colors);
    /// ```
    ///
    fn color_mean(&self) -> Color {
        let r = mean(self.colors.iter().map(|c| c.r));
        let g = mean(self.colors.iter().map(|c| c.g));
        let b = mean(self.colors.iter().map(|c| c.b));
        let a = mean(self.colors.iter().map(|c| c.a));

        Color { r, g, b, a }
    }
}

// --------
// So basically a call to the entire algorithm looks like this:
// --------

#[test]
fn test_colors() {
    let mut pixels = Vec::new();
    // fill pixels with data
    pixels.push(Color { r: 100, g: 120, b: 120, a: 0 });
    pixels.push(Color { r: 150, g: 150, b: 150, a: 0 });
    pixels.push(Color { r: 255, g: 255, b: 255, a: 0 });

    let mut bucket = ColorBucket::from_pixels(pixels).expect("empty list");

    let colors = bucket.make_palette(3);
    // Do something with the output colors
    // println!("{:?}", colors);
    let expected = vec![
        Color { r: 255, g: 255, b: 255, a: 0 },
        Color { r: 150, g: 150, b: 150, a: 0 },
        Color { r: 100, g: 120, b: 120, a: 0 },
    ];
    assert_eq!(colors, expected);
}

#[test]
fn test_red_blue() {
    let mut pixels = Vec::new();
    // fill pixels with data
    pixels.push(Color { r: 255, g: 0, b: 0, a: 0 });
    pixels.push(Color { r: 0, g: 255, b: 0, a: 0 });

    let mut bucket = ColorBucket::from_pixels(pixels).expect("empty list");

    let colors = bucket.make_palette(1);
    // Do something with the output colors
    // println!("{:?}", colors);
    let expected = vec![
        Color { r: 255, g: 0, b: 0, a: 0 },
        Color { r: 0, g: 255, b: 0, a: 0 },
    ];
    assert_eq!(colors, expected);
}