#include "bigint.hpp"
#include <cassert>
#include <cmath>
#include <utility>
using namespace BI;
using namespace BI::detail;
using ChunkType = BigInt::ChunkType;
using DataType = BigInt::DataType;
// Avoid having to convert the number to a BigInt over and over again.
static auto const one = BigInt(1);
BigInt::BigInt()
{
    chunks.push_back(0);
}
BigInt::BigInt(std::string_view num)
{
    auto throw_invalid_number = [&num]() {
        throw std::invalid_argument(std::format("Invalid number: \"{}\"", num));
    };
    if (num.empty()) {
        throw_invalid_number();
    }
    negative = num[0] == '-';
    size_t index = negative ? 1z : 0z;
    Base base{Base::Decimal};
    if (num.size() > index + 1 && num[index] == '0') {
        if (std::tolower(num[index + 1]) == 'x') {
            base = Base::Hexadecimal;
            index += 2;
        } else if (std::tolower(num[index + 1]) == 'b') {
            base = Base::Binary;
            index += 2;
        } else {
            base = Base::Octal;
            index += 1;
        }
    }
    if (index >= num.size()) {
        throw_invalid_number();
    }
    // Convert the number to binary and store it in chunks.
    try {
        base_to_binary(num.substr(index), base);
    } catch (std::invalid_argument const &e) {
        throw_invalid_number();
    }
    // Remove leading zeroes.
    remove_leading_zeroes();
}
auto BigInt::operator+() const noexcept -> BigInt
{
    return *this;
}
auto BigInt::operator-() const noexcept -> BigInt
{
    BigInt result{*this};
    result.negative = !result.negative;
    return result;
}
auto BigInt::operator+(BigInt const &rhs) const noexcept -> BigInt
{
    if (is_zero()) {
        return rhs;
    }
    if (rhs.is_zero()) {
        return *this;
    }
    bool magnitude_greater = compare_magnitude(rhs) == std::strong_ordering::greater;
    BigInt result;
    if (negative == rhs.negative) {
        result = magnitude_greater ? add_magnitude(rhs) : rhs.add_magnitude(*this);
    } else {
        result = magnitude_greater ? subtract_magnitude(rhs) : rhs.subtract_magnitude(*this);
    }
    result.negative = magnitude_greater ? negative : rhs.negative;
    return result;
}
auto BigInt::operator-(BigInt const &rhs) const noexcept -> BigInt
{
    if (is_zero()) {
        return -rhs;
    }
    if (rhs.is_zero()) {
        return *this;
    }
    bool magnitude_greater = compare_magnitude(rhs) == std::strong_ordering::greater;
    BigInt result;
    if (negative == rhs.negative) {
        result = magnitude_greater ? subtract_magnitude(rhs) : rhs.subtract_magnitude(*this);
    } else {
        result = magnitude_greater ? add_magnitude(rhs) : rhs.add_magnitude(*this);
    }
    result.negative = magnitude_greater ? negative : !rhs.negative;
    return result;
}
auto BigInt::operator*(BigInt const &rhs) const noexcept -> BigInt
{
    if (is_zero() || rhs.is_zero()) {
        return BigInt{};
    }
    if (*this == one) {
        return rhs;
    }
    if (rhs == one) {
        return *this;
    }
    bool const magnitude_greater = compare_magnitude(rhs) == std::strong_ordering::greater;
    BigInt const &larger = magnitude_greater ? *this : rhs;
    BigInt const &smaller = magnitude_greater ? rhs : *this;
    BigInt result{};
    // Iterate through each bit of the smaller number in reverse order.
    // Shift the result by one bit and add the larger number to the result if the bit is set.
    for (size_t i = smaller.bit_count(); i-- > 0;) {
        result <<= 1;
        if (smaller.get_bit_at(i)) {
            result += larger;
        }
    }
    result.negative = negative != rhs.negative;
    return result;
}
auto BigInt::operator/(BigInt const &rhs) const -> BigInt
{
    return div(*this, rhs).first;
}
auto BigInt::operator%(BigInt const &rhs) const -> BigInt
{
    return div(*this, rhs).second;
}
auto BigInt::operator<<(size_t rhs) const noexcept -> BigInt
{
    if (is_zero() || rhs == 0) {
        return *this;
    }
    BigInt result{*this};
    // Number of whole chunks to shift.
    size_t chunk_shift = rhs / chunk_bits;
    // Number of bits to shift within a chunk.
    size_t bit_shift = rhs % chunk_bits;
    // Add whole chunks of zeroes to the beginning of the number.
    result.chunks.insert(result.chunks.begin(), chunk_shift, 0);
    // Shift the bits within the remaining chunks.
    if (bit_shift != 0) {
        ChunkType carry = 0;
        for (size_t i = chunk_shift; i < result.chunks.size(); ++i) {
            // Get the bits that will be shifted out of the current chunk and store them in carry.
            // Append the carry from the previous chunk to the current chunk.
            ChunkType new_carry = result.chunks[i] >> (chunk_bits - bit_shift);
            result.chunks[i] = (result.chunks[i] << bit_shift) | carry;
            carry = new_carry;
        }
        // If there is a carry left, add it to the end of the number.
        if (carry != 0) {
            result.chunks.push_back(carry);
        }
    }
    return result;
}
auto BigInt::operator>>(size_t rhs) const noexcept -> BigInt
{
    if (is_zero() || rhs == 0) {
        return *this;
    }
    BigInt result{*this};
    // Number of whole chunks to shift.
    size_t chunk_shift = rhs / chunk_bits;
    // Number of bits to shift within a chunk.
    size_t bit_shift = rhs % chunk_bits;
    // Shift is larger than the number of bits in the number, return 0.
    if (chunk_shift >= result.chunks.size()) {
        return BigInt{};
    }
    // Erase the whole chunks that will be shifted.
    result.chunks.erase(result.chunks.begin(), std::next(result.chunks.begin(), to_signed(chunk_shift)));
    // Shift the bits within the remaining chunks.
    if (bit_shift != 0) {
        ChunkType carry = 0;
        for (size_t i = result.chunks.size(); i-- > 0;) {
            // Get the bits that will be shifted out of the current chunk and store them in carry.
            // Append the carry from the previous chunk to the current chunk.
            ChunkType new_carry = result.chunks[i] << (chunk_bits - bit_shift);
            result.chunks[i] = (result.chunks[i] >> bit_shift) | carry;
            carry = new_carry;
        }
        // Clear out any leading zero chunks that may have been created.
        result.remove_leading_zeroes();
    }
    return result;
}
auto BigInt::operator+=(BigInt const &rhs) noexcept -> BigInt &
{
    *this = *this + rhs;
    return *this;
}
auto BigInt::operator-=(BigInt const &rhs) noexcept -> BigInt &
{
    *this = *this - rhs;
    return *this;
}
auto BigInt::operator*=(BigInt const &rhs) noexcept -> BigInt &
{
    *this = *this * rhs;
    return *this;
}
auto BigInt::operator/=(BigInt const &rhs) -> BigInt &
{
    *this = *this / rhs;
    return *this;
}
auto BigInt::operator%=(BigInt const &rhs) -> BigInt &
{
    *this = *this % rhs;
    return *this;
}
auto BigInt::operator<<=(size_t rhs) noexcept -> BigInt &
{
    *this = *this << rhs;
    return *this;
}
auto BigInt::operator>>=(size_t rhs) noexcept -> BigInt &
{
    *this = *this >> rhs;
    return *this;
}
auto BigInt::operator++() noexcept -> BigInt &
{
    *this += one;
    return *this;
}
auto BigInt::operator--() noexcept -> BigInt &
{
    *this -= one;
    return *this;
}
auto BigInt::operator++(int) noexcept -> BigInt
{
    BigInt result{*this};
    *this += one;
    return result;
}
auto BigInt::operator--(int) noexcept -> BigInt
{
    BigInt result{*this};
    *this -= one;
    return result;
}
auto BigInt::operator<=>(BigInt const &rhs) const noexcept -> std::strong_ordering
{
    if (this->is_zero() && rhs.is_zero()) {
        return std::strong_ordering::equal;
    }
    if (negative != rhs.negative) {
        return negative ? std::strong_ordering::less : std::strong_ordering::greater;
    }
    return negative ? rhs.compare_magnitude(*this) : compare_magnitude(rhs);
}
auto BigInt::operator==(BigInt const &rhs) const noexcept -> bool
{
    return (*this <=> rhs) == std::strong_ordering::equal;
}
BigInt::operator std::string() const
{
    return format_to_base(Base::Decimal);
}
auto BigInt::abs() const noexcept -> BigInt
{
    BigInt result{*this};
    result.negative = false;
    return result;
}
auto operator<<(std::ostream &os, BigInt const &num) -> std::ostream &
{
    return os << static_cast<std::string>(num);
}
auto operator""_bi(char const *num) -> BigInt
{
    return BigInt{num};
}
auto BigInt::div(BigInt const &num, BigInt const &denom) -> std::pair<BigInt, BigInt>
{
    if (denom.is_zero()) {
        throw std::domain_error("Division by zero");
    }
    if (num.is_zero()) {
        return {BigInt{}, BigInt{}};
    }
    if (num.compare_magnitude(denom) == std::strong_ordering::less) {
        return {BigInt{}, num};
    }
    BigInt quotient{};
    BigInt remainder{num.abs()};
    // Perform long division:
    // 1. Find the largest multiple of the denominator that fits in the current remainder.
    // 2. Subtract the multiple from the remainder.
    // 3. Repeat until the remainder is less than the denominator.
    while (remainder.compare_magnitude(denom) != std::strong_ordering::less) {
        BigInt temp{denom.abs()};
        // Approximate the amount of shifts needed to align the most significant bit of the denominator with the
        // most significant bit of the remainder.
        size_t shift = remainder.bit_count() - temp.bit_count();
        // Align the most significant bit of the denominator with the most significant bit of the remainder.
        temp <<= shift;
        // If the denominator is still greater than the remainder, shift it to the right once.
        // This will guarantee that the denominator is less than the remainder.
        if (temp.compare_magnitude(remainder) == std::strong_ordering::greater) {
            temp >>= 1;
            --shift;
        }
        // Subtract the multiple from the remainder, and add the multiplier to the quotient.
        remainder -= temp;
        quotient += one << shift;
    }
    // For remainder, the sign is always the same as the dividend.
    remainder.negative = num.negative;
    quotient.negative = num.negative != denom.negative;
    return {quotient, remainder};
}
auto BigInt::pow(size_t power) const noexcept -> BigInt
{
    // x^0 = 1
    // NOTE: 0^0 also returns 1.
    if (power == 0) {
        return one;
    }
    // x^1 = x
    // 0^x = 0
    // 1^x = 1
    if (power == 1 || is_zero() || *this == one) {
        return *this;
    }
    static constexpr auto mask = static_cast<size_t>(1) << (sizeof(size_t) * 8 - 1);
    auto const power_leading_zeroes = static_cast<size_t>(std::countl_zero(power));
    // Get amount of bits in the power excluding leading zeroes.
    auto const power_bit_count = (sizeof(size_t) * 8) - power_leading_zeroes;
    // Get rid of leading zeroes so that we can iterate through the bits of the power.
    power <<= power_leading_zeroes;
    BigInt result(1);
    // Iterate through the bits of the power, starting from the most significant bit.
    // Square the result in each iteration, and multiply it by the base if the bit is set.
    // After each iteration, left shift the power by 1 to get the next bit.
    for (size_t i = 0; i < power_bit_count; ++i) {
        result *= result;
        if ((power & mask) != 0) {
            result *= *this;
        }
        power <<= 1;
    }
    return result;
}
auto BigInt::bit_count() const -> size_t
{
    return (chunks.size() * chunk_bits) - static_cast<size_t>(std::countl_zero(chunks.back()));
}
auto BigInt::get_bit_at(size_t index) const -> bool
{
    size_t const chunk_index = index / chunk_bits;
    size_t const bit_index = index % chunk_bits;
    // Even though the chunks are stored in little endian, the bits in a chunk are stored in big endian.
    return static_cast<bool>((chunks[chunk_index] >> bit_index) & 1);
}
auto BigInt::is_zero() const -> bool
{
    return chunks.size() == 1 && chunks[0] == 0;
}
void BigInt::remove_leading_zeroes()
{
    while (chunks.size() > 1 && chunks.back() == 0) {
        chunks.pop_back();
    }
}
auto BigInt::compare_magnitude(BigInt const &rhs) const noexcept -> std::strong_ordering
{
    if (chunks.size() != rhs.chunks.size()) {
        return chunks.size() <=> rhs.chunks.size();
    }
    for (size_t i = chunks.size(); i-- > 0;) {
        if (chunks[i] != rhs.chunks[i]) {
            return chunks[i] <=> rhs.chunks[i];
        }
    }
    return std::strong_ordering::equal;
}
auto BigInt::add_magnitude(BigInt const &rhs) const noexcept -> BigInt
{
    assert(compare_magnitude(rhs) != std::strong_ordering::less);
    BigInt result{*this};
    ChunkType carry = 0;
    for (size_t i = 0; i < rhs.chunks.size(); ++i) {
        // Check if the current chunk can be added to without overflow.
        if (result.chunks[i] < chunk_max - rhs.chunks[i] - carry) {
            result.chunks[i] += rhs.chunks[i] + carry;
            carry = 0;
        } else {
            // Add with overflow and carry to next chunk.
            result.chunks[i] += rhs.chunks[i] + carry;
            carry = 1;
        }
    }
    // Add carry to the rest of the *this number.
    if (carry != 0) {
        for (size_t i = rhs.chunks.size(); i < result.chunks.size(); ++i) {
            if (result.chunks[i] == chunk_max) {
                result.chunks[i] = 0;
            } else {
                result.chunks[i] += carry;
                carry = 0;
                break;
            }
        }
    }
    // Add carry to the end of the number.
    if (carry != 0) {
        result.chunks.push_back(1);
    }
    return result;
}
auto BigInt::subtract_magnitude(BigInt const &rhs) const noexcept -> BigInt
{
    assert(compare_magnitude(rhs) != std::strong_ordering::less);
    BigInt result{*this};
    ChunkType borrow = 0;
    for (size_t i = 0; i < rhs.chunks.size(); ++i) {
        // Check if the current chunk can be subtracted from without underflow.
        if (result.chunks[i] >= rhs.chunks[i] + borrow) {
            result.chunks[i] -= rhs.chunks[i] + borrow;
            borrow = 0;
        } else {
            // Subtract with underflow and borrow from next chunk.
            result.chunks[i] -= rhs.chunks[i] + borrow;
            borrow = 1;
        }
    }
    // Subtract borrow from the rest of the *this number.
    if (borrow != 0) {
        for (size_t i = rhs.chunks.size(); i < result.chunks.size(); ++i) {
            if (result.chunks[i] == 0) {
                result.chunks[i] = chunk_max;
            } else {
                result.chunks[i] -= borrow;
                borrow = 0;
                break;
            }
        }
    }
    // Borrow cannot be 1 at the end of the number since rhs is smaller or equal.
    assert(borrow == 0);
    // Remove leading zeroes.
    result.remove_leading_zeroes();
    return result;
}
#include <algorithm>
#include <cassert>
#include <cmath>
#include <format>
#include <stdexcept>
#include <utility>
#include "bigint.hpp"
using namespace BI;
using namespace BI::detail;
using ChunkType = BigInt::ChunkType;
using DataType = BigInt::DataType;
// Character representation of all digits
static constexpr std::array<char, 16>array digits = {
  '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
};
// Character representation of all digits in lowercase.
static constexpr std::array<char, 16>array digits_lowercase = {
  '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'
};
static constexpr auto is_power_of_two(std::integral auto num) -> bool
{
    return num != 0 && (num & (num - 1)) == 0;
}
auto BigInt::is_valid_digit(Base base, char c) -> bool
{
    switch (base) {
    case Base::Binary:
        return c == '0' || c == '1';
    case Base::Octal:
        return c >= '0' && c <= '7';
    case Base::Decimal:
        return std::isdigit(c) != 0;
    case Base::Hexadecimal:
        return std::isxdigit(c) != 0;
    }
}
auto BigInt::char_to_digit(Base base, char c) -> ChunkType
{
    if (!is_valid_digit(base, c)) {
        throw std::invalid_argument(std::format("Invalid digit: {}", c));
    }
    switch (base) {
    case Base::Binary:
    case Base::Octal:
    case Base::Decimal:
        return static_cast<ChunkType>(c - '0');
    case Base::Hexadecimal:
        return static_cast<ChunkType>(std::isdigit(c) != 0 ? (c - '0') : (std::tolower(c) - 'a' + 10));
    }
}
auto BigInt::long_divide(std::string_view num, std::string "ient, Base base, ChunkType divisor) -> ChunkType
{
    // Clear quotient string and reserve space.
    quotient.clear();
    quotient.reserve(num.size());
    // Numeric value of base.
    auto const base_num = std::to_underlying(base);
    ChunkType dividend = 0;
    for (char digit : num) {
        dividend = dividend * base_num + char_to_digit(base, digit);
        // If dividend >= divisor, divide and store the quotient
        if (dividend >= divisor) {
            ChunkType quot = dividend / divisor;
            ChunkType rem = dividend % divisor;
            assert(base_num < digits.size() && quot < base_num);  // Should always be true for valid bases.
            // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-constant-array-index)
            quotient += digits[quot];
            dividend = rem;
        } else if (!quotient.empty()) {
            // If dividend is less than divisor and quotient is non-empty, add '0'.
            quotient += '0';
        }
    }
    return dividend;  // Return the remainder (which is the dividend now).
}
void BigInt::power_of_two_base_to_binary(std::string_view num, Base base)
{
    // Numeric value of base. Used for base conversion.
    auto const base_num = std::to_underlying(base);
    if (!is_power_of_two(base_num)) {
        throw std::invalid_argument("Base must be a power of 2");
    }
    // The number of bits needed to store a digit in the base.
    auto const bits_per_digit = static_cast<size_t>(std::countr_zero(base_num));
    // Clear the chunks vector.
    chunks.clear();
    ChunkType current_chunk{};
    size_t current_chunk_bits = 0;
    for (size_t i = num.size(); i-- > 0;) {
        // Bits that will fit in the current chunk.
        size_t const added_bits = std::min(bits_per_digit, chunk_bits - current_chunk_bits);
        // Bits that won't fit in the current chunk.
        size_t const remaining_bits = bits_per_digit - added_bits;
        // Digit corresponding to the current character.
        ChunkType const digit = char_to_digit(base, num[i]);
        // Digit without the bits that won't fit in the current chunk.
        ChunkType const digit_masked = digit & ((1 << added_bits) - 1);
        // Add the bits corresponding to the digit to the current chunk.
        current_chunk = (digit_masked << current_chunk_bits) | current_chunk;
        // Increment the number of bits in the current chunk.
        current_chunk_bits += added_bits;
        if (current_chunk_bits == chunk_bits) {
            // The current chunk is full, push it to the chunks vector and reset the chunk.
            chunks.push_back(current_chunk);
            // Some bits may remain in the digit, add them to the new chunk.
            current_chunk = remaining_bits > 0 ? (digit >> added_bits) : 0;
            current_chunk_bits = remaining_bits;
        }
    }
    if (current_chunk_bits > 0) {
        // If the last chunk is not full, push it to the chunks vector.
        chunks.push_back(current_chunk);
    }
}
void BigInt::decimal_base_to_binary(std::string_view num)
{
    // Clear the chunks vector.
    chunks.clear();
    // Maximum number that can fit in a half-sized chunk.
    static constexpr auto const half_chunk_bits = chunk_bits / 2;
    // The process:
    // 1. Long divide the number by 2 ^ half_chunk_bits, the remainder will span half of the chunk. Only half of the
    //    chunk is processed at a time to avoid overflow.
    // 2. Every two iterations, push the full chunk to the chunks vector.
    // 3. Repeat until the number is 0.
    bool is_half_chunk = false;
    ChunkType current_chunk{0};
    static auto divisor = static_cast<ChunkType>(1) << half_chunk_bits;
    std::string current_num{num};
    std::string new_num;
    while (!current_num.empty()) {
        // If the previous iteration was the first half of the chunk, the current iteration will be the second half,
        // and vice versa.
        is_half_chunk = !is_half_chunk;
        // Long divide the current number by the divisor and store the remainder.
        // The resulting quotient is stored in the new_num string.
        ChunkType remainder = long_divide(current_num, new_num, Base::Decimal, divisor);
        // Store the remainder in the current chunk.
        if (is_half_chunk) {
            // First half of the chunk, just store the remainder.
            current_chunk = remainder;
        } else {
            // Full chunk, push it to the chunks vector and reset the chunk.
            // The most significant half of the chunk is the remainder from this iteration.
            chunks.push_back((remainder << half_chunk_bits) | current_chunk);
            current_chunk = 0;
        }
        // Swap the current and new num strings to avoid reallocations.
        std::swap(current_num, new_num);
    }
    // If the last chunk is not full, push it to the chunks vector.
    if (is_half_chunk) {
        chunks.push_back(current_chunk);
    }
}
void BigInt::base_to_binary(std::string_view num, Base base)
{
    if (is_power_of_two(std::to_underlying(base))) {
        power_of_two_base_to_binary(num, base);
    } else {
        assert(base == Base::Decimal);
        decimal_base_to_binary(num);
    }
}
auto BigInt::format_to_power_of_two_base(Base base, bool add_prefix, bool capitalize) const noexcept -> std::string
{
    // Amount of bits that fit in a single digit of the specified base.
    auto const digit_bits = static_cast<ChunkType>(std::countr_zero(std::to_underlying(base)));
    size_t bit_count = this->bit_count();
    std::string result;
    // Reserve enough space for the result.
    result.reserve((bit_count / digit_bits) + 1 + (add_prefix ? 2 : 0));
    // Iterate through chunks of digit_bits bits and convert them to the specified base.
    size_t i = 0;
    while (i < bit_count) {
        // Number of bits to extract from the current chunk.
        size_t const extracted_bits = std::min(digit_bits, bit_count - i);
        ChunkType digit = 0;
        for (size_t j = 0; j < extracted_bits; ++j) {
            digit |= static_cast<ChunkType>(get_bit_at(i + j)) << j;
        }
        // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-constant-array-index)
        result += capitalize ? digits[digit] : digits_lowercase[digit];
        i += extracted_bits;
    }
    // Reverse the result string to make it MSB first.
    std::ranges::reverse(result);
    if (result.empty()) {
        result = "0";
    }
    std::string prefix;
    if (add_prefix) {
        switch (base) {
        case Base::Binary:
            prefix = capitalize ? "0B" : "0b";
            break;
        case Base::Octal:
            prefix = "0";
            break;
        case Base::Hexadecimal:
            prefix = capitalize ? "0X" : "0x";
            break;
        case Base::Decimal:
            assert(false);  // Should never happen.
        }
    }
    result.insert(0, prefix);
    return result;
}
auto BigInt::format_to_decimal() const -> std::string
{
    BigInt quotient{*this};
    BigInt remainder;
    static auto const ten = BigInt(10);
    static auto const log2_10 = std::log2(10);
    std::string result;
    // Reserve enough space for the result.
    result.reserve(static_cast<size_t>(std::ceil(static_cast<long double>(quotient.bit_count()) / log2_10) + 1));
    // Keep dividing modulo 10 and store the remainder in a string.
    while (!quotient.is_zero()) {
        std::tie(quotient, remainder) = div(quotient, ten);
        // Append the remainder to the result string.
        // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-constant-array-index)
        result += digits[remainder.chunks[0]];
    }
    std::ranges::reverse(result);
    return result.empty() ? "0" : result;
}
auto BigInt::format_to_base(Base base, bool add_prefix, bool capitalize) const -> std::string
{
    if (is_zero()) {
        return "0";
    }
    if (negative) {
        return "-" + (-(*this)).format_to_base(base, add_prefix, capitalize);
    }
    if (is_power_of_two(std::to_underlying(base))) {
        return format_to_power_of_two_base(base, add_prefix, capitalize);
    }
    assert(base == Base::Decimal);
    return format_to_decimal();
}