5
\$\begingroup\$

Another look at finding median value. This time the input is a histogram represented as an ordered range of (value, count) pairs, or simply as a range of counts, with value inferred from position. I've taken care to avoid integer overflow in the computation, which unfortunately complicates the code more than I would like.

There's a simpler implementation for bidirectional ranges, and a more general version for when the total count is known (this is useful for image histograms, where the total is width ✕ height).

I have provided an adapter to ensure counts are in the range [0, +∞). This is to be explicitly applied by user, rather than built into the functions, so that it imposes no penalty on code that doesn't use it. Other features I might add in future, but not provided here, are:

  • Find other percentiles, not just median.
  • Support "wide" bins as well as single values.
  • Validation adapter to check that the values are ordered.
  • Validation of known-total computation to ensure that the total is correct.
#include <cmath>
#include <concepts>
#include <iterator>
#include <numeric>
#include <stdexcept>
#include <ranges>
#include <tuple>

namespace median
{
    // Median result types - allow indication of integer-and-a-half

    template<typename T>
    struct result {};

    template<std::floating_point T>
    struct result<T> {
        T value;
        result(T mid) : value{mid} {}
        result(T left, T right) : value{std::midpoint(left, right)} {}
        operator T() const { return value; }
        auto as_double() const { return value; }
    };

    template<std::integral T>
    struct result<T> {
        T whole_part;
        bool plus_half;
        result(T mid)
            : whole_part{mid},
              plus_half{false}
        {}
        result(T left, T right)
            : whole_part{std::midpoint(left, right)},
              plus_half{left % 2 != right % 2}
        {
            if (plus_half && right < left) { --whole_part; }
        }
        auto as_double() const { return static_cast<double>(whole_part) + 0.5 * plus_half; }
        explicit operator T() const { return whole_part; }
        explicit operator double() const { return as_double(); }
    };


    // Range adapters for verifying range values

    auto const checked_histogram = std::views::transform([](auto&& val) {
        if constexpr (requires{ std::get<1>(val); }) {
            auto const& [i,count] = val;
            if (count < 0 || !std::isfinite(count)) {
                throw std::domain_error("invalid histogram entry");
            }
        } else {
            if (val < 0 || !std::isfinite(val)) {
                throw std::domain_error("invalid histogram entry");
            }
        }
        return val;
    });

    // A histogram is an ordered range of {value, count} pairs.
    // Alternatively, a range of counts can be provided and the
    // values 0, 1, 2, ... will be inferred.

    // Median of a bidirectional histogram

    template<std::ranges::bidirectional_range R>
    auto from_histogram(R const& input)
        -> result<std::remove_cv_t<std::tuple_element_t<0,std::ranges::range_value_t<R>>>>
        requires requires(std::ranges::range_value_t<R> value) { std::get<1>(value); }
    {
        auto left = input.cbegin();
        auto right = input.cend();
        if (left == right) {
            throw std::domain_error("empty histogram");
        }

        auto constexpr value = [](std::indirectly_readable auto iter){ return std::get<0>(*iter); };
        auto constexpr count = [](std::indirectly_readable auto iter){ return std::get<1>(*iter); };
        auto constexpr has_positive_count = [](auto const& pair){ return std::get<1>(pair) > 0; };

        auto left_sum = count(left);
        auto right_sum = count(--right);

        while (left != right) {
            // Reduce sums so that at least one of them is zero
            if (left_sum > right_sum) {
                left_sum -= right_sum;
                right_sum = 0;
            } else {
                right_sum -= left_sum;
                left_sum = 0;
            }
            // advance one of the iterators
            if (left_sum) {
                right_sum += count(--right);
            } else if (right_sum) {
                left_sum += count(++left);
            } else {
                // left and right sums both zero
                auto const it = std::find_if(std::next(left), right, has_positive_count);
                if (it == right) {
                    return {value(left), value(right)};
                }
                left_sum += count(left = it);
            }
        }

        return value(left);
    }


    // Median of a forward-only histogram

    template<std::ranges::forward_range R>
    auto from_histogram(R const& input)
        -> result<std::remove_cv_t<std::tuple_element_t<0,std::ranges::range_value_t<R>>>>
        requires (not std::ranges::bidirectional_range<R>)
        and requires(std::ranges::range_value_t<R> value) { std::get<1>(value); }
    {
        auto constexpr has_positive_count = [](auto const& pair){ return std::get<1>(pair) > 0; };

        auto left = input.cbegin();
        auto right = left;
        auto const last = input.cend();
        if (right == last) {
            throw std::domain_error("empty histogram");
        }

        auto constexpr value = [](std::indirectly_readable auto iter){ return std::get<0>(*iter); };
        auto constexpr count = [](std::indirectly_readable auto iter){ return std::get<1>(*iter); };

        auto left_sum = count(left) + 0; // addition promotes smaller types to signed/unsigned int
        auto right_sum = count(right) + 0;

        auto constexpr addition_would_overflow = [](auto augend, auto addend)
        {
            // Neither argument is negative.
            return addend > std::numeric_limits<decltype(augend)>::max() - augend;
        };


        auto constexpr reduce = [](auto& left_sum, auto &right_sum) {
            // // Reduce sums
            if (left_sum > 1 && right_sum > 2) {
                auto subtrahend = std::min(left_sum - 1, (right_sum - 1) / 2);
                left_sum -= subtrahend;
                right_sum -= subtrahend * 2;
            }
        };

        using std::next;

        while (next(right) != last) {
            reduce(left_sum, right_sum);
            {
                // advance right
                auto right_addend = count(++right);
                reduce(left_sum, right_addend);
                while (addition_would_overflow(right_sum, right_addend)) {
                    auto left_addend = count(++left);
                    reduce(left_addend, right_addend);
                    left_sum += left_addend;
                    if (left == right) { break; }
                }
                right_sum += right_addend;
            }
            while (!addition_would_overflow(left_sum, left_sum)  &&  left_sum + left_sum < right_sum) {
                // advance left until it reaches right/2
                auto left_addend = count(++left);
                reduce(left_addend, right_sum);
                while (addition_would_overflow(left_sum, left_addend)) {
                    auto right_addend = count(++right);
                    reduce(left_addend, right_addend);
                    right_sum += right_addend;
                    if (next(right) == last) {
                        break;
                    }
                }
                left_sum += left_addend;
            }
        }

        if (left_sum * 2 == right_sum) {
            // tie break
            if (left == right) [[unlikely]] {
                // only happens with {0} as input
                return value(left);
            }
            auto const it = std::ranges::find_if(std::next(left), right, has_positive_count);
            return {value(left), value(it)};
        }

        return value(left);
    }

    template<std::ranges::forward_range R>
    auto from_histogram(R const& input)
        requires std::assignable_from<double&, std::ranges::range_value_t<R>>
    {
        return median::from_histogram(input | std::views::enumerate);
    }


    // Median of any histogram, when total population is known

    template<typename T, std::ranges::forward_range R>
    auto from_histogram(T const& total, R const& input)
        -> result<std::remove_cv_t<std::tuple_element_t<0,std::ranges::range_value_t<R>>>>
        requires requires(std::ranges::range_value_t<R> value) { std::get<1>(value); }
    {
        auto iter = input.cbegin();
        auto const last = input.cend();
        if (iter == last) {
            throw std::domain_error("empty histogram");
        }

        auto constexpr value = [](std::indirectly_readable auto iter){ return std::get<0>(*iter); };
        auto constexpr count = [](std::indirectly_readable auto iter){ return std::get<1>(*iter); };

        if (total == 0) [[unlikely]] {
            // Try to return something sensible
            auto const left_val = value(iter++);
            auto right_val = left_val;
            while (iter != last) {
                right_val = value(iter++);
            }
            return {left_val, right_val};
        }

        T sum = 0;

        do {
            auto addend = count(iter);
            if (addend > total - sum) {
                throw std::domain_error("overpopulated histogram");
            }
            sum += count(iter);
            if (sum > total / 2) {
                return value(iter);
            }
            if (2 * sum == total) {
                // find midpoint
                auto const left_val = value(iter);
                iter = std::ranges::find_if(++iter, last, [](auto val){ return std::get<1>(val) != 0; });
                if (iter == last) { break; }
                return {left_val, value(iter)};
            }
        } while (++iter != last);

        // ran off the end
        throw std::domain_error("underpopulated histogram");
    }

    template<std::ranges::input_range R, std::convertible_to<std::ranges::range_value_t<R>> T>
    auto from_histogram(T total, R const& input)
    {
        return median::from_histogram(total, input | std::views::enumerate);
    }

}

I used these tests to write the preceding functions:

#include <gtest/gtest.h>

#include <climits>
#include <forward_list>
#include <map>
#include <vector>

TEST(vector_input, empty)
{
    std::vector<int> empty;
    EXPECT_THROW(median::from_histogram(empty), std::domain_error);
}

TEST(vector_input, one_element)
{
    std::vector<int> hist{0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 0);
}

TEST(vector_input, negative_element)
{
    std::vector<int> hist{-1};
    EXPECT_THROW(median::from_histogram(hist | median::checked_histogram), std::domain_error);
}

TEST(vector_input, one_TWO)
{
    std::vector<int> hist{1,2};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 1);
    EXPECT_EQ(median::from_histogram(hist | std::views::reverse).as_double(), 0);
}

TEST(vector_input, all_zero)
{
    std::vector<int> hist{0, 0, 0, 0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}

TEST(vector_input, all_one)
{
    std::vector<int> hist{1, 1, 1, 1};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}

TEST(vector_input, ones_zero_TWO_one)
{
    std::vector<int> hist{1, 1, 0, 2, 1};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 3);
    EXPECT_EQ(median::from_histogram(hist | std::views::reverse).as_double(), 1);
}

TEST(vector_input, ones_ZEROS_ones)
{
    std::vector<int> hist{1, 1, 0, 0, 0, 1, 1};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 3);
    EXPECT_EQ(median::from_histogram(hist | std::views::reverse).as_double(), 3);
}

TEST(vector_input, max_MAX_max)
{
    std::vector<unsigned> hist{0, UINT_MAX, UINT_MAX, UINT_MAX, 0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 2);
}

TEST(vector_input, max_MAX_MAX_max)
{
    std::vector<unsigned> hist{0, UINT_MAX, UINT_MAX, UINT_MAX, UINT_MAX, 0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 2.5);
}


TEST(map_input, TEN)
{
    std::map<double,unsigned> hist{{10, 1}};
    EXPECT_EQ(median::from_histogram(hist), 10);
}

TEST(map_input, zero_TEN_FIFTEEN_twenty)
{
    std::map<double,unsigned> hist{{0, 4}, {10.5, 1}, {15, 2}, {20, 3}};
    EXPECT_EQ(median::from_histogram(hist), 12.75);
}


TEST(list_input, empty)
{
    std::forward_list<int> empty;
    EXPECT_THROW(median::from_histogram(empty), std::domain_error);
}

TEST(list_input, one_element)
{
    std::forward_list<int> hist{0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 0);
}

TEST(list_input, negative_element)
{
    std::forward_list<int> hist{-1};
    EXPECT_THROW(median::from_histogram(hist | median::checked_histogram), std::domain_error);
}

TEST(list_input, one_TWO)
{
    std::forward_list<int> hist{1,2};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 1);
}

TEST(list_input, all_zero)
{
    std::forward_list<int> hist{0, 0, 0, 0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}

TEST(list_input, all_one)
{
    std::forward_list<int> hist{1, 1, 1, 1};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}

TEST(list_input, zero_ONES_zero)
{
    std::forward_list<int> hist{0, 1, 1, 0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 1.5);
}

TEST(list_input, ones_zero_TWO_one)
{
    std::forward_list<int> hist{1, 1, 0, 2, 1};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 3);
}

TEST(list_input, ones_ZEROS_ones)
{
    std::forward_list<int> hist{1, 1, 0, 0, 0, 1, 1};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 3);
}

TEST(list_input, max_MAX_max)
{
    std::forward_list<unsigned> hist{0, UINT_MAX, UINT_MAX, UINT_MAX, 0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 2);
}

TEST(list_input, max_MAX_MAX_max)
{
    std::forward_list<unsigned> hist{0, UINT_MAX, UINT_MAX, UINT_MAX, UINT_MAX, 0};
    EXPECT_EQ(median::from_histogram(hist).as_double(), 2.5);
}


TEST(list_input_with_total, empty)
{
    std::forward_list<int> empty;
    EXPECT_THROW(median::from_histogram(0, empty), std::domain_error);
}

TEST(list_input_with_total, one_element)
{
    std::forward_list<int> hist{0};
    EXPECT_EQ(median::from_histogram(0, hist).as_double(), 0);
}

TEST(list_input_with_total, negative_element)
{
    std::forward_list<int> hist{-1};
    EXPECT_THROW(median::from_histogram(1, hist | median::checked_histogram), std::domain_error);
}

TEST(list_input_with_total, one_TWO)
{
    std::forward_list<int> hist{1,2};
    EXPECT_EQ(median::from_histogram(3, hist).as_double(), 1);
}

TEST(list_input_with_total, all_zero)
{
    std::forward_list<int> hist{0, 0, 0, 0};
    EXPECT_EQ(median::from_histogram(0, hist).as_double(), 1.5);
}

TEST(list_input_with_total, all_one)
{
    std::forward_list<int> hist{1, 1, 1, 1};
    EXPECT_EQ(median::from_histogram(4, hist).as_double(), 1.5);
}

TEST(list_input_with_total, ones_zero_TWO_one)
{
    std::forward_list<int> hist{1, 1, 0, 2, 1};
    EXPECT_EQ(median::from_histogram(5, hist).as_double(), 3);
}

TEST(list_input_with_total, ones_ZEROS_ones)
{
    std::forward_list<int> hist{1, 1, 0, 0, 0, 1, 1};
    EXPECT_EQ(median::from_histogram(4, hist).as_double(), 3);
}

TEST(list_input_with_total, MAX_one)
{
    std::forward_list<unsigned int> hist{UINT_MAX-1, 1};
    EXPECT_EQ(median::from_histogram(UINT_MAX, hist).as_double(), 0);
}

TEST(list_input_with_total, one_MAX)
{
    std::forward_list<unsigned int> hist{1, UINT_MAX-1};
    EXPECT_EQ(median::from_histogram(UINT_MAX, hist).as_double(), 1);
}

TEST(list_input_with_total, max_max)
{
    // Misuse - total is not correct
    std::forward_list<unsigned int> hist{UINT_MAX/2, UINT_MAX};
    EXPECT_THROW(median::from_histogram(UINT_MAX, hist).as_double(), std::domain_error);
}
```
\$\endgroup\$

1 Answer 1

2
\$\begingroup\$

It looks very complicated, and it does some weird things:

Weird result types

I would expect a median function to return a simple integer or floating point, not a result<T>. For floating point, I see no point at all; you could just return std::midpoint(left, right) unconditionally. For integers, things are indeed a bit more complicated. There is the rounding part when one wants an integer result, I get that. But why check for right < left? Wouldn't the histograms be sorted? If not, does it even make sense to ask for a median?

It looks like a lot of work for the case where one wants a double result from a histogram of integral values. Wouldn't it be easier to just add a template parameter to from_histogram() to set the result type? So that you would write this:

std::vector<int> hist{1, 1, 1, 1};
EXPECT_EQ(median::from_histogram<double>(hist), 1.5);

Furthermore, you only allow std::integral and std::floating_point types. But what about other types that are ordered and would in principle allow a midpoint to be calculated, like for example a custom fraction or bignum type?

Consider using a projection function

You have 5 overloads for from_histogram(), 2 to handle ranges with an implicit value, 3 to handle tuples of at least 2 values, where the first tuple element is taken to be the value, and the second the count. What if I have a range of 3-tuples and the count I'm interested in is in the third element? What if my value/count pairs are in a regular struct instead of in a std::tuple?

Consider allowing a projection function to be passed to from_histogram() that allows customizing how value and count information is extracted from the input range. You can provide a default function that does the equivalent of your current overloads:

template<std::ranges::forward_range R, class Proj = std::identity>
requires (/* Proj valid for R */)
auto from_histogram(R const& input, Proj proj = {}) {
    if constexpr (/* projection returns a single value */) {
        /* enumerate */
        …
    } else if constexpr (/* projection returns a 2-tuple */) {
        /* split return value of projection into value and count */
        …
    } else {
        /* return a compile error */
    }
}

Overflow handling

I question any code which would produce a histogram whose counts sum to more than can be stored in the count type; it would likely be buggy itself, as a different distribution of values could then already cause an overflow. But your algorithm can handle that, which is good.

I wonder if left_sum * 2 could wrap? If so, your algorithm could return an incorrect result.

\$\endgroup\$

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.