Skip to main content
1 of 5
Famiu
  • 378
  • 1
  • 9

`BigInt` implementation in C++23

Cooked this up in 3 days because I was bored and thought this would be an interesting project to work on. This BigInt implementation does what you'd expect it to do, you can use it like an ordinary integer type except there's no limit to how big the number can be (well, technically, there is a limit which is 2 ^ PTRDIFF_MAX, but that's several exbiytes of data on a 64-bit machine so good luck reaching that).

It uses a sign-magnitude representation for the number, the sign being stored as a boolean value and the magnitude being stored as a std::vector of chunks. Each chunk is an integer of type std::uint_fast32_t. Optimization for small numbers is planned in the future but not implemented yet.

The implementation spans 4 files, bigint.hpp which contains the declarations and some template functions, bigint.cpp which contains most of the implementation minus the logic for parsing strings into BigInt, string_parser.cpp which contains the logic for string parsing, and finally utils.hpp which contains some useful utilities used throughout the project.

The whole project is available on https://github.com/famiu/BigInt

Here's the code posted here for convenience:

include/bigint.hpp:

#pragma once

#include <concepts>
#include <cstdint>
#include <deque>
#include <format>
#include <limits>
#include <stdexcept>
#include <string>
#include <type_traits>

#include "utils.hpp"

namespace BI
{
class BigInt
{
public:
    using ChunkType = std::uint_fast32_t;
    using DataType = std::deque<ChunkType>;

    BigInt();
    BigInt(BigInt const &rhs) = default;
    BigInt(BigInt &&rhs) noexcept = default;

    explicit BigInt(std::integral auto const &num) noexcept : negative{num < 0}
    {
        auto const num_unsigned = detail::to_unsigned(negative ? -num : num);
        auto const num_size = sizeof(num_unsigned) * 8;

        for (size_t i = 0; i < num_size; i += chunk_bits) {
            chunks.push_back(static_cast<ChunkType>((num_unsigned >> i) & chunk_max));
        }

        // Remove leading zeroes.
        remove_leading_zeroes();
    }

    explicit BigInt(std::string_view num);
    ~BigInt() = default;

    auto operator=(BigInt const &rhs) noexcept -> BigInt & = default;
    auto operator=(BigInt &&rhs) noexcept -> BigInt & = default;

    auto operator+() const noexcept -> BigInt;
    auto operator-() const noexcept -> BigInt;

    auto operator+(BigInt const &rhs) const noexcept -> BigInt;
    auto operator-(BigInt const &rhs) const noexcept -> BigInt;
    auto operator*(BigInt const &rhs) const noexcept -> BigInt;
    auto operator/(BigInt const &rhs) const -> BigInt;
    auto operator%(BigInt const &rhs) const -> BigInt;

    auto operator<<(size_t rhs) const noexcept -> BigInt;
    auto operator>>(size_t rhs) const noexcept -> BigInt;

    auto operator+=(BigInt const &rhs) noexcept -> BigInt &;
    auto operator-=(BigInt const &rhs) noexcept -> BigInt &;
    auto operator*=(BigInt const &rhs) noexcept -> BigInt &;
    auto operator/=(BigInt const &rhs) -> BigInt &;
    auto operator%=(BigInt const &rhs) -> BigInt &;

    auto operator<<=(size_t rhs) noexcept -> BigInt &;
    auto operator>>=(size_t rhs) noexcept -> BigInt &;

    auto operator++() noexcept -> BigInt &;
    auto operator--() noexcept -> BigInt &;
    auto operator++(int) noexcept -> BigInt;
    auto operator--(int) noexcept -> BigInt;

    auto operator<=>(BigInt const &rhs) const noexcept -> std::strong_ordering;
    auto operator==(BigInt const &rhs) const noexcept -> bool;

    auto operator<=>(std::integral auto const &rhs) const noexcept -> std::strong_ordering
    {
        try {
            return static_cast<decltype(rhs)>(*this) <=> rhs;
        } catch (std::overflow_error const &) {
            return negative ? std::strong_ordering::less : std::strong_ordering::greater;
        }
    }

    auto operator==(std::integral auto const &rhs) const noexcept -> bool
    {
        return (*this <=> rhs) == std::strong_ordering::equal;
    }

    template<std::integral T>
    explicit operator T() const
    {
        using UnsignedT = std::make_unsigned_t<T>;

        constexpr auto is_signed = std::is_signed_v<T>;
        size_t const num_bits = this->bit_count();

        // Unsigned types cannot store negative numbers.
        if (negative && !is_signed) {
            throw std::underflow_error(std::format("Number can't fit in unsigned type '{}'", detail::type_name<T>()));
        }
        // Signed types can store 1 less bit than their signed counterpart.
        if (num_bits > (sizeof(T) * 8) - static_cast<size_t>(is_signed)) {
            throw std::overflow_error(  // clang-format off
                    std::format("Number is too large to be converted to type '{}'", detail::type_name<T>())
            );  // clang-format on
        }

        UnsignedT result{};

        for (size_t i = 0; i < sizeof(T) * 8; i += chunk_bits) {
            result |= static_cast<T>(this->chunks[i / chunk_bits]) << i;
        }

        return negative ? -static_cast<T>(result) : static_cast<T>(result);
    }

    explicit operator std::string() const;

    /// @brief Get the absolute value of the number.
    [[nodiscard]] auto abs() const noexcept -> BigInt;

    /// @brief Convert the number to the specified type.
    ///
    /// @param[out] output Result of the conversion.
    /// @return Whether the conversion was successful.
    [[nodiscard]] auto convert(std::string &output) const noexcept -> bool
    {
        try {
            output = static_cast<std::string>(*this);
            return true;
        } catch (std::invalid_argument const &) {
            return false;
        }
    }

    /// @brief Convert the number to the specified type.
    ///
    /// @tparam T The type to convert the number to.
    /// @param[out] output Result of the conversion.
    /// @return Whether the conversion was successful.
    [[nodiscard]] auto convert(std::integral auto &output) const noexcept -> bool
    {
        try {
            output = static_cast<decltype(output)>(*this);
            return true;
        } catch (std::overflow_error const &) {
            return false;
        }
    }

    /// @brief Divide two numbers and return the quotient and remainder.
    ///
    /// @param num The dividend.
    /// @param denom The divisor.
    /// @return The quotient and remainder.
    ///
    /// @throw std::domain_error if the divisor is 0.
    [[nodiscard]] static auto div(BigInt const &num, BigInt const &denom) -> std::pair<BigInt, BigInt>;

    /// @brief Raise the number to the specified power.
    ///
    /// @param power The power to raise the number to.
    /// @return The result of the exponentiation.
    ///
    /// @note Only works for non-negative powers.
    /// @note 0^0 returns 1.
    [[nodiscard]] auto pow(size_t power) const noexcept -> BigInt;

    friend std::formatter<BigInt>;
    friend auto operator""_bi(char const *) -> BigInt;

private:
    /// @brief Sign of the number.
    bool negative{false};
    /// @brief Chunks of the number. Stored in little endian.
    DataType chunks;

    /// @brief Supported bases.
    enum class Base : std::uint_fast8_t
    {
        Binary = 2,
        Octal = 8,
        Decimal = 10,
        Hexadecimal = 16
    };

    /// @brief Number of bits in a chunk.
    static constexpr auto chunk_bits = sizeof(ChunkType) * 8;
    /// @brief Maximum value a chunk can store.
    static constexpr ChunkType chunk_max = std::numeric_limits<ChunkType>::max();

    /// @brief Get the number of bits in the number.
    [[nodiscard]] auto bit_count() const -> size_t;

    /// @brief Get the bit at the specified index.
    ///
    /// @param index The index of the bit to get. 0th bit is the least significant bit and the last bit is the most
    ///              significant bit.
    [[nodiscard]] auto get_bit_at(size_t index) const -> bool;

    /// @brief Check if the number is zero.
    [[nodiscard]] auto is_zero() const -> bool;
    /// @brief Remove leading zero chunks from the number.
    void remove_leading_zeroes();

    /// @brief Compare the magnitude of two numbers. Does not evaluate the sign.
    ///
    /// @param rhs The number to compare to.
    /// @return Ordering of the magnitude of the two numbers.
    [[nodiscard]] auto compare_magnitude(BigInt const &rhs) const noexcept -> std::strong_ordering;

    /// @brief Add the magnitude of lhs and rhs.
    ///
    /// @param rhs The number to add, must be smaller than or equal to lhs.
    /// @return The result of the addition.
    [[nodiscard]] auto add_magnitude(BigInt const &rhs) const noexcept -> BigInt;

    /// @brief Subtract the magnitude of rhs from lhs.
    ///
    /// @param rhs The number to subtract, must be smaller than or equal to lhs.
    /// @return The result of the subtraction.
    [[nodiscard]] auto subtract_magnitude(BigInt const &rhs) const noexcept -> BigInt;

    /// @brief Check if character is a valid digit in the given base.
    ///
    /// @param base The base to check the digit in.
    /// @param c The character to check.
    /// @return True if the character is a valid digit in the given base, false otherwise.
    ///
    /// @note Only works for bases 2, 8, 10, and 16.
    [[nodiscard]] static auto is_valid_digit(Base base, char c) -> bool;

    /// @brief Convert character to digit in the given base. The character must be a valid digit in the given base.
    ///
    /// @param base The base to convert the character to.
    /// @param c The character to convert.
    /// @return The digit represented by the character.
    ///
    /// @throws std::invalid_argument if the character is not a valid digit in the given base.
    /// @note Only works for bases 2, 8, 10, and 16.
    [[nodiscard]] static auto char_to_digit(Base base, char c) -> ChunkType;

    /// @brief Long divide string representation of a number by a divisor.
    ///
    /// @param num The number to divide, must be unsigned.
    /// @param[out] quotient Resulting quotient.
    /// @param base The base of the number.
    /// @param divisor The divisor.
    /// @return The remainder.
    ///
    /// @throws std::invalid_argument if num contains invalid digits for the given base.
    /// @note Only works for bases 2, 8, 10, and 16.
    static auto long_divide(std::string_view num, std::string &quotient, Base base, ChunkType divisor) -> ChunkType;

    /// @brief Convert string with power of two base to binary and store it in chunks.
    ///
    /// @param num The number to convert, must be unsigned.
    /// @param base The base of the number.
    ///
    /// @throws std::invalid_argument if num contains invalid digits for the given base.
    /// @note Only works for bases 2, 8, and 16.
    void power_of_two_base_to_binary(std::string_view num, Base base);

    /// @brief Convert decimal base to binary and store it in chunks.
    ///
    /// @param num The number to convert, must be unsigned.
    ///
    /// @throws std::invalid_argument if num contains invalid digits for the given base.
    void decimal_base_to_binary(std::string_view num);

    /// @brief Convert a base to binary and store it in chunks.
    ///
    /// @param num The number to convert, must be unsigned.
    /// @param base The base of the number.
    ///
    /// @throws std::invalid_argument if num contains invalid digits for the given base.
    /// @note Only works for bases 2, 8, 10, and 16.
    void base_to_binary(std::string_view num, Base base);

    /// @brief Format the number to a power of two base.
    ///
    /// @param base The base to format the number to.
    /// @param add_prefix Whether to add a base prefix to the formatted number (e.g. 0b for binary).
    /// @param capitalize Whether to capitalize the base prefix (if any) and the digits (for hexadecimal).
    /// @return The formatted number.
    ///
    /// @note Only works for bases 2, 8, and 16.
    [[nodiscard]] auto
    format_to_power_of_two_base(Base base, bool add_prefix = false, bool capitalize = false) const noexcept
      -> std::string;

    /// @brief Format the number to decimal.
    ///
    /// @return The formatted number.
    [[nodiscard]] auto format_to_decimal() const -> std::string;

    /// @brief Format the number to the specified base.
    ///
    /// @param base The base to format the number to.
    /// @param add_prefix Whether to add a base prefix to the formatted number (e.g. 0b for binary).
    /// @param capitalize Whether to capitalize the base prefix (if any) and the digits (for hexadecimal).
    /// @return The formatted number.
    ///
    /// @note Only works for bases 2, 8, 10, and 16.
    [[nodiscard]] auto format_to_base(Base base, bool add_prefix = false, bool capitalize = false) const -> std::string;
};
}  // namespace BI

auto operator<<(std::ostream &os, BI::BigInt const &num) -> std::ostream &;
auto operator""_bi(char const *) -> BI::BigInt;

template<>
struct std::formatter<BI::BigInt> : std::formatter<std::string>
{
    using BigInt = BI::BigInt;

    bool add_prefix = false;
    bool capitalize = false;
    BigInt::Base base{BigInt::Base::Decimal};

    constexpr auto parse(std::format_parse_context &ctx) -> decltype(ctx.begin())
    {
        auto const *it = ctx.begin();

        if (it == ctx.end() || *it == '}') {
            return it;
        }

        if (*it == '#') {
            add_prefix = true;
            std::advance(it, 1);
        }

        switch (*it) {
        case 'B':
            capitalize = true;
            [[fallthrough]];
        case 'b':
            base = BigInt::Base::Binary;
            break;
        case 'o':
            base = BigInt::Base::Octal;
            break;
        case 'd':
            base = BigInt::Base::Decimal;
            break;
        case 'X':
            capitalize = true;
            [[fallthrough]];
        case 'x':
            base = BigInt::Base::Hexadecimal;
            break;
        default:
            throw std::format_error("Invalid format specifier");
        }

        std::advance(it, 1);

        if (it != ctx.end() && *it != '}') {
            throw std::format_error("Invalid format specifier");
        }

        return it;
    }

    auto format(BigInt const &num, std::format_context &ctx) const
    {
        return std::format_to(ctx.out(), "{}", num.format_to_base(base, add_prefix, capitalize));
    }
};

static_assert(std::is_unsigned_v<BI::BigInt::ChunkType>, "ChunkType must be an unsigned integral type");
static_assert(std::is_same_v<BI::BigInt::DataType::value_type, BI::BigInt::ChunkType>, "DataType must store ChunkType");

src/bigint.cpp:

#include "bigint.hpp"

#include <cassert>
#include <cmath>
#include <utility>

using namespace BI;
using namespace BI::detail;

using ChunkType = BigInt::ChunkType;
using DataType = BigInt::DataType;

// Avoid having to convert the number to a BigInt over and over again.
static auto const one = BigInt(1);

BigInt::BigInt()
{
    chunks.push_back(0);
}

BigInt::BigInt(std::string_view num)
{
    auto throw_invalid_number = [&num]() {
        throw std::invalid_argument(std::format("Invalid number: \"{}\"", num));
    };

    if (num.empty()) {
        throw_invalid_number();
    }

    negative = num[0] == '-';
    size_t index = negative ? 1z : 0z;
    Base base{Base::Decimal};

    if (num.size() > index + 1 && num[index] == '0') {
        if (std::tolower(num[index + 1]) == 'x') {
            base = Base::Hexadecimal;
            index += 2;
        } else if (std::tolower(num[index + 1]) == 'b') {
            base = Base::Binary;
            index += 2;
        } else {
            base = Base::Octal;
            index += 1;
        }
    }

    if (index >= num.size()) {
        throw_invalid_number();
    }

    // Convert the number to binary and store it in chunks.
    try {
        base_to_binary(num.substr(index), base);
    } catch (std::invalid_argument const &e) {
        throw_invalid_number();
    }

    // Remove leading zeroes.
    remove_leading_zeroes();
}

auto BigInt::operator+() const noexcept -> BigInt
{
    return *this;
}

auto BigInt::operator-() const noexcept -> BigInt
{
    BigInt result{*this};
    result.negative = !result.negative;
    return result;
}

auto BigInt::operator+(BigInt const &rhs) const noexcept -> BigInt
{
    if (is_zero()) {
        return rhs;
    }

    if (rhs.is_zero()) {
        return *this;
    }

    bool magnitude_greater = compare_magnitude(rhs) == std::strong_ordering::greater;
    BigInt result;

    if (negative == rhs.negative) {
        result = magnitude_greater ? add_magnitude(rhs) : rhs.add_magnitude(*this);
    } else {
        result = magnitude_greater ? subtract_magnitude(rhs) : rhs.subtract_magnitude(*this);
    }

    result.negative = magnitude_greater ? negative : rhs.negative;
    return result;
}

auto BigInt::operator-(BigInt const &rhs) const noexcept -> BigInt
{
    if (is_zero()) {
        return -rhs;
    }

    if (rhs.is_zero()) {
        return *this;
    }

    bool magnitude_greater = compare_magnitude(rhs) == std::strong_ordering::greater;
    BigInt result;

    if (negative == rhs.negative) {
        result = magnitude_greater ? subtract_magnitude(rhs) : rhs.subtract_magnitude(*this);
    } else {
        result = magnitude_greater ? add_magnitude(rhs) : rhs.add_magnitude(*this);
    }

    result.negative = magnitude_greater ? negative : !rhs.negative;
    return result;
}

auto BigInt::operator*(BigInt const &rhs) const noexcept -> BigInt
{
    if (is_zero() || rhs.is_zero()) {
        return BigInt{};
    }
    if (*this == one) {
        return rhs;
    }
    if (rhs == one) {
        return *this;
    }

    bool const magnitude_greater = compare_magnitude(rhs) == std::strong_ordering::greater;
    BigInt const &larger = magnitude_greater ? *this : rhs;
    BigInt const &smaller = magnitude_greater ? rhs : *this;

    BigInt result{};

    // Iterate through each bit of the smaller number in reverse order.
    // Shift the result by one bit and add the larger number to the result if the bit is set.
    for (size_t i = smaller.bit_count(); i-- > 0;) {
        result <<= 1;

        if (smaller.get_bit_at(i)) {
            result += larger;
        }
    }

    result.negative = negative != rhs.negative;

    return result;
}

auto BigInt::operator/(BigInt const &rhs) const -> BigInt
{
    return div(*this, rhs).first;
}

auto BigInt::operator%(BigInt const &rhs) const -> BigInt
{
    return div(*this, rhs).second;
}

auto BigInt::operator<<(size_t rhs) const noexcept -> BigInt
{
    if (is_zero() || rhs == 0) {
        return *this;
    }

    BigInt result{*this};
    // Number of whole chunks to shift.
    size_t chunk_shift = rhs / chunk_bits;
    // Number of bits to shift within a chunk.
    size_t bit_shift = rhs % chunk_bits;

    // Add whole chunks of zeroes to the beginning of the number.
    result.chunks.insert(result.chunks.begin(), chunk_shift, 0);

    // Shift the bits within the remaining chunks.
    if (bit_shift != 0) {
        ChunkType carry = 0;

        for (size_t i = chunk_shift; i < result.chunks.size(); ++i) {
            // Get the bits that will be shifted out of the current chunk and store them in carry.
            // Append the carry from the previous chunk to the current chunk.
            ChunkType new_carry = result.chunks[i] >> (chunk_bits - bit_shift);
            result.chunks[i] = (result.chunks[i] << bit_shift) | carry;
            carry = new_carry;
        }

        // If there is a carry left, add it to the end of the number.
        if (carry != 0) {
            result.chunks.push_back(carry);
        }
    }

    return result;
}

auto BigInt::operator>>(size_t rhs) const noexcept -> BigInt
{
    if (is_zero() || rhs == 0) {
        return *this;
    }

    BigInt result{*this};
    // Number of whole chunks to shift.
    size_t chunk_shift = rhs / chunk_bits;
    // Number of bits to shift within a chunk.
    size_t bit_shift = rhs % chunk_bits;

    // Shift is larger than the number of bits in the number, return 0.
    if (chunk_shift >= result.chunks.size()) {
        return BigInt{};
    }

    // Erase the whole chunks that will be shifted.
    result.chunks.erase(result.chunks.begin(), std::next(result.chunks.begin(), to_signed(chunk_shift)));

    // Shift the bits within the remaining chunks.
    if (bit_shift != 0) {
        ChunkType carry = 0;

        for (size_t i = result.chunks.size(); i-- > 0;) {
            // Get the bits that will be shifted out of the current chunk and store them in carry.
            // Append the carry from the previous chunk to the current chunk.
            ChunkType new_carry = result.chunks[i] << (chunk_bits - bit_shift);
            result.chunks[i] = (result.chunks[i] >> bit_shift) | carry;
            carry = new_carry;
        }

        // Clear out any leading zero chunks that may have been created.
        result.remove_leading_zeroes();
    }

    return result;
}

auto BigInt::operator+=(BigInt const &rhs) noexcept -> BigInt &
{
    *this = *this + rhs;
    return *this;
}

auto BigInt::operator-=(BigInt const &rhs) noexcept -> BigInt &
{
    *this = *this - rhs;
    return *this;
}

auto BigInt::operator*=(BigInt const &rhs) noexcept -> BigInt &
{
    *this = *this * rhs;
    return *this;
}

auto BigInt::operator/=(BigInt const &rhs) -> BigInt &
{
    *this = *this / rhs;
    return *this;
}

auto BigInt::operator%=(BigInt const &rhs) -> BigInt &
{
    *this = *this % rhs;
    return *this;
}

auto BigInt::operator<<=(size_t rhs) noexcept -> BigInt &
{
    *this = *this << rhs;
    return *this;
}

auto BigInt::operator>>=(size_t rhs) noexcept -> BigInt &
{
    *this = *this >> rhs;
    return *this;
}

auto BigInt::operator++() noexcept -> BigInt &
{
    *this += one;
    return *this;
}

auto BigInt::operator--() noexcept -> BigInt &
{
    *this -= one;
    return *this;
}

auto BigInt::operator++(int) noexcept -> BigInt
{
    BigInt result{*this};
    *this += one;
    return result;
}

auto BigInt::operator--(int) noexcept -> BigInt
{
    BigInt result{*this};
    *this -= one;
    return result;
}

auto BigInt::operator<=>(BigInt const &rhs) const noexcept -> std::strong_ordering
{
    if (this->is_zero() && rhs.is_zero()) {
        return std::strong_ordering::equal;
    }

    if (negative != rhs.negative) {
        return negative ? std::strong_ordering::less : std::strong_ordering::greater;
    }

    return negative ? rhs.compare_magnitude(*this) : compare_magnitude(rhs);
}

auto BigInt::operator==(BigInt const &rhs) const noexcept -> bool
{
    return (*this <=> rhs) == std::strong_ordering::equal;
}

BigInt::operator std::string() const
{
    return format_to_base(Base::Decimal);
}

auto BigInt::abs() const noexcept -> BigInt
{
    BigInt result{*this};
    result.negative = false;
    return result;
}

auto operator<<(std::ostream &os, BigInt const &num) -> std::ostream &
{
    return os << static_cast<std::string>(num);
}

auto operator""_bi(char const *num) -> BigInt
{
    return BigInt{num};
}

auto BigInt::div(BigInt const &num, BigInt const &denom) -> std::pair<BigInt, BigInt>
{
    if (denom.is_zero()) {
        throw std::domain_error("Division by zero");
    }

    if (num.is_zero()) {
        return {BigInt{}, BigInt{}};
    }

    if (num.compare_magnitude(denom) == std::strong_ordering::less) {
        return {BigInt{}, num};
    }

    BigInt quotient{};
    BigInt remainder{num.abs()};

    // Perform long division:
    // 1. Find the largest multiple of the denominator that fits in the current remainder.
    // 2. Subtract the multiple from the remainder.
    // 3. Repeat until the remainder is less than the denominator.
    while (remainder.compare_magnitude(denom) != std::strong_ordering::less) {
        BigInt temp{denom.abs()};
        // Approximate the amount of shifts needed to align the most significant bit of the denominator with the
        // most significant bit of the remainder.
        size_t shift = remainder.bit_count() - temp.bit_count();

        // Align the most significant bit of the denominator with the most significant bit of the remainder.
        temp <<= shift;

        // If the denominator is still greater than the remainder, shift it to the right once.
        // This will guarantee that the denominator is less than the remainder.
        if (temp.compare_magnitude(remainder) == std::strong_ordering::greater) {
            temp >>= 1;
            --shift;
        }

        // Subtract the multiple from the remainder, and add the multiplier to the quotient.
        remainder -= temp;
        quotient += one << shift;
    }

    // For remainder, the sign is always the same as the dividend.
    remainder.negative = num.negative;
    quotient.negative = num.negative != denom.negative;

    return {quotient, remainder};
}

auto BigInt::pow(size_t power) const noexcept -> BigInt
{
    // x^0 = 1
    // NOTE: 0^0 also returns 1.
    if (power == 0) {
        return one;
    }

    // x^1 = x
    // 0^x = 0
    // 1^x = 1
    if (power == 1 || is_zero() || *this == one) {
        return *this;
    }

    static constexpr auto mask = static_cast<size_t>(1) << (sizeof(size_t) * 8 - 1);
    auto const power_leading_zeroes = static_cast<size_t>(std::countl_zero(power));
    // Get amount of bits in the power excluding leading zeroes.
    auto const power_bit_count = (sizeof(size_t) * 8) - power_leading_zeroes;

    // Get rid of leading zeroes so that we can iterate through the bits of the power.
    power <<= power_leading_zeroes;

    BigInt result(1);

    // Iterate through the bits of the power, starting from the most significant bit.
    // Square the result in each iteration, and multiply it by the base if the bit is set.
    // After each iteration, left shift the power by 1 to get the next bit.
    for (size_t i = 0; i < power_bit_count; ++i) {
        result *= result;
        if ((power & mask) != 0) {
            result *= *this;
        }

        power <<= 1;
    }

    return result;
}

auto BigInt::bit_count() const -> size_t
{
    return (chunks.size() * chunk_bits) - static_cast<size_t>(std::countl_zero(chunks.back()));
}

auto BigInt::get_bit_at(size_t index) const -> bool
{
    size_t const chunk_index = index / chunk_bits;
    size_t const bit_index = index % chunk_bits;

    // Even though the chunks are stored in little endian, the bits in a chunk are stored in big endian.
    return static_cast<bool>((chunks[chunk_index] >> bit_index) & 1);
}

auto BigInt::is_zero() const -> bool
{
    return chunks.size() == 1 && chunks[0] == 0;
}

void BigInt::remove_leading_zeroes()
{
    while (chunks.size() > 1 && chunks.back() == 0) {
        chunks.pop_back();
    }
}

auto BigInt::compare_magnitude(BigInt const &rhs) const noexcept -> std::strong_ordering
{
    if (chunks.size() != rhs.chunks.size()) {
        return chunks.size() <=> rhs.chunks.size();
    }

    for (size_t i = chunks.size(); i-- > 0;) {
        if (chunks[i] != rhs.chunks[i]) {
            return chunks[i] <=> rhs.chunks[i];
        }
    }

    return std::strong_ordering::equal;
}

auto BigInt::add_magnitude(BigInt const &rhs) const noexcept -> BigInt
{
    assert(compare_magnitude(rhs) != std::strong_ordering::less);

    BigInt result{*this};
    ChunkType carry = 0;

    for (size_t i = 0; i < rhs.chunks.size(); ++i) {
        // Check if the current chunk can be added without overflow.
        if (result.chunks[i] < chunk_max - rhs.chunks[i] - carry) {
            result.chunks[i] += rhs.chunks[i] + carry;
            carry = 0;
        } else {
            // Add with overflow.
            result.chunks[i] += rhs.chunks[i] + carry;
            carry = 1;
        }
    }

    // Add carry to the rest of the *this number.
    if (carry != 0) {
        for (size_t i = rhs.chunks.size(); i < result.chunks.size(); ++i) {
            if (result.chunks[i] == chunk_max) {
                result.chunks[i] = 0;
            } else {
                result.chunks[i] += carry;
                carry = 0;
                break;
            }
        }
    }

    // Add carry to the end of the number.
    if (carry != 0) {
        result.chunks.push_back(1);
    }

    return result;
}

auto BigInt::subtract_magnitude(BigInt const &rhs) const noexcept -> BigInt
{
    assert(compare_magnitude(rhs) != std::strong_ordering::less);

    BigInt result{*this};
    ChunkType borrow = 0;

    for (size_t i = 0; i < rhs.chunks.size(); ++i) {
        // Check if the current chunk can be subtracted without underflow.
        if (result.chunks[i] >= rhs.chunks[i] + borrow) {
            result.chunks[i] -= rhs.chunks[i] + borrow;
            borrow = 0;
        } else {
            // Subtract with underflow.
            result.chunks[i] -= rhs.chunks[i] + borrow;
            borrow = 1;
        }
    }

    // Subtract borrow from the rest of the *this number.
    if (borrow != 0) {
        for (size_t i = rhs.chunks.size(); i < result.chunks.size(); ++i) {
            if (result.chunks[i] == 0) {
                result.chunks[i] = chunk_max;
            } else {
                result.chunks[i] -= borrow;
                borrow = 0;
                break;
            }
        }
    }

    // Borrow cannot be 1 at the end of the number.
    assert(borrow == 0);

    // Remove leading zeroes.
    result.remove_leading_zeroes();

    return result;
}

src/string_parser.cpp:

#include "bigint.hpp"

#include <cassert>
#include <cmath>
#include <utility>

using namespace BI;
using namespace BI::detail;

using ChunkType = BigInt::ChunkType;
using DataType = BigInt::DataType;

// Avoid having to convert the number to a BigInt over and over again.
static auto const one = BigInt(1);

BigInt::BigInt()
{
    chunks.push_back(0);
}

BigInt::BigInt(std::string_view num)
{
    auto throw_invalid_number = [&num]() {
        throw std::invalid_argument(std::format("Invalid number: \"{}\"", num));
    };

    if (num.empty()) {
        throw_invalid_number();
    }

    negative = num[0] == '-';
    size_t index = negative ? 1z : 0z;
    Base base{Base::Decimal};

    if (num.size() > index + 1 && num[index] == '0') {
        if (std::tolower(num[index + 1]) == 'x') {
            base = Base::Hexadecimal;
            index += 2;
        } else if (std::tolower(num[index + 1]) == 'b') {
            base = Base::Binary;
            index += 2;
        } else {
            base = Base::Octal;
            index += 1;
        }
    }

    if (index >= num.size()) {
        throw_invalid_number();
    }

    // Convert the number to binary and store it in chunks.
    try {
        base_to_binary(num.substr(index), base);
    } catch (std::invalid_argument const &e) {
        throw_invalid_number();
    }

    // Remove leading zeroes.
    remove_leading_zeroes();
}

auto BigInt::operator+() const noexcept -> BigInt
{
    return *this;
}

auto BigInt::operator-() const noexcept -> BigInt
{
    BigInt result{*this};
    result.negative = !result.negative;
    return result;
}

auto BigInt::operator+(BigInt const &rhs) const noexcept -> BigInt
{
    if (is_zero()) {
        return rhs;
    }

    if (rhs.is_zero()) {
        return *this;
    }

    bool magnitude_greater = compare_magnitude(rhs) == std::strong_ordering::greater;
    BigInt result;

    if (negative == rhs.negative) {
        result = magnitude_greater ? add_magnitude(rhs) : rhs.add_magnitude(*this);
    } else {
        result = magnitude_greater ? subtract_magnitude(rhs) : rhs.subtract_magnitude(*this);
    }

    result.negative = magnitude_greater ? negative : rhs.negative;
    return result;
}

auto BigInt::operator-(BigInt const &rhs) const noexcept -> BigInt
{
    if (is_zero()) {
        return -rhs;
    }

    if (rhs.is_zero()) {
        return *this;
    }

    bool magnitude_greater = compare_magnitude(rhs) == std::strong_ordering::greater;
    BigInt result;

    if (negative == rhs.negative) {
        result = magnitude_greater ? subtract_magnitude(rhs) : rhs.subtract_magnitude(*this);
    } else {
        result = magnitude_greater ? add_magnitude(rhs) : rhs.add_magnitude(*this);
    }

    result.negative = magnitude_greater ? negative : !rhs.negative;
    return result;
}

auto BigInt::operator*(BigInt const &rhs) const noexcept -> BigInt
{
    if (is_zero() || rhs.is_zero()) {
        return BigInt{};
    }
    if (*this == one) {
        return rhs;
    }
    if (rhs == one) {
        return *this;
    }

    bool const magnitude_greater = compare_magnitude(rhs) == std::strong_ordering::greater;
    BigInt const &larger = magnitude_greater ? *this : rhs;
    BigInt const &smaller = magnitude_greater ? rhs : *this;

    BigInt result{};

    // Iterate through each bit of the smaller number in reverse order.
    // Shift the result by one bit and add the larger number to the result if the bit is set.
    for (size_t i = smaller.bit_count(); i-- > 0;) {
        result <<= 1;

        if (smaller.get_bit_at(i)) {
            result += larger;
        }
    }

    result.negative = negative != rhs.negative;

    return result;
}

auto BigInt::operator/(BigInt const &rhs) const -> BigInt
{
    return div(*this, rhs).first;
}

auto BigInt::operator%(BigInt const &rhs) const -> BigInt
{
    return div(*this, rhs).second;
}

auto BigInt::operator<<(size_t rhs) const noexcept -> BigInt
{
    if (is_zero() || rhs == 0) {
        return *this;
    }

    BigInt result{*this};
    // Number of whole chunks to shift.
    size_t chunk_shift = rhs / chunk_bits;
    // Number of bits to shift within a chunk.
    size_t bit_shift = rhs % chunk_bits;

    // Add whole chunks of zeroes to the beginning of the number.
    result.chunks.insert(result.chunks.begin(), chunk_shift, 0);

    // Shift the bits within the remaining chunks.
    if (bit_shift != 0) {
        ChunkType carry = 0;

        for (size_t i = chunk_shift; i < result.chunks.size(); ++i) {
            // Get the bits that will be shifted out of the current chunk and store them in carry.
            // Append the carry from the previous chunk to the current chunk.
            ChunkType new_carry = result.chunks[i] >> (chunk_bits - bit_shift);
            result.chunks[i] = (result.chunks[i] << bit_shift) | carry;
            carry = new_carry;
        }

        // If there is a carry left, add it to the end of the number.
        if (carry != 0) {
            result.chunks.push_back(carry);
        }
    }

    return result;
}

auto BigInt::operator>>(size_t rhs) const noexcept -> BigInt
{
    if (is_zero() || rhs == 0) {
        return *this;
    }

    BigInt result{*this};
    // Number of whole chunks to shift.
    size_t chunk_shift = rhs / chunk_bits;
    // Number of bits to shift within a chunk.
    size_t bit_shift = rhs % chunk_bits;

    // Shift is larger than the number of bits in the number, return 0.
    if (chunk_shift >= result.chunks.size()) {
        return BigInt{};
    }

    // Erase the whole chunks that will be shifted.
    result.chunks.erase(result.chunks.begin(), std::next(result.chunks.begin(), to_signed(chunk_shift)));

    // Shift the bits within the remaining chunks.
    if (bit_shift != 0) {
        ChunkType carry = 0;

        for (size_t i = result.chunks.size(); i-- > 0;) {
            // Get the bits that will be shifted out of the current chunk and store them in carry.
            // Append the carry from the previous chunk to the current chunk.
            ChunkType new_carry = result.chunks[i] << (chunk_bits - bit_shift);
            result.chunks[i] = (result.chunks[i] >> bit_shift) | carry;
            carry = new_carry;
        }

        // Clear out any leading zero chunks that may have been created.
        result.remove_leading_zeroes();
    }

    return result;
}

auto BigInt::operator+=(BigInt const &rhs) noexcept -> BigInt &
{
    *this = *this + rhs;
    return *this;
}

auto BigInt::operator-=(BigInt const &rhs) noexcept -> BigInt &
{
    *this = *this - rhs;
    return *this;
}

auto BigInt::operator*=(BigInt const &rhs) noexcept -> BigInt &
{
    *this = *this * rhs;
    return *this;
}

auto BigInt::operator/=(BigInt const &rhs) -> BigInt &
{
    *this = *this / rhs;
    return *this;
}

auto BigInt::operator%=(BigInt const &rhs) -> BigInt &
{
    *this = *this % rhs;
    return *this;
}

auto BigInt::operator<<=(size_t rhs) noexcept -> BigInt &
{
    *this = *this << rhs;
    return *this;
}

auto BigInt::operator>>=(size_t rhs) noexcept -> BigInt &
{
    *this = *this >> rhs;
    return *this;
}

auto BigInt::operator++() noexcept -> BigInt &
{
    *this += one;
    return *this;
}

auto BigInt::operator--() noexcept -> BigInt &
{
    *this -= one;
    return *this;
}

auto BigInt::operator++(int) noexcept -> BigInt
{
    BigInt result{*this};
    *this += one;
    return result;
}

auto BigInt::operator--(int) noexcept -> BigInt
{
    BigInt result{*this};
    *this -= one;
    return result;
}

auto BigInt::operator<=>(BigInt const &rhs) const noexcept -> std::strong_ordering
{
    if (this->is_zero() && rhs.is_zero()) {
        return std::strong_ordering::equal;
    }

    if (negative != rhs.negative) {
        return negative ? std::strong_ordering::less : std::strong_ordering::greater;
    }

    return negative ? rhs.compare_magnitude(*this) : compare_magnitude(rhs);
}

auto BigInt::operator==(BigInt const &rhs) const noexcept -> bool
{
    return (*this <=> rhs) == std::strong_ordering::equal;
}

BigInt::operator std::string() const
{
    return format_to_base(Base::Decimal);
}

auto BigInt::abs() const noexcept -> BigInt
{
    BigInt result{*this};
    result.negative = false;
    return result;
}

auto operator<<(std::ostream &os, BigInt const &num) -> std::ostream &
{
    return os << static_cast<std::string>(num);
}

auto operator""_bi(char const *num) -> BigInt
{
    return BigInt{num};
}

auto BigInt::div(BigInt const &num, BigInt const &denom) -> std::pair<BigInt, BigInt>
{
    if (denom.is_zero()) {
        throw std::domain_error("Division by zero");
    }

    if (num.is_zero()) {
        return {BigInt{}, BigInt{}};
    }

    if (num.compare_magnitude(denom) == std::strong_ordering::less) {
        return {BigInt{}, num};
    }

    BigInt quotient{};
    BigInt remainder{num.abs()};

    // Perform long division:
    // 1. Find the largest multiple of the denominator that fits in the current remainder.
    // 2. Subtract the multiple from the remainder.
    // 3. Repeat until the remainder is less than the denominator.
    while (remainder.compare_magnitude(denom) != std::strong_ordering::less) {
        BigInt temp{denom.abs()};
        // Approximate the amount of shifts needed to align the most significant bit of the denominator with the
        // most significant bit of the remainder.
        size_t shift = remainder.bit_count() - temp.bit_count();

        // Align the most significant bit of the denominator with the most significant bit of the remainder.
        temp <<= shift;

        // If the denominator is still greater than the remainder, shift it to the right once.
        // This will guarantee that the denominator is less than the remainder.
        if (temp.compare_magnitude(remainder) == std::strong_ordering::greater) {
            temp >>= 1;
            --shift;
        }

        // Subtract the multiple from the remainder, and add the multiplier to the quotient.
        remainder -= temp;
        quotient += one << shift;
    }

    // For remainder, the sign is always the same as the dividend.
    remainder.negative = num.negative;
    quotient.negative = num.negative != denom.negative;

    return {quotient, remainder};
}

auto BigInt::pow(size_t power) const noexcept -> BigInt
{
    // x^0 = 1
    // NOTE: 0^0 also returns 1.
    if (power == 0) {
        return one;
    }

    // x^1 = x
    // 0^x = 0
    // 1^x = 1
    if (power == 1 || is_zero() || *this == one) {
        return *this;
    }

    static constexpr auto mask = static_cast<size_t>(1) << (sizeof(size_t) * 8 - 1);
    auto const power_leading_zeroes = static_cast<size_t>(std::countl_zero(power));
    // Get amount of bits in the power excluding leading zeroes.
    auto const power_bit_count = (sizeof(size_t) * 8) - power_leading_zeroes;

    // Get rid of leading zeroes so that we can iterate through the bits of the power.
    power <<= power_leading_zeroes;

    BigInt result(1);

    // Iterate through the bits of the power, starting from the most significant bit.
    // Square the result in each iteration, and multiply it by the base if the bit is set.
    // After each iteration, left shift the power by 1 to get the next bit.
    for (size_t i = 0; i < power_bit_count; ++i) {
        result *= result;
        if ((power & mask) != 0) {
            result *= *this;
        }

        power <<= 1;
    }

    return result;
}

auto BigInt::bit_count() const -> size_t
{
    return (chunks.size() * chunk_bits) - static_cast<size_t>(std::countl_zero(chunks.back()));
}

auto BigInt::get_bit_at(size_t index) const -> bool
{
    size_t const chunk_index = index / chunk_bits;
    size_t const bit_index = index % chunk_bits;

    // Even though the chunks are stored in little endian, the bits in a chunk are stored in big endian.
    return static_cast<bool>((chunks[chunk_index] >> bit_index) & 1);
}

auto BigInt::is_zero() const -> bool
{
    return chunks.size() == 1 && chunks[0] == 0;
}

void BigInt::remove_leading_zeroes()
{
    while (chunks.size() > 1 && chunks.back() == 0) {
        chunks.pop_back();
    }
}

auto BigInt::compare_magnitude(BigInt const &rhs) const noexcept -> std::strong_ordering
{
    if (chunks.size() != rhs.chunks.size()) {
        return chunks.size() <=> rhs.chunks.size();
    }

    for (size_t i = chunks.size(); i-- > 0;) {
        if (chunks[i] != rhs.chunks[i]) {
            return chunks[i] <=> rhs.chunks[i];
        }
    }

    return std::strong_ordering::equal;
}

auto BigInt::add_magnitude(BigInt const &rhs) const noexcept -> BigInt
{
    assert(compare_magnitude(rhs) != std::strong_ordering::less);

    BigInt result{*this};
    ChunkType carry = 0;

    for (size_t i = 0; i < rhs.chunks.size(); ++i) {
        // Check if the current chunk can be added without overflow.
        if (result.chunks[i] < chunk_max - rhs.chunks[i] - carry) {
            result.chunks[i] += rhs.chunks[i] + carry;
            carry = 0;
        } else {
            // Add with overflow.
            result.chunks[i] += rhs.chunks[i] + carry;
            carry = 1;
        }
    }

    // Add carry to the rest of the *this number.
    if (carry != 0) {
        for (size_t i = rhs.chunks.size(); i < result.chunks.size(); ++i) {
            if (result.chunks[i] == chunk_max) {
                result.chunks[i] = 0;
            } else {
                result.chunks[i] += carry;
                carry = 0;
                break;
            }
        }
    }

    // Add carry to the end of the number.
    if (carry != 0) {
        result.chunks.push_back(1);
    }

    return result;
}

auto BigInt::subtract_magnitude(BigInt const &rhs) const noexcept -> BigInt
{
    assert(compare_magnitude(rhs) != std::strong_ordering::less);

    BigInt result{*this};
    ChunkType borrow = 0;

    for (size_t i = 0; i < rhs.chunks.size(); ++i) {
        // Check if the current chunk can be subtracted without underflow.
        if (result.chunks[i] >= rhs.chunks[i] + borrow) {
            result.chunks[i] -= rhs.chunks[i] + borrow;
            borrow = 0;
        } else {
            // Subtract with underflow.
            result.chunks[i] -= rhs.chunks[i] + borrow;
            borrow = 1;
        }
    }

    // Subtract borrow from the rest of the *this number.
    if (borrow != 0) {
        for (size_t i = rhs.chunks.size(); i < result.chunks.size(); ++i) {
            if (result.chunks[i] == 0) {
                result.chunks[i] = chunk_max;
            } else {
                result.chunks[i] -= borrow;
                borrow = 0;
                break;
            }
        }
    }

    // Borrow cannot be 1 at the end of the number.
    assert(borrow == 0);

    // Remove leading zeroes.
    result.remove_leading_zeroes();

    return result;
}

src/utils.hpp:

#pragma once

#include <limits>
#include <source_location>
#include <stdexcept>
#include <utility>

namespace BI::detail
{
template<typename T>
static constexpr auto to_unsigned(T const &num) -> std::make_unsigned_t<T>
{
    if constexpr (std::is_unsigned_v<T>) {
        return num;
    } else {
        if (num < 0) {
            throw std::underflow_error("Number is negative and cannot be converted to an unsigned type");
        }
        return static_cast<std::make_unsigned_t<T>>(num);
    }
}

template<typename T>
static constexpr auto to_signed(T const &num) -> std::make_signed_t<T>
{
    if constexpr (std::is_signed_v<T>) {
        return num;
    } else {
        if (num > std::numeric_limits<std::make_signed_t<T>>::max()) {
            throw std::overflow_error("Number is too large to be converted to a signed type");
        }
        return static_cast<std::make_signed_t<T>>(num);
    }
}

/// @brief Get the raw function name of a function with the type as a template parameter.
template<typename T>
constexpr auto type_to_string_raw() -> std::string_view
{
    return std::source_location::current().function_name();
}

/// @brief Raw function name of long double type, used to find the length of the prefix and suffix of the raw string.
constexpr std::string_view long_double_raw_string = type_to_string_raw<long double>();
constexpr std::string_view long_double_string = "long double";

template<std::size_t... Idxs>
constexpr auto substring_as_array(std::string_view str, std::index_sequence<Idxs...> /*unused*/)
{
    return std::array{str[Idxs]...};
}

template<typename T>
constexpr auto type_name_array()
{
    // Find prefix and suffix of every raw string using known type, by searching for the known type in the raw string.
    static constexpr auto prefix_len = long_double_raw_string.find(long_double_string);
    static constexpr auto suffix_len = long_double_raw_string.size() - prefix_len - long_double_string.size();

    constexpr std::string_view type_raw_string = type_to_string_raw<T>();
    constexpr auto start = prefix_len;
    constexpr auto end = type_raw_string.size() - suffix_len;

    static_assert(start < end);

    constexpr auto name = type_raw_string.substr(start, (end - start));

    // Convert the substring to a std::array so it can be used in a constexpr context.
    return substring_as_array(name, std::make_index_sequence<name.size()>{});
};

template<typename T>
struct type_name_holder
{
    static constexpr auto value = type_name_array<T>();
};

/// @brief Get the name of a type as a string.
///
/// @tparam T The type to get the name of.
/// @return The name of the type.
///
/// @see https://rodusek.com/posts/2021/03/09/getting-an-unmangled-type-name-at-compile-time/
template<typename T>
constexpr auto type_name() -> std::string_view
{
    constexpr auto &value = type_name_holder<T>::value;
    return std::string_view(value.data(), value.size());
}
}  // namespace BI::detail
```
Famiu
  • 378
  • 1
  • 9