diff --git a/include/boost/histogram/accumulators.hpp b/include/boost/histogram/accumulators.hpp index af449d32..2f0e62b4 100644 --- a/include/boost/histogram/accumulators.hpp +++ b/include/boost/histogram/accumulators.hpp @@ -23,5 +23,6 @@ #include #include #include +#include #endif diff --git a/include/boost/histogram/accumulators/fraction.hpp b/include/boost/histogram/accumulators/fraction.hpp index 81fd41c5..bcb967a3 100644 --- a/include/boost/histogram/accumulators/fraction.hpp +++ b/include/boost/histogram/accumulators/fraction.hpp @@ -9,7 +9,9 @@ #include #include // for fraction<> +#include #include +#include #include // for std::common_type namespace boost { @@ -36,7 +38,8 @@ class fraction { using const_reference = const value_type&; using real_type = typename std::conditional::value, value_type, double>::type; - using interval_type = typename utility::wilson_interval::interval_type; + using interval_type = + typename utility::binomial_proportion_interval::interval_type; fraction() noexcept = default; @@ -51,11 +54,14 @@ class fraction { static_cast(e.failures())} {} /// Insert boolean sample x. - void operator()(bool x) noexcept { + void operator()(bool x) noexcept { operator()(weight(1), x); } + + /// Insert boolean sample x with weight w. + void operator()(const weight_type& w, bool x) noexcept { if (x) - ++succ_; + succ_ += w.value; else - ++fail_; + fail_ += w.value; } /// Add another accumulator. @@ -79,18 +85,12 @@ class fraction { /// Return variance of the success fraction. real_type variance() const noexcept { - // We want to compute Var(p) for p = X / n with Var(X) = n p (1 - p) - // For Var(X) see - // https://en.wikipedia.org/wiki/Binomial_distribution#Expected_value_and_variance - // Error propagation: Var(p) = p'(X)^2 Var(X) = p (1 - p) / n - const real_type p = value(); - return p * (1 - p) / count(); + return variance_for_p_and_n_eff(value(), count()); } /// Return standard interval with 68.3 % confidence level (Wilson score interval). interval_type confidence_interval() const noexcept { - return utility::wilson_interval()(static_cast(successes()), - static_cast(failures())); + return confidence_interval(utility::wilson_interval()); } bool operator==(const fraction& rhs) const noexcept { @@ -106,6 +106,24 @@ class fraction { } private: + friend class weighted_fraction; + + // Calculate the variance for a given success fraction and effective number of samples. + template + static real_type variance_for_p_and_n_eff(const real_type& p, const T& n_eff) noexcept { + // We want to compute Var(p) for p = X / n with Var(X) = n p (1 - p) + // For Var(X) see + // https://en.wikipedia.org/wiki/Binomial_distribution#Expected_value_and_variance + // Error propagation: Var(p) = p'(X)^2 Var(X) = p (1 - p) / n + return p * (1 - p) / n_eff; + } + + // Return interval for the given binomial proportion interval computer. + interval_type confidence_interval( + const utility::binomial_proportion_interval& b) const noexcept { + return b(static_cast(successes()), static_cast(failures())); + } + value_type succ_{}; value_type fail_{}; }; diff --git a/include/boost/histogram/accumulators/ostream.hpp b/include/boost/histogram/accumulators/ostream.hpp index 46b914c9..3fbf3e45 100644 --- a/include/boost/histogram/accumulators/ostream.hpp +++ b/include/boost/histogram/accumulators/ostream.hpp @@ -101,6 +101,14 @@ std::basic_ostream& operator<<(std::basic_ostream& return detail::handle_nonzero_width(os, x); } +template +std::basic_ostream& operator<<(std::basic_ostream& os, + const weighted_fraction& x) { + if (os.width() == 0) + return os << "weighted_fraction(" << x.get_fraction() << ", " << x.sum_w2() << ")"; + return detail::handle_nonzero_width(os, x); +} + } // namespace accumulators } // namespace histogram } // namespace boost diff --git a/include/boost/histogram/experimental/weighted_fraction.hpp b/include/boost/histogram/experimental/weighted_fraction.hpp new file mode 100644 index 00000000..d35dc934 --- /dev/null +++ b/include/boost/histogram/experimental/weighted_fraction.hpp @@ -0,0 +1,198 @@ +// Copyright 2022 Jay Gohil, Hans Dembinski +// +// Distributed under the Boost Software License, version 1.0. +// (See accompanying file LICENSE_1_0.txt +// or copy at http://www.boost.org/LICENSE_1_0.txt) + +#ifndef BOOST_HISTOGRAM_ACCUMULATORS_WEIGHTED_FRACTION_HPP +#define BOOST_HISTOGRAM_ACCUMULATORS_WEIGHTED_FRACTION_HPP + +#include +#include +#include +#include // for weighted_fraction<> +#include +#include // for std::common_type + +namespace boost { +namespace histogram { +namespace accumulators { + +namespace internal { + +// Accumulates the sum of weights squared. +template +class sum_of_weights_squared { +public: + using value_type = ValueType; + using const_reference = const value_type&; + + sum_of_weights_squared() = default; + + // Allow implicit conversion from sum_of_weights_squared + template + sum_of_weights_squared(const sum_of_weights_squared& o) noexcept + : sum_of_weights_squared(o.sum_of_weights_squared_) {} + + // Initialize to external sum of weights squared. + sum_of_weights_squared(const_reference sum_w2) noexcept + : sum_of_weights_squared_(sum_w2) {} + + // Increment by one. + sum_of_weights_squared& operator++() { + ++sum_of_weights_squared_; + return *this; + } + + // Increment by weight. + sum_of_weights_squared& operator+=(const weight_type& w) { + sum_of_weights_squared_ += detail::square(w.value); + return *this; + } + + // Added another sum_of_weights_squared. + sum_of_weights_squared& operator+=(const sum_of_weights_squared& rhs) { + sum_of_weights_squared_ += rhs.sum_of_weights_squared_; + return *this; + } + + bool operator==(const sum_of_weights_squared& rhs) const noexcept { + return sum_of_weights_squared_ == rhs.sum_of_weights_squared_; + } + + bool operator!=(const sum_of_weights_squared& rhs) const noexcept { + return !operator==(rhs); + } + + // Return sum of weights squared. + const_reference value() const noexcept { return sum_of_weights_squared_; } + + template + void serialize(Archive& ar, unsigned /* version */) { + ar& make_nvp("sum_of_weights_squared", sum_of_weights_squared_); + } + +private: + ValueType sum_of_weights_squared_{}; +}; + +} // namespace internal + +/// Accumulates weighted boolean samples and computes the fraction of true samples. +template +class weighted_fraction { +public: + using value_type = ValueType; + using const_reference = const value_type&; + using fraction_type = fraction; + using real_type = typename fraction_type::real_type; + using interval_type = typename fraction_type::interval_type; + + weighted_fraction() noexcept = default; + + /// Initialize to external fraction and sum of weights squared. + weighted_fraction(const fraction_type& f, const_reference sum_w2) noexcept + : f_(f), sum_w2_(sum_w2) {} + + /// Convert the weighted_fraction class to a different type T. + template + operator weighted_fraction() const noexcept { + return weighted_fraction(static_cast>(f_), + static_cast(sum_w2_.value())); + } + + /// Insert boolean sample x with weight 1. + void operator()(bool x) noexcept { operator()(weight(1), x); } + + /// Insert boolean sample x with weight w. + void operator()(const weight_type& w, bool x) noexcept { + f_(w, x); + sum_w2_ += w; + } + + /// Add another weighted_fraction. + weighted_fraction& operator+=(const weighted_fraction& rhs) noexcept { + f_ += rhs.f_; + sum_w2_ += rhs.sum_w2_; + return *this; + } + + bool operator==(const weighted_fraction& rhs) const noexcept { + return f_ == rhs.f_ && sum_w2_ == rhs.sum_w2_; + } + + bool operator!=(const weighted_fraction& rhs) const noexcept { + return !operator==(rhs); + } + + /// Return number of boolean samples that were true. + const_reference successes() const noexcept { return f_.successes(); } + + /// Return number of boolean samples that were false. + const_reference failures() const noexcept { return f_.failures(); } + + /// Return effective number of boolean samples. + real_type count() const noexcept { + return static_cast(detail::square(f_.count())) / sum_w2_.value(); + } + + /// Return success weighted_fraction of boolean samples. + real_type value() const noexcept { return f_.value(); } + + /// Return variance of the success weighted_fraction. + real_type variance() const noexcept { + return fraction_type::variance_for_p_and_n_eff(value(), count()); + } + + /// Return the sum of weights squared. + value_type sum_of_weights_squared() const noexcept { return sum_w2_.value(); } + + /// Return standard interval with 68.3 % confidence level (Wilson score interval). + interval_type confidence_interval() const noexcept { + return confidence_interval(utility::wilson_interval()); + } + + /// Return the Wilson score interval. + interval_type confidence_interval( + const utility::wilson_interval& w) const noexcept { + const real_type n_eff = count(); + const real_type p_hat = value(); + const real_type correction = w.third_order_correction(n_eff); + return w.solve_for_neff_phat_correction(n_eff, p_hat, correction); + } + + /// Return the fraction. + const fraction_type& get_fraction() const noexcept { return f_; } + + /// Return the sum of weights squared. + const value_type& sum_w2() const noexcept { return sum_w2_.value(); } + + template + void serialize(Archive& ar, unsigned /* version */) { + ar& make_nvp("fraction", f_); + ar& make_nvp("sum_of_weights_squared", sum_w2_); + } + +private: + fraction_type f_; + internal::sum_of_weights_squared sum_w2_; +}; + +} // namespace accumulators +} // namespace histogram +} // namespace boost + +#ifndef BOOST_HISTOGRAM_DOXYGEN_INVOKED + +namespace std { +template +/// Specialization for boost::histogram::accumulators::weighted_fraction. +struct common_type, + boost::histogram::accumulators::weighted_fraction> { + using type = boost::histogram::accumulators::weighted_fraction>; +}; +} // namespace std + +#endif + +#endif diff --git a/include/boost/histogram/fwd.hpp b/include/boost/histogram/fwd.hpp index 3c48f784..edc88e55 100644 --- a/include/boost/histogram/fwd.hpp +++ b/include/boost/histogram/fwd.hpp @@ -97,6 +97,16 @@ class count; template class fraction; +namespace internal { + +template +class sum_of_weights_squared; + +} // namespace internal + +template +class weighted_fraction; + template class sum; diff --git a/include/boost/histogram/utility/wilson_interval.hpp b/include/boost/histogram/utility/wilson_interval.hpp index 98fbf123..ae62d9b3 100644 --- a/include/boost/histogram/utility/wilson_interval.hpp +++ b/include/boost/histogram/utility/wilson_interval.hpp @@ -75,7 +75,72 @@ class wilson_interval : public binomial_proportion_interval { return {t1 - t2, t1 + t2}; } + /// Returns the third order correction for n_eff. + static value_type third_order_correction(value_type n_eff) noexcept { + // The approximate formula reads: + // f(n) = (n³ + n² + 2n + 6) / n³ + // + // Applying the substitution x = 1 / n gives: + // f(n) = 1 + x + 2x² + 6x³ + // + // Using Horner's method to evaluate this polynomial gives: + // f(n) = 1 + x (1 + x (2 + 6x)) + if (n_eff == 0) return 1; + const value_type x = 1 / n_eff; + return 1 + x * (1 + x * (2 + 6 * x)); + } + + /** Computer the confidence interval for the provided problem. + + @param p The problem to solve. + */ + interval_type solve_for_neff_phat_correction( + const value_type& n_eff, const value_type& p_hat, + const value_type& correction) const noexcept { + // Equation 41 from this paper: https://arxiv.org/abs/2110.00294 + // (p̂ - p)² = p (1 - p) (z² f(n) / n) + // Multiply by n to avoid floating point error when n = 0. + // n (p̂ - p)² = p (1 - p) z² f(n) + // Expand. + // np² - 2np̂p + np̂² = pz²f(n) - p²z²f(n) + // Collect terms of p. + // p²(n + z²f(n)) + p(-2np̂ - z²f(n)) + (np̂²) = 0 + // + // This is a quadratic equation ap² + bp + c = 0 where + // a = n + z²f(n) + // b = -2np̂ - z²f(n) + // c = np̂² + + const value_type zz_correction = (z_ * z_) * correction; + + const value_type a = n_eff + zz_correction; + const value_type b = -2 * n_eff * p_hat - zz_correction; + const value_type c = n_eff * (p_hat * p_hat); + + return quadratic_roots(a, b, c); + } + private: + // Finds the roots of the quadratic equation ax² + bx + c = 0. + static interval_type quadratic_roots(const value_type& a, const value_type& b, + const value_type& c) noexcept { + // https://people.csail.mit.edu/bkph/articles/Quadratics.pdf + + const value_type two_a = 2 * a; + const value_type two_c = 2 * c; + const value_type sqrt_bb_4ac = std::sqrt(b * b - two_a * two_c); + + if (b >= 0) { + const value_type root1 = (-b - sqrt_bb_4ac) / two_a; + const value_type root2 = two_c / (-b - sqrt_bb_4ac); + return {root1, root2}; + } else { + const value_type root1 = two_c / (-b + sqrt_bb_4ac); + const value_type root2 = (-b + sqrt_bb_4ac) / two_a; + return {root1, root2}; + } + } + value_type z_; }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 337da698..103ccdfa 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -41,6 +41,8 @@ boost_test(TYPE compile-fail SOURCES histogram_fail4.cpp) set(BOOST_TEST_LINK_LIBRARIES Boost::histogram Boost::core) boost_test(TYPE run SOURCES accumulators_count_test.cpp) +boost_test(TYPE run SOURCES accumulators_fraction_test.cpp) +boost_test(TYPE run SOURCES accumulators_weighted_fraction_test.cpp) boost_test(TYPE run SOURCES accumulators_mean_test.cpp) boost_test(TYPE run SOURCES accumulators_sum_test.cpp) boost_test(TYPE run SOURCES accumulators_weighted_mean_test.cpp) @@ -96,7 +98,6 @@ boost_test(TYPE run SOURCES unlimited_storage_test.cpp) boost_test(TYPE run SOURCES tools_test.cpp) boost_test(TYPE run SOURCES issue_327_test.cpp) boost_test(TYPE run SOURCES issue_353_test.cpp) -boost_test(TYPE run SOURCES accumulators_fraction_test.cpp) boost_test(TYPE run SOURCES utility_binomial_proportion_interval_test.cpp) boost_test(TYPE run SOURCES utility_wald_interval_test.cpp) boost_test(TYPE run SOURCES utility_wilson_interval_test.cpp) diff --git a/test/Jamfile b/test/Jamfile index d38b5398..2f6e6551 100644 --- a/test/Jamfile +++ b/test/Jamfile @@ -41,6 +41,7 @@ alias odr : alias cxx14 : [ run accumulators_count_test.cpp ] [ run accumulators_fraction_test.cpp ] + [ run accumulators_weighted_fraction_test.cpp ] [ run accumulators_mean_test.cpp ] [ run accumulators_sum_test.cpp : : : # make sure sum accumulator works even with -ffast-math and optimizations diff --git a/test/accumulators_serialization_test.cpp b/test/accumulators_serialization_test.cpp index b09132a8..c42797e0 100644 --- a/test/accumulators_serialization_test.cpp +++ b/test/accumulators_serialization_test.cpp @@ -102,5 +102,34 @@ int main(int argc, char** argv) { BOOST_TEST(a == b); } + // sum_of_weights_squared + { + const auto filename = + join(argv[1], "accumulators_serialization_test_sum_of_weights_squared.xml"); + accumulators::internal::sum_of_weights_squared<> a; + ++a; + print_xml(filename, a); + + accumulators::internal::sum_of_weights_squared<> b; + BOOST_TEST_NOT(a == b); + load_xml(filename, b); + BOOST_TEST(a == b); + } + + // weighted_fraction + { + const auto filename = + join(argv[1], "accumulators_serialization_test_weighted_fraction.xml"); + accumulators::weighted_fraction<> a; + a(true); + a(weight(6), false); + print_xml(filename, a); + + accumulators::weighted_fraction<> b; + BOOST_TEST_NOT(a == b); + load_xml(filename, b); + BOOST_TEST(a == b); + } + return boost::report_errors(); } diff --git a/test/accumulators_serialization_test_sum_of_weights_squared.xml b/test/accumulators_serialization_test_sum_of_weights_squared.xml new file mode 100644 index 00000000..dd0c854b --- /dev/null +++ b/test/accumulators_serialization_test_sum_of_weights_squared.xml @@ -0,0 +1,17 @@ + + + + + + + 1.00000000000000000e+00 + + + + diff --git a/test/accumulators_serialization_test_weighted_fraction.xml b/test/accumulators_serialization_test_weighted_fraction.xml new file mode 100644 index 00000000..4a5ff2f4 --- /dev/null +++ b/test/accumulators_serialization_test_weighted_fraction.xml @@ -0,0 +1,24 @@ + + + + + + + + 1.00000000000000000e+00 + 6.00000000000000000e+00 + + + 3.70000000000000000e+01 + + + + + + diff --git a/test/accumulators_weighted_fraction_test.cpp b/test/accumulators_weighted_fraction_test.cpp new file mode 100644 index 00000000..28d1606d --- /dev/null +++ b/test/accumulators_weighted_fraction_test.cpp @@ -0,0 +1,151 @@ +// Copyright 2015-2018 Hans Dembinski +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt +// or copy at http://www.boost.org/LICENSE_1_0.txt) + +#include +#include +#include +#include +#include +#include +#include "is_close.hpp" +#include "str.hpp" +#include "throw_exception.hpp" + +using namespace boost::histogram; +using namespace std::literals; + +template +void run_tests() { + using type_fw_t = accumulators::weighted_fraction; + using type_f_t = accumulators::fraction; + + const double eps = std::numeric_limits::epsilon(); + + { + type_fw_t f; + BOOST_TEST_EQ(f.successes(), 0); + BOOST_TEST_EQ(f.failures(), 0); + BOOST_TEST_EQ(f.sum_w2(), 0); + BOOST_TEST(std::isnan(f.value())); + BOOST_TEST(std::isnan(f.variance())); + + const auto ci = f.confidence_interval(); + BOOST_TEST(std::isnan(ci.first)); + BOOST_TEST(std::isnan(ci.second)); + } + + { + type_fw_t a(type_f_t(1, 0), 1); + type_fw_t b(type_f_t(0, 1), 1); + a += b; + BOOST_TEST_EQ(a, type_fw_t(type_f_t(1, 1), 2)); + + a(weight(2), true); // adds 2 trues and 2^2 to sum_of_weights_squared + a(weight(3), false); // adds 3 falses and 3^2 to sum_of_weights_squared + BOOST_TEST_EQ(a, type_fw_t(type_f_t(3, 4), 15)); + } + + { + type_fw_t f; + BOOST_TEST_EQ(f.successes(), 0); + BOOST_TEST_EQ(f.failures(), 0); + BOOST_TEST_EQ(f.sum_w2(), 0); + + f(true); + BOOST_TEST_EQ(f.successes(), 1); + BOOST_TEST_EQ(f.failures(), 0); + BOOST_TEST_EQ(f.sum_w2(), 1); + BOOST_TEST_EQ(str(f), "weighted_fraction(fraction(1, 0), 1)"s); + f(false); + BOOST_TEST_EQ(f.successes(), 1); + BOOST_TEST_EQ(f.failures(), 1); + BOOST_TEST_EQ(f.sum_w2(), 2); + BOOST_TEST_EQ(str(f), "weighted_fraction(fraction(1, 1), 2)"s); + BOOST_TEST_EQ(str(f, 41, false), " weighted_fraction(fraction(1, 1), 2)"s); + BOOST_TEST_EQ(str(f, 41, true), "weighted_fraction(fraction(1, 1), 2) "s); + } + + { + type_fw_t f(type_f_t(3, 1), 4); + BOOST_TEST_EQ(f.successes(), 3); + BOOST_TEST_EQ(f.failures(), 1); + BOOST_TEST_EQ(f.value(), 0.75); + BOOST_TEST_IS_CLOSE(f.variance(), 0.75 * (1 - 0.75) / 4, eps); + const auto ci = f.confidence_interval(); + + BOOST_TEST_EQ(f.count(), 4); + + // const auto expected = utility::wilson_interval()(3, 1); + const auto wilson = utility::wilson_interval(); + const auto expected = wilson.solve_for_neff_phat_correction( + 4, // n_eff = 4 + 0.75, // p_hat = 0.75 + 1.46875 // f(n) = (n³ + n² + 2n + 6) / n³ evaluated at n=4 + ); + + BOOST_TEST_IS_CLOSE(ci.first, expected.first, eps); + BOOST_TEST_IS_CLOSE(ci.second, expected.second, eps); + } + + { + type_fw_t f(type_f_t(0, 1), 1); + BOOST_TEST_EQ(f.successes(), 0); + BOOST_TEST_EQ(f.failures(), 1); + BOOST_TEST_EQ(f.value(), 0); + BOOST_TEST_EQ(f.variance(), 0); + const auto ci = f.confidence_interval(); + + const auto wilson = utility::wilson_interval(); + const auto expected = wilson.solve_for_neff_phat_correction( + 1, // n_eff = 1 + 0, // p_hat = 0 + 10 // f(n) = (n³ + n² + 2n + 6) / n³ evaluated at n=1 + ); + + BOOST_TEST_IS_CLOSE(ci.first, expected.first, eps); + BOOST_TEST_IS_CLOSE(ci.second, expected.second, eps); + } + + { + type_fw_t f(type_f_t(1, 0), 1); + BOOST_TEST_EQ(f.successes(), 1); + BOOST_TEST_EQ(f.failures(), 0); + BOOST_TEST_EQ(f.value(), 1); + BOOST_TEST_EQ(f.variance(), 0); + const auto ci = f.confidence_interval(); + + const auto wilson = utility::wilson_interval(); + const auto expected = wilson.solve_for_neff_phat_correction( + 1, // n_eff = 1 + 1, // p_hat = 1 + 10 // f(n) = (n³ + n² + 2n + 6) / n³ evaluated at n=1 + ); + + BOOST_TEST_IS_CLOSE(ci.first, expected.first, eps); + BOOST_TEST_IS_CLOSE(ci.second, expected.second, eps); + } +} + +int main() { + run_tests(); + run_tests(); + run_tests(); + + { + using type_f_double = accumulators::fraction; + using type_fw_double = accumulators::weighted_fraction; + using type_fw_int = accumulators::weighted_fraction; + + type_fw_double fw_double(type_f_double(5, 3), 88); + type_fw_int fw_int(fw_double); + + BOOST_TEST_EQ(fw_int.successes(), 5); + BOOST_TEST_EQ(fw_int.failures(), 3); + BOOST_TEST_EQ(fw_int.sum_w2(), 88); + } + + return boost::report_errors(); +}