Skip to main content
added 49 characters in body
Source Link
Toby Speight
  • 88.7k
  • 14
  • 104
  • 327
  • Find other percentiles, not just median.
  • Support "wide" bins as well as single values.
  • Validation adapter to check that the values are ordered.
  • Validation of known-total computation to ensure that the total is correct.
  • Find other percentiles, not just median.
  • Validation adapter to check that the values are ordered.
  • Validation of known-total computation to ensure that the total is correct.
  • Find other percentiles, not just median.
  • Support "wide" bins as well as single values.
  • Validation adapter to check that the values are ordered.
  • Validation of known-total computation to ensure that the total is correct.
Source Link
Toby Speight
  • 88.7k
  • 14
  • 104
  • 327

Compute median of a histogram

Another look at finding median value. This time the input is a histogram represented as an ordered range of (value, count) pairs, or simply as a range of counts, with value inferred from position. I've taken care to avoid integer overflow in the computation, which unfortunately complicates the code more than I would like.

There's a simpler implementation for bidirectional ranges, and a more general version for when the total count is known (this is useful for image histograms, where the total is width ✕ height).

I have provided an adapter to ensure counts are in the range [0, +∞). This is to be explicitly applied by user, rather than built into the functions, so that it imposes no penalty on code that doesn't use it. Other features I might add in future, but not provided here, are:

  • Find other percentiles, not just median.
  • Validation adapter to check that the values are ordered.
  • Validation of known-total computation to ensure that the total is correct.
#include <cmath>
#include <concepts>
#include <iterator>
#include <numeric>
#include <stdexcept>
#include <ranges>
#include <tuple>

namespace median
{
    // Median result types - allow indication of integer-and-a-half

    template<typename T>
    struct result {};

    template<std::floating_point T>
    struct result<T> {
        T value;
        result(T mid) : value{mid} {}
        result(T left, T right) : value{std::midpoint(left, right)} {}
        operator T() const { return value; }
        auto as_double() const { return value; }
    };

    template<std::integral T>
    struct result<T> {
        T whole_part;
        bool plus_half;
        result(T mid)
            : whole_part{mid},
              plus_half{false}
        {}
        result(T left, T right)
            : whole_part{std::midpoint(left, right)},
              plus_half{left % 2 != right % 2}
        {
            if (plus_half && right < left) { --whole_part; }
        }
        auto as_double() const { return static_cast<double>(whole_part) + 0.5 * plus_half; }
        explicit operator T() const { return whole_part; }
        explicit operator double() const { return as_double(); }
    };


    // Range adapters for verifying range values

    auto const checked_histogram = std::views::transform([](auto&& val) {
        if constexpr (requires{ std::get<1>(val); }) {
            auto const& [i,count] = val;
            if (count < 0 || !std::isfinite(count)) {
                throw std::domain_error("invalid histogram entry");
            }
        } else {
            if (val < 0 || !std::isfinite(val)) {
                throw std::domain_error("invalid histogram entry");
            }
        }
        return val;
    });

    // A histogram is an ordered range of {value, count} pairs.
    // Alternatively, a range of counts can be provided and the
    // values 0, 1, 2, ... will be inferred.

    // Median of a bidirectional histogram

    template<std::ranges::bidirectional_range R>
    auto from_histogram(R const& input)
        -> result<std::remove_cv_t<std::tuple_element_t<0,std::ranges::range_value_t<R>>>>
        requires requires(std::ranges::range_value_t<R> value) { std::get<1>(value); }
    {
        auto left = input.cbegin();
        auto right = input.cend();
        if (left == right) {
            throw std::domain_error("empty histogram");
        }

        auto constexpr value = [](std::indirectly_readable auto iter){ return std::get<0>(*iter); };
        auto constexpr count = [](std::indirectly_readable auto iter){ return std::get<1>(*iter); };
        auto constexpr has_positive_count = [](auto const& pair){ return std::get<1>(pair) > 0; };

        auto left_sum = count(left);
        auto right_sum = count(--right);

        while (left != right) {
            // Reduce sums so that at least one of them is zero
            if (left_sum > right_sum) {
                left_sum -= right_sum;
                right_sum = 0;
            } else {
                right_sum -= left_sum;
                left_sum = 0;
            }
            // advance one of the iterators
            if (left_sum) {
                right_sum += count(--right);
            } else if (right_sum) {
                left_sum += count(++left);
            } else {
                // left and right sums both zero
                auto const it = std::find_if(std::next(left), right, has_positive_count);
                if (it == right) {
                    return {value(left), value(right)};
                }
                left_sum += count(left = it);
            }
        }

        return value(left);
    }


    // Median of a forward-only histogram

    template<std::ranges::forward_range R>
    auto from_histogram(R const& input)
        -> result<std::remove_cv_t<std::tuple_element_t<0,std::ranges::range_value_t<R>>>>
        requires (not std::ranges::bidirectional_range<R>)
        and requires(std::ranges::range_value_t<R> value) { std::get<1>(value); }
    {
        auto constexpr has_positive_count = [](auto const& pair){ return std::get<1>(pair) > 0; };

        auto left = input.cbegin();
        auto right = left;
        auto const last = input.cend();
        if (right == last) {
            throw std::domain_error("empty histogram");
        }

        auto constexpr value = [](std::indirectly_readable auto iter){ return std::get<0>(*iter); };
        auto constexpr count = [](std::indirectly_readable auto iter){ return std::get<1>(*iter); };

        auto left_sum = count(left) + 0; // addition promotes smaller types to signed/unsigned int
        auto right_sum = count(right) + 0;

        auto constexpr addition_would_overflow = [](auto augend, auto addend)
        {
            // Neither argument is negative.
            return addend > std::numeric_limits<decltype(augend)>::max() - augend;
        };


        auto constexpr reduce = [](auto& left_sum, auto &right_sum) {
            // // Reduce sums
            if (left_sum > 1 && right_sum > 2) {
                auto subtrahend = std::min(left_sum - 1, (right_sum - 1) / 2);
                left_sum -= subtrahend;
                right_sum -= subtrahend * 2;
            }
        };

        using std::next;

        while (next(right) != last) {
            reduce(left_sum, right_sum);
            {
                // advance right
                auto right_addend = count(++right);
                reduce(left_sum, right_addend);
                while (addition_would_overflow(right_sum, right_addend)) {
                    auto left_addend = count(++left);
                    reduce(left_addend, right_addend);
                    left_sum += left_addend;
                    if (left == right) { break; }
                }
                right_sum += right_addend;
            }
            while (!addition_would_overflow(left_sum, left_sum)  &&  left_sum + left_sum < right_sum) {
                // advance left until it reaches right/2
                auto left_addend = count(++left);
                reduce(left_addend, right_sum);
                while (addition_would_overflow(left_sum, left_addend)) {
                    auto right_addend = count(++right);
                    reduce(left_addend, right_addend);
                    right_sum += right_addend;
                    if (next(right) == last) {
                        break;
                    }
                }
                left_sum += left_addend;
            }
        }

        if (left_sum * 2 == right_sum) {
            // tie break
            if (left == right) [[unlikely]] {
                // only happens with {0} as input
                return value(left);
            }
            auto const it = std::ranges::find_if(std::next(left), right, has_positive_count);
            return {value(left), value(it)};
        }

        return value(left);
    }

    template<std::ranges::forward_range R>
    auto from_histogram(R const& input)
        requires std::assignable_from<double&, std::ranges::range_value_t<R>>
    {
        return median::from_histogram(input | std::views::enumerate);
    }


    // Median of any histogram, when total population is known

    template<typename T, std::ranges::forward_range R>
    auto from_histogram(T const& total, R const& input)
        -> result<std::remove_cv_t<std::tuple_element_t<0,std::ranges::range_value_t<R>>>>
        requires requires(std::ranges::range_value_t<R> value) { std::get<1>(value); }
    {
        auto iter = input.cbegin();
        auto const last = input.cend();
        if (iter == last) {
            throw std::domain_error("empty histogram");
        }

        auto constexpr value = [](std::indirectly_readable auto iter){ return std::get<0>(*iter); };
        auto constexpr count = [](std::indirectly_readable auto iter){ return std::get<1>(*iter); };

        if (total == 0) [[unlikely]] {
            // Try to return something sensible
            auto const left_val = value(iter++);
            auto right_val = left_val;
            while (iter != last) {
                right_val = value(iter++);
            }
            return {left_val, right_val};
        }

        T sum = 0;

        do {
            auto addend = count(iter);
            if (addend > total - sum) {
                throw std::domain_error("overpopulated histogram");
            }
            sum += count(iter);
            if (sum > total / 2) {
                return value(iter);
            }
            if (2 * sum == total) {
                // find midpoint
                auto const left_val = value(iter);
                iter = std::ranges::find_if(++iter, last, [](auto val){ return std::get<1>(val) != 0; });
                if (iter == last) { break; }
                return {left_val, value(iter)};
            }
        } while (++iter != last);

        // ran off the end
        throw std::domain_error("underpopulated histogram");
    }

    template<std::ranges::input_range R, std::convertible_to<std::ranges::range_value_t<R>> T>
    auto from_histogram(T total, R const& input)
    {
        return median::from_histogram(total, input | std::views::enumerate);
    }

}

I used these tests to write the preceding functions:

#include <gtest/gtest.h>

#include <climits>
#include <forward_list>
#include <map>
#include <vector>

TEST(vector_input, empty)
{
    std::vector<int> empty;
    EXPECT_THROW(median::from_histogram(empty), std::domain_error);
}

TEST(vector_input, one_element)
{
    std::vector<int> hist{0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 0);
}

TEST(vector_input, negative_element)
{
    std::vector<int> hist{-1};
    EXPECT_THROW(median::from_histogram(hist | median::checked_histogram), std::domain_error);
}

TEST(vector_input, one_TWO)
{
    std::vector<int> hist{1,2};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 1);
    EXPECT_EQ(median::from_histogram(hist | std::views::reverse).as_double(), 0);
}

TEST(vector_input, all_zero)
{
    std::vector<int> hist{0, 0, 0, 0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}

TEST(vector_input, all_one)
{
    std::vector<int> hist{1, 1, 1, 1};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}

TEST(vector_input, ones_zero_TWO_one)
{
    std::vector<int> hist{1, 1, 0, 2, 1};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 3);
    EXPECT_EQ(median::from_histogram(hist | std::views::reverse).as_double(), 1);
}

TEST(vector_input, ones_ZEROS_ones)
{
    std::vector<int> hist{1, 1, 0, 0, 0, 1, 1};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 3);
    EXPECT_EQ(median::from_histogram(hist | std::views::reverse).as_double(), 3);
}

TEST(vector_input, max_MAX_max)
{
    std::vector<unsigned> hist{0, UINT_MAX, UINT_MAX, UINT_MAX, 0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 2);
}

TEST(vector_input, max_MAX_MAX_max)
{
    std::vector<unsigned> hist{0, UINT_MAX, UINT_MAX, UINT_MAX, UINT_MAX, 0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 2.5);
}


TEST(map_input, TEN)
{
    std::map<double,unsigned> hist{{10, 1}};
    EXPECT_EQ(median::from_histogram(hist), 10);
}

TEST(map_input, zero_TEN_FIFTEEN_twenty)
{
    std::map<double,unsigned> hist{{0, 4}, {10.5, 1}, {15, 2}, {20, 3}};
    EXPECT_EQ(median::from_histogram(hist), 12.75);
}


TEST(list_input, empty)
{
    std::forward_list<int> empty;
    EXPECT_THROW(median::from_histogram(empty), std::domain_error);
}

TEST(list_input, one_element)
{
    std::forward_list<int> hist{0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 0);
}

TEST(list_input, negative_element)
{
    std::forward_list<int> hist{-1};
    EXPECT_THROW(median::from_histogram(hist | median::checked_histogram), std::domain_error);
}

TEST(list_input, one_TWO)
{
    std::forward_list<int> hist{1,2};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 1);
}

TEST(list_input, all_zero)
{
    std::forward_list<int> hist{0, 0, 0, 0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}

TEST(list_input, all_one)
{
    std::forward_list<int> hist{1, 1, 1, 1};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}

TEST(list_input, zero_ONES_zero)
{
    std::forward_list<int> hist{0, 1, 1, 0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}

TEST(list_input, ones_zero_TWO_one)
{
    std::forward_list<int> hist{1, 1, 0, 2, 1};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 3);
}

TEST(list_input, ones_ZEROS_ones)
{
    std::forward_list<int> hist{1, 1, 0, 0, 0, 1, 1};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 3);
}

TEST(list_input, max_MAX_max)
{
    std::forward_list<unsigned> hist{0, UINT_MAX, UINT_MAX, UINT_MAX, 0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 2);
}

TEST(list_input, max_MAX_MAX_max)
{
    std::forward_list<unsigned> hist{0, UINT_MAX, UINT_MAX, UINT_MAX, UINT_MAX, 0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 2.5);
}


TEST(list_input_with_total, empty)
{
    std::forward_list<int> empty;
    EXPECT_THROW(median::from_histogram(0, empty), std::domain_error);
}

TEST(list_input_with_total, one_element)
{
    std::forward_list<int> hist{0};
    EXPECT_EQ(median::from_histogram(0, hist).as_double(), 0);
}

TEST(list_input_with_total, negative_element)
{
    std::forward_list<int> hist{-1};
    EXPECT_THROW(median::from_histogram(1, hist | median::checked_histogram), std::domain_error);
}

TEST(list_input_with_total, one_TWO)
{
    std::forward_list<int> hist{1,2};
    EXPECT_EQ(median::from_histogram(3, hist).as_double(), 1);
}

TEST(list_input_with_total, all_zero)
{
    std::forward_list<int> hist{0, 0, 0, 0};
    EXPECT_EQ(median::from_histogram(0, hist).as_double(), 1.5);
}

TEST(list_input_with_total, all_one)
{
    std::forward_list<int> hist{1, 1, 1, 1};
    EXPECT_EQ(median::from_histogram(4, hist).as_double(), 1.5);
}

TEST(list_input_with_total, ones_zero_TWO_one)
{
    std::forward_list<int> hist{1, 1, 0, 2, 1};
    EXPECT_EQ(median::from_histogram(5, hist).as_double(), 3);
}

TEST(list_input_with_total, ones_ZEROS_ones)
{
    std::forward_list<int> hist{1, 1, 0, 0, 0, 1, 1};
    EXPECT_EQ(median::from_histogram(4, hist).as_double(), 3);
}

TEST(list_input_with_total, MAX_one)
{
    std::forward_list<unsigned int> hist{UINT_MAX-1, 1};
    EXPECT_EQ(median::from_histogram(UINT_MAX, hist).as_double(), 0);
}

TEST(list_input_with_total, one_MAX)
{
    std::forward_list<unsigned int> hist{1, UINT_MAX-1};
    EXPECT_EQ(median::from_histogram(UINT_MAX, hist).as_double(), 1);
}

TEST(list_input_with_total, max_max)
{
    // Misuse - total is not correct
    std::forward_list<unsigned int> hist{UINT_MAX/2, UINT_MAX};
    EXPECT_THROW(median::from_histogram(UINT_MAX, hist).as_double(), std::domain_error);
}
```