Skip to main content
2 of 5
added 5 characters in body
Daniel
  • 4.6k
  • 2
  • 18
  • 40

Compute mean and variance, incrementally

I have previously reviewed code that computes standard deviation using the mathematical formula \$ E(x^2) - E(x)^2 \$ , and warned against the use of this formula because floating-point precision is severely compromised by subtracting almost-equal numbers. However, I'm not an expert on numerical methods, so I'd like to know of any weaknesses in my version.

I got slightly carried away with generality, so I've included a version that works with complex numbers, and I've implemented a trailing-N-values rolling mean and variance. I also present the unit-tests I used to write the classes - they are written for Google Test, but should be easy to convert if you prefer a different test runner.

As usual, I'd like feedback on any aspect that could be improved. Although I wrote this only as an exercise for fun and practice, I do want to make all my code the best it can be.

#include <complex>
#include <deque>
#include <stdexcept>
#include <limits>

struct container_underflow_error : public std::runtime_error
{
    explicit container_underflow_error(const char* desc = "empty container")
        : std::runtime_error(desc)
    {}
    explicit container_underflow_error(const std::string& desc)
        : std::runtime_error(desc)
    {}
};

namespace impl {
    static constexpr struct raw_tag {} raw_tag = {};
}

template<typename>
class SimpleStatsBag
{
    SimpleStatsBag() = delete;
};

template<typename T>
    requires std::numeric_limits<T>::has_quiet_NaN
class SimpleStatsBag<T>
{
    static constexpr auto nan = std::numeric_limits<T>::quiet_NaN();

public:
    using value_type = T;
    using variance_type = T;

private:
    std::size_t count = 0;
    value_type current_mean = 0;
    variance_type current_nvar = 0;     // count times the current variance

public:
    SimpleStatsBag() noexcept = default;
    SimpleStatsBag(std::initializer_list<T> items) noexcept
        : SimpleStatsBag{items.begin(), items.end()}
    {}
    template<typename It>       // InputIterator It
        requires requires(It i) { *++i; }
    SimpleStatsBag(It first, It last) noexcept
    {
        while (first != last)
            *this += *first++;
    }

    // tagged constructor (for internal use only)
    SimpleStatsBag(struct impl::raw_tag,
                   std::size_t size, value_type mean, variance_type nvar)
        : count(size), current_mean(mean), current_nvar(nvar)
    {}


    // Accessors for the statistical properties
    std::size_t size() const noexcept { return count; }

    value_type mean() const noexcept { return count ? current_mean : nan; }

    variance_type population_variance() const noexcept
    {
        return count ? current_nvar / count : nan;
    }
    variance_type sample_variance() const noexcept
    {
        return count > 1 ? population_variance() * count / (count - 1) : nan;
    }


    // Mutators

    // add and remove values
    SimpleStatsBag operator+(value_type value) const noexcept
    {
        return SimpleStatsBag(*this) += value;
    }
    SimpleStatsBag& operator+=(value_type value) noexcept
    {
        auto const old_mean = current_mean;
        current_mean += (value - current_mean) / ++count;
        current_nvar += (value - current_mean) * (value - old_mean);
        return *this;
    }

    SimpleStatsBag operator-(value_type value) const noexcept
    {
        return SimpleStatsBag(*this) += value;
    }
    SimpleStatsBag& operator-=(value_type value)
    {
        if (!count)
            throw container_underflow_error();
        auto const old_mean = current_mean;
        current_mean -= (value - current_mean) / --count;
        current_nvar -= (value - current_mean) * (value - old_mean);
        return *this;
    }

    // add/subtract bags
    SimpleStatsBag operator+(const SimpleStatsBag& other) const noexcept
    {
        auto new_count = count + other.count;
        auto new_mean = (current_mean * count + other.current_mean * other.count) / new_count;
        auto new_nvar = current_nvar + other.current_nvar
            + count * (current_mean - new_mean) * (current_mean - new_mean)
            + other.count * (other.current_mean - new_mean) * (other.current_mean - new_mean);

        return SimpleStatsBag(impl::raw_tag, new_count, new_mean, new_nvar);
    }

    SimpleStatsBag& operator+=(const SimpleStatsBag& other) noexcept
    {
        return *this = *this + other;
    }

    SimpleStatsBag operator-(const SimpleStatsBag& other) const
    {
        auto new_count = count - other.count;
        auto new_mean = (current_mean * count - other.current_mean * other.count) / new_count;
        auto new_nvar = current_nvar - other.current_nvar
            + count * (current_mean - new_mean) * (current_mean - new_mean)
            - other.count * (other.current_mean - new_mean) * (other.current_mean - new_mean);

        return SimpleStatsBag(impl::raw_tag, new_count, new_mean, new_nvar);
    }

    SimpleStatsBag& operator-=(const SimpleStatsBag& other) noexcept
    {
        return *this = *this - other;
    }
};

// specialize for complex numbers

template<typename T>
    requires std::numeric_limits<T>::has_quiet_NaN
class SimpleStatsBag<std::complex<T>>
{
    SimpleStatsBag<T> real = {};
    SimpleStatsBag<T> imag = {};

public:
    using value_type = std::complex<T>;
    using variance_type = T;

public:
    SimpleStatsBag() noexcept = default;
    template<typename It>       // InputIterator It
        requires requires(It i) { *++i; }
    SimpleStatsBag(It first, It last) noexcept
    {
        while (first != last)
            *this += (*first++);
    }
    SimpleStatsBag(const std::initializer_list<value_type> items) noexcept
        : SimpleStatsBag{items.begin(), items.end()}
    {}


    // Accessors for the statistical properties
    std::size_t size() const noexcept { return real.size(); }

    value_type mean() const noexcept { return {real.mean(), imag.mean()}; }

    variance_type population_variance() const noexcept {
        return real.population_variance() + imag.population_variance();
    }
    variance_type sample_variance() const noexcept
    {
        return real.sample_variance() + imag.sample_variance();
    }


    // add and remove values
    SimpleStatsBag operator+(value_type value) const noexcept
    {
        return SimpleStatsBag(*this) += value;
    }
    SimpleStatsBag& operator+=(value_type value) noexcept
    {
        real += value.real();
        imag += value.imag();
        return *this;
    }

    SimpleStatsBag operator-(value_type value) const noexcept
    {
        return SimpleStatsBag(*this) -= value;
    }
    SimpleStatsBag& operator-=(value_type value)
    {
        real -= value.real();
        imag -= value.imag();
        return *this;
    }

    // add and subtract bags
    SimpleStatsBag operator+(const SimpleStatsBag& other) const noexcept
    {
        return SimpleStatsBag(*this) += other;
    }
    SimpleStatsBag& operator+=(const SimpleStatsBag& other) noexcept
    {
        real += other.real;
        imag += other.imag;
        return *this;
    }

    SimpleStatsBag operator-(const SimpleStatsBag& other) const
    {
        return SimpleStatsBag(*this) -= other;
    }
    SimpleStatsBag& operator-=(const SimpleStatsBag& other) noexcept
    {
        real -= other.real;
        imag -= other.imag;
        return *this;
    }
};


// <complex> doesn't provide these specializations of common_type.
// Technically, specializing these is undefined behaviour, but it is the
// least-pain way to mix and match complex and scalar values.
namespace std {
    template<typename S, typename T>
    struct common_type<std::complex<S>, T> {
        using type = std::complex<typename std::common_type_t<S, T>>;
    };
    template<typename S, typename T>
    struct common_type<S, std::complex<T>> {
        using type = std::complex<typename std::common_type_t<S, T>>;
    };
    template<typename S, typename T>
    struct common_type<std::complex<S>, std::complex<T>> {
        using type = std::complex<typename std::common_type_t<S, T>>;
    };
}

// deduction guide - promote to at least double
template<typename... T> SimpleStatsBag(T...)
    -> SimpleStatsBag<typename std::common_type_t<T..., double>>;


// Rolling statitistics
template<typename T = double>
class RollingStatsBag : SimpleStatsBag<T>
{
    std::size_t capacity;
    std::deque<T> recent = {};

public:
    RollingStatsBag(std::size_t capacity)
        : capacity{capacity}
    {}

    using typename SimpleStatsBag<T>::value_type;

    using SimpleStatsBag<T>::size;
    using SimpleStatsBag<T>::mean;
    using SimpleStatsBag<T>::population_variance;
    using SimpleStatsBag<T>::sample_variance;

    // add value
    RollingStatsBag operator+(value_type value) const noexcept
    {
        return RollingStatsBag(*this) += value;
    }
    RollingStatsBag& operator+=(value_type value) noexcept
    {
        recent.push_back(value);
        SimpleStatsBag<T>::operator+=(value);
        if (size() > capacity) {
            SimpleStatsBag<T>::operator-=(recent.front());
            recent.pop_front();
        }
        return *this;
    }
};
// Test suite

#include <gtest/gtest.h>
#include <cmath>                // std::isnan

TEST(SimpleStatsBag, empty)
{
    SimpleStatsBag b;
    static_assert(std::is_same_v<decltype(b.mean()), double>);
    EXPECT_EQ(0, b.size());
    EXPECT_TRUE(std::isnan(b.mean()));
    EXPECT_TRUE(std::isnan(b.population_variance()));
    EXPECT_TRUE(std::isnan(b.sample_variance()));
}

TEST(SimpleStatsBag, one_element)
{
    SimpleStatsBag b{100};
    EXPECT_EQ(1, b.size());
    EXPECT_EQ(100, b.mean());
    EXPECT_EQ(0, b.population_variance());
    EXPECT_TRUE(std::isnan(b.sample_variance()));
}

TEST(SimpleStatsBag, single_precision)
{
    SimpleStatsBag b{100.f};
    static_assert(std::is_same_v<decltype(b.mean()), double>);
    EXPECT_EQ(100, b.mean());
}

TEST(SimpleStatsBag, long_double)
{
    SimpleStatsBag b{100.L};
    static_assert(std::is_same_v<decltype(b.mean()), long double>);
    EXPECT_EQ(100, b.mean());
}

TEST(SimpleStatsBag, complex)
{
    SimpleStatsBag b{std::complex{100.f, -100.f}};
    static_assert(std::is_same_v<decltype(b.mean()), std::complex<double>>);
    EXPECT_DOUBLE_EQ(100, b.mean().real());
    EXPECT_DOUBLE_EQ(-100, b.mean().imag());
    EXPECT_DOUBLE_EQ(0, b.population_variance());
    EXPECT_TRUE(std::isnan(b.sample_variance()));
}



TEST(SimpleStatsBag, two_double)
{
    SimpleStatsBag b{0, 200};
    EXPECT_EQ(2, b.size());
    EXPECT_DOUBLE_EQ(100, b.mean());
    EXPECT_DOUBLE_EQ(10000, b.population_variance());
    EXPECT_DOUBLE_EQ(20000, b.sample_variance());
}

TEST(SimpleStatsBag, two_complex)
{
    SimpleStatsBag<std::complex<double>> b{ {100, -100}, {100, 100} };
    EXPECT_DOUBLE_EQ(100, b.mean().real());
    EXPECT_DOUBLE_EQ(0, b.mean().imag());
    EXPECT_DOUBLE_EQ(10000, b.population_variance());
    EXPECT_DOUBLE_EQ(20000, b.sample_variance());
}

TEST(SimpleStatsBag, mixed_complex)
{
    SimpleStatsBag b{std::complex{100.f, -100.f}, std::complex{100.l, -100.l}};
    static_assert(std::is_same_v<decltype(b.mean()), std::complex<long double>>);
    EXPECT_DOUBLE_EQ(100.l, b.mean().real());
    EXPECT_DOUBLE_EQ(-100.l, b.mean().imag());
    EXPECT_DOUBLE_EQ(0, b.population_variance());
    EXPECT_DOUBLE_EQ(0, b.sample_variance());
}


TEST(SimpleStatsBag, remove)
{
    SimpleStatsBag b{0, 200, 4000};
    b -= 4000;
    EXPECT_EQ(100, b.mean());
    EXPECT_EQ(10000, b.population_variance());
}

TEST(SimpleStatsBag, remove_all)
{
    SimpleStatsBag b{100};
    b -= 100;
    EXPECT_TRUE(std::isnan(b.mean()));
    EXPECT_TRUE(std::isnan(b.population_variance()));
}

TEST(SimpleStatsBag, remove_more)
{
    SimpleStatsBag b{};
    ASSERT_THROW(b -= 100, std::runtime_error);
}

TEST(SimpleStatsBag, add_bags)
{
    SimpleStatsBag a{100, 1000};
    SimpleStatsBag b{200, 300};
    auto c = a + b;
    SimpleStatsBag d{100, 200, 300, 1000};
    EXPECT_EQ(d.size(), c.size());
    EXPECT_DOUBLE_EQ(d.mean(), c.mean());
    EXPECT_DOUBLE_EQ(d.population_variance(), c.population_variance());
}

TEST(SimpleStatsBag, subtract_bags)
{
    SimpleStatsBag<std::complex<float>> a{100, 200, 300, 1000};
    SimpleStatsBag<std::complex<float>> b{200, 300};
    auto c = a - b;
    SimpleStatsBag<std::complex<float>> d{100, 1000};
    EXPECT_EQ(d.size(), c.size());
    EXPECT_FLOAT_EQ(d.mean().real(), c.mean().real());
    EXPECT_FLOAT_EQ(d.mean().imag(), c.mean().imag());
    EXPECT_FLOAT_EQ(d.population_variance(), c.population_variance());
}


TEST(RollingStatsBag, real)
{
    RollingStatsBag a{3};
    a += 10;
    a += 20;
    a += 30;
    EXPECT_EQ(3, a.size());
    EXPECT_DOUBLE_EQ(20, a.mean());
    a += 40;
    EXPECT_EQ(3, a.size());
    EXPECT_DOUBLE_EQ(30, a.mean());
}


TEST(RollingStatsBag, complex)
{
    RollingStatsBag<std::complex<double>> a{2};
    a += 0;
    a += {0, -100};
    EXPECT_DOUBLE_EQ(2500, a.population_variance());
    a += {0, -100};
    EXPECT_FLOAT_EQ(1, 1+a.population_variance());
}
Toby Speight
  • 88.4k
  • 14
  • 104
  • 327