diff --git a/Sources/IntegerUtilities/Midpoint.swift b/Sources/IntegerUtilities/Midpoint.swift new file mode 100644 index 00000000..fa3b9759 --- /dev/null +++ b/Sources/IntegerUtilities/Midpoint.swift @@ -0,0 +1,64 @@ +//===--- Midpoint.swift ---------------------------------------*- swift -*-===// +// +// This source file is part of the Swift Numerics open source project +// +// Copyright (c) 2024 Apple Inc. and the Swift Numerics project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +/// The average of `a` and `b`, rounded to an integer according to `rule`. +/// +/// Unlike commonly seen expressions such as `(a+b)/2` or `(a+b) >> 1` or +/// `a + (b-a)/2` (all of which may overflow), this function never overflows, +/// and the result is guaranteed to be representable in the result type. +/// +/// The default rounding rule is `.down`, which matches the behavior of +/// `(a + b) >> 1` when that expression does not overflow. Rounding +/// `.towardZero` matches the behavior of `(a + b)/2` when that expression +/// does not overflow. All other rounding modes are supported. +/// +/// Rounding `.down` is generally most efficient; if you do not have a +/// reason to chose a specific other rounding rule, you should use the +/// default. +@inlinable +public func midpoint( + _ a: T, + _ b: T, + rounding rule: RoundingRule = .down +) -> T { + // Isolate bits in a + b with weight 2, and those with weight 1 + let twos = a & b + let ones = a ^ b + let floor = twos &+ ones >> 1 + let frac = ones & 1 + switch rule { + case .toNearestOrDown: + fallthrough + case .down: + return floor + case .toNearestOrUp: + fallthrough + case .up: + return floor &+ frac + case .toNearestOrZero: + fallthrough + case .towardZero: + return floor &+ (floor < 0 ? frac : 0) + case .toNearestOrAway: + fallthrough + case .awayFromZero: + return floor &+ (floor >= 0 ? frac : 0) + case .toNearestOrEven: + return floor &+ (floor & frac) + case .toOdd: + return floor &+ (~floor & frac) + case .stochastically: + return floor &+ (Bool.random() ? frac : 0) + case .requireExact: + precondition(frac == 0) + return floor + } +} diff --git a/Sources/IntegerUtilities/RoundingRule.swift b/Sources/IntegerUtilities/RoundingRule.swift index 0dc9d83f..325c1a02 100644 --- a/Sources/IntegerUtilities/RoundingRule.swift +++ b/Sources/IntegerUtilities/RoundingRule.swift @@ -256,5 +256,5 @@ extension RoundingRule { /// > Deprecated: Use `.toNearestOrAway` instead. @inlinable @available(*, deprecated, renamed: "toNearestOrAway") - static var toNearestOrAwayFromZero: Self { .toNearestOrAway } + public static var toNearestOrAwayFromZero: Self { .toNearestOrAway } } diff --git a/Tests/IntegerUtilitiesTests/MidpointTests.swift b/Tests/IntegerUtilitiesTests/MidpointTests.swift new file mode 100644 index 00000000..2b84cea8 --- /dev/null +++ b/Tests/IntegerUtilitiesTests/MidpointTests.swift @@ -0,0 +1,42 @@ +//===--- MidpointTests.swift ----------------------------------*- swift -*-===// +// +// This source file is part of the Swift Numerics open source project +// +// Copyright (c) 2024 Apple Inc. and the Swift Numerics project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +import IntegerUtilities +import XCTest + +final class IntegerUtilitiesMidpointTests: XCTestCase { + func testMidpoint() { + for rule in [ + RoundingRule.down, + .up, + .towardZero, + .awayFromZero, + .toNearestOrDown, + .toNearestOrUp, + .toNearestOrZero, + .toNearestOrAway, + .toNearestOrEven, + .toOdd + ] { + for a in -128 ... 127 { + for b in -128 ... 127 { + let ref = (a + b).shifted(rightBy: 1, rounding: rule) + let tst = midpoint(Int8(a), Int8(b), rounding: rule) + if ref != tst { + print(rule, a, b, ref, tst, separator: "\t") + return + } + } + } + } + } +}