cplib-cpp

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub hitonanode/cplib-cpp

:heavy_check_mark: other_algorithms/test/north_east_lattice_paths.bruteforce.test.cpp

Depends on

Code

#define PROBLEM "https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=ITP1_1_A" // DUMMY

#include "../north_east_lattice_paths.hpp"
#include "../../modint.hpp"
#include "../../convolution/ntt.hpp"

using mint = ModInt998244353;

#include <iostream>
#include <map>
#include <utility>
using namespace std;

void Check() {
    const int n = 30, d = 10;
    vector<mint> init(d);
    for (int i = 0; i < d; ++i) init[i] = mint(2).pow(i);

    map<pair<int, int>, mint> dp;
    for (int x = 0; x < d; ++x) dp[{x, 0}] = init[x];

    for (int x = -n; x <= n; ++x) {
        for (int y = -n; y <= n; ++y) {
            if (x + 1 <= n) dp[{x + 1, y}] += dp[{x, y}];
            if (y + 1 <= n) dp[{x, y + 1}] += dp[{x, y}];
        }
    }

    for (int dx_init = -n; dx_init <= n; ++dx_init) {
        for (int y = -n; y <= n; ++y) {
            for (int len = 1; dx_init + len - 1 <= n; ++len) {
                vector<mint> expected(len);
                for (int i = 0; i < len; ++i) expected[i] = dp[{dx_init + i, y}];

                auto res = NorthEastLatticePathsParallel<mint>(
                    init, dx_init, y, len, [&](auto &&a, auto &&b) { return nttconv(a, b); });
                if (res != expected) {
                    cerr << "Failed Parallel: " << dx_init << ' ' << y << ' ' << len
                         << "\nExpected: ";
                    for (auto e : expected) cerr << " " << e;
                    cerr << "\nResult: ";
                    for (auto e : res) cerr << " " << e;
                    cerr << '\n';
                    exit(1);
                }
            }
        }
    }

    for (int x = -n; x <= n; ++x) {
        for (int dy_init = -n; dy_init <= n; ++dy_init) {
            for (int len = 1; dy_init + len - 1 <= n; ++len) {
                vector<mint> expected(len);
                for (int i = 0; i < len; ++i) expected[i] = dp[{x, dy_init + i}];

                auto res = NorthEastLatticePathsVertical<mint>(
                    init, x, dy_init, len, [&](auto &&a, auto &&b) { return nttconv(a, b); });
                if (res != expected) {
                    cerr << "Failed Vertical: " << x << ' ' << dy_init << ' ' << len
                         << "\nExpected: ";
                    for (auto e : expected) cerr << " " << e;
                    cerr << "\nResult: ";
                    for (auto e : res) cerr << " " << e;
                    cerr << '\n';
                    exit(1);
                }
            }
        }
    }
}

int main() {
    Check();
    puts("Hello World");
}
#line 1 "other_algorithms/test/north_east_lattice_paths.bruteforce.test.cpp"
#define PROBLEM "https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=ITP1_1_A" // DUMMY

#line 2 "other_algorithms/north_east_lattice_paths.hpp"
#include <algorithm>
#include <cassert>
#include <numeric>
#include <vector>

// (i, 0) (0 <= i < bottom.size()) -> (dx_init + j, y) (0 <= j < len)
// Input: bottom[i] = Initial weight at (i, 0)
// Output: top[j] = weight at (dx_init + j, y) after transition
template <class MODINT>
std::vector<MODINT>
NorthEastLatticePathsParallel(const std::vector<MODINT> &bottom, long long dx_init, long long y,
                              int len, auto convolve) {
    const long long min_x = std::max(dx_init, 0LL), max_x = dx_init + len - 1;
    if (max_x < 0 or y < 0) return std::vector<MODINT>(len);

    const long long min_shift = std::max<long long>(0, min_x - ((long long)bottom.size() - 1)),
                    max_shift = max_x;

    std::vector<MODINT> trans(max_shift - min_shift + 1);
    for (int i = 0; i < (int)trans.size(); ++i)
        trans[i] = MODINT::binom(min_shift + i + y, y); // can be made faster if needed
    std::vector<MODINT> top = convolve(trans, bottom);

    top.erase(top.begin(), top.begin() + std::min<long long>(min_x, (long long)bottom.size() - 1));
    top.resize(max_x - min_x + 1);
    if (dx_init < 0) {
        std::vector<MODINT> tmp(-dx_init);
        top.insert(top.begin(), tmp.begin(), tmp.end());
    }
    top.shrink_to_fit();
    assert((int)top.size() == len);

    return top;
}

// (i, 0) (0 <= i < bottom.size()) -> (x, dy_init + j) (0 <= j < len)
// Input: bottom[i] = Initial weight at (i, 0)
// Output: right[j] = weight at (x, dy_init + j) after transition
template <class MODINT>
std::vector<MODINT> NorthEastLatticePathsVertical(const std::vector<MODINT> &bottom, int x,
                                                  int dy_init, int len, auto convolve) {
    const int ylo = std::max(dy_init, 0), yhi = dy_init + len;
    if (yhi <= 0 or x < 0) return std::vector<MODINT>(len);

    // (i, 0) -> (x, y) : binom(x - i, y)
    // f[i] -> g[x + y - ylo] : (x + y - i)! / (x - i)! y!
    std::vector<MODINT> tmp = bottom;
    if ((int)tmp.size() > x + 1) tmp.resize(x + 1);

    for (int i = 0; i < (int)tmp.size(); ++i) tmp[i] *= MODINT::facinv(x - i);

    std::vector<MODINT> trans(x + yhi);
    for (int i = 0; i < (int)trans.size(); ++i) trans[i] = MODINT::fac(i + ylo);
    tmp = convolve(trans, tmp);

    std::vector<MODINT> right(yhi - ylo);
    for (int y = ylo; y < yhi; ++y) right.at(y - ylo) = tmp.at(x + y - ylo) * MODINT::facinv(y);

    if (dy_init < 0) {
        std::vector<MODINT> tmp(-dy_init);
        right.insert(right.begin(), tmp.begin(), tmp.end());
    }
    right.shrink_to_fit();
    assert((int)right.size() == len);

    return right;
}

// Solve DP below in O((h + w)log(h + w)) (if `convolve()` is O(n log n))
// 1. dp[0, 0:h] += left[:]
// 2. dp[0:w, 0] += bottom[:]
// 3. dp[i, j] := dp[i-1, j] + dp[i, j-1]
// 4. return (right = dp[w-1, 0:h], top = dp[0:w, h-1]
template <class MODINT>
auto NorthEastLatticePathsInRectangle(const std::vector<MODINT> &left,
                                      const std::vector<MODINT> &bottom, auto convolve) {
    struct Result {
        std::vector<MODINT> right, top;
    };

    const int h = left.size(), w = bottom.size();
    if (h == 0 or w == 0) return Result{left, bottom};

    auto right = NorthEastLatticePathsParallel(left, 0, w - 1, h, convolve);
    auto top = NorthEastLatticePathsParallel(bottom, 0, h - 1, w, convolve);

    const auto right2 = NorthEastLatticePathsVertical(bottom, w - 1, 0, h, convolve);
    for (int i = 0; i < (int)right2.size(); ++i) right[i] += right2[i];

    const auto top2 = NorthEastLatticePathsVertical(left, h - 1, 0, w, convolve);
    for (int i = 0; i < (int)top2.size(); ++i) top[i] += top2[i];

    return Result{right, top};
}

// a) Lattice paths from (*, 0) / (0, *). Count paths ending at (w - 1, *) or absorbed at (i, ub[i])s.
// b) In other words, count sequences satisfying 0 <= a_i < ub[i]
// c) In other words, solve DP below:
//   1. dp[0:w, 0] += bottom[:], dp[0, 0:ub[0]] += left[:]
//   2. dp[i, j + 1] += dp[i, j]
//   3. dp[i + 1, j] += dp[i, j] (j < ub[i])
//   4. return dp[w-1, 0:ub[w-1]] as right, dp[i, ub[i] - 1] as top
// Complexity: O((h + w) (log(h + w))^2) (if `convolve()` is O(n log n))
// Requirement: ub is non-decreasing
template <class MODINT>
auto NorthEastLatticePathsBounded(const std::vector<int> &ub, const std::vector<MODINT> &left,
                                  const std::vector<MODINT> &bottom, auto convolve) {
    struct Result {
        std::vector<MODINT> right, top;
    };

    assert(ub.size() == bottom.size());
    if (ub.empty()) return Result{left, {}};

    assert(ub.front() == (int)left.size());
    assert(ub.front() >= 0);
    for (int i = 1; i < (int)ub.size(); ++i) assert(ub[i] >= ub[i - 1]);

    if (ub.back() <= 0) return Result{{}, bottom};

    if (const int n = bottom.size(); n > 64 and ub.back() > 64) { // 64: parameter
        const int l = n / 2, r = n - l;
        const int b = ub[l];

        auto [right1, top1] = NorthEastLatticePathsBounded<MODINT>(
            {ub.begin(), ub.begin() + l}, left, {bottom.begin(), bottom.begin() + l}, convolve);
        right1.resize(b);
        auto [right, out2] = NorthEastLatticePathsInRectangle<MODINT>(
            right1, {bottom.begin() + l, bottom.end()}, convolve);

        std::vector<int> ub_next(r);
        for (int i = 0; i < r; ++i) ub_next[i] = ub[l + i] - b;
        const auto [right3, top3] =
            NorthEastLatticePathsBounded<MODINT>(ub_next, {}, out2, convolve);
        right.insert(right.end(), right3.begin(), right3.end());
        top1.insert(top1.end(), top3.begin(), top3.end());
        return Result{right, top1};
    } else {
        std::vector<MODINT> dp = left;
        std::vector<MODINT> top = bottom;
        dp.reserve(ub.back());
        for (int i = 0; i < n; ++i) {
            dp.resize(ub[i], 0);
            if (dp.empty()) continue;
            dp[0] += bottom[i];
            for (int j = 1; j < (int)dp.size(); ++j) dp[j] += dp[j - 1];
            top[i] = dp.back();
        }
        return Result{dp, top};
    }
}

// Lattice paths from (0, *). Count paths ending at (w - 1, *). In other words, solve DP below:
//   1. dp[0, lb[0]:ub[0]] += left[:]
//   2. dp[i, j + 1] += dp[i, j] (j + 1 < ub[i])
//   3. dp[i + 1, j] += dp[i, j] (lb[i + 1] <= j)
//   4. return dp[w-1, lb[w-1]:ub[w-1]]
// Complexity: O((h + w) (log(h + w))^2) (if `convolve()` is O(n log n))
// Requirement: lb/ub is non-decreasing
template <class MODINT>
std::vector<MODINT>
NorthEastLatticePathsBothBounded(const std::vector<int> &lb, const std::vector<int> &ub,
                                 const std::vector<MODINT> &left, auto convolve) {
    assert(lb.size() == ub.size());

    const int n = ub.size();
    if (n == 0) return left;

    assert((int)left.size() == ub[0] - lb[0]);
    for (int i = 1; i < n; ++i) {
        assert(lb[i - 1] <= lb[i]);
        assert(ub[i - 1] <= ub[i]);
    }

    for (int i = 0; i < n; ++i) {
        if (lb[i] >= ub[i]) return std::vector<MODINT>(ub.back() - lb.back());
    }

    int x = 0;
    std::vector<MODINT> dp_left = left;
    std::vector<int> tmp_ub;
    while (true) {
        dp_left.resize(ub[x] - lb[x], MODINT{0});

        const int x1 = std::upper_bound(ub.begin() + x + 1, ub.begin() + n, ub[x]) - ub.begin();
        const int x2 = std::lower_bound(lb.begin() + x1, lb.begin() + n, ub[x]) - lb.begin();
        const int x3 = std::upper_bound(lb.begin() + x2, lb.begin() + n, ub[x]) - lb.begin();

        tmp_ub.assign(dp_left.size(), x2 - x);
        for (int i = x2 - 1; i >= x; --i) {
            if (const int d = lb[i] - lb[x] - 1; d >= 0 and d < (int)tmp_ub.size()) {
                tmp_ub[d] = i - x;
            }
        }
        for (int j = (int)tmp_ub.size() - 1; j; --j)
            tmp_ub[j - 1] = std::min(tmp_ub[j - 1], tmp_ub[j]);

        auto [next_dp, southeast] = NorthEastLatticePathsBounded(
            tmp_ub, std::vector<MODINT>(tmp_ub.front()), dp_left, convolve);
        next_dp.erase(next_dp.begin(), next_dp.begin() + (x1 - x));
        assert((int)next_dp.size() == x2 - x1);

        if (x1 < x3) {
            next_dp.resize(x3 - x1, MODINT{0});
            tmp_ub.resize(x3 - x1);
            for (int i = x1; i < x3; ++i) tmp_ub[i - x1] = ub[i] - ub[x];
            next_dp = NorthEastLatticePathsBounded(
                          tmp_ub, std::vector<MODINT>(tmp_ub.front()), next_dp, convolve)
                          .right;
        } else {
            next_dp.clear();
        }

        if (x3 >= n) {
            std::vector<MODINT> ret = southeast;
            ret.insert(ret.end(), next_dp.begin(), next_dp.end());
            ret.erase(ret.begin(), ret.begin() + (lb[n - 1] - lb[x]));
            assert((int)ret.size() == ub[n - 1] - lb[n - 1]);
            return ret;
        } else {
            next_dp.erase(next_dp.begin(), next_dp.begin() + (lb[x3] - ub[x]));
            x = x3;
            dp_left = std::move(next_dp);
        }
    }
}

// Count nonnegative non-decreasing integer sequence `a` satisfying a[i] < ub[i]
// Complexity: O(n log^2(n)) (if `convolve()` is O(n log n))
template <class MODINT> MODINT CountBoundedMonotoneSequence(std::vector<int> ub, auto convolve) {
    const int n = ub.size();
    assert(n > 0);
    for (int i = n - 1; i; --i) ub[i - 1] = std::min(ub[i - 1], ub[i]);
    if (ub.front() <= 0) return MODINT{0};

    std::vector<MODINT> bottom(ub.size()), left(ub.front());
    bottom.front() = 1;
    std::vector<MODINT> ret = NorthEastLatticePathsBounded(ub, left, bottom, convolve).right;
    return std::accumulate(ret.begin(), ret.end(), MODINT{});
}

// Count nonnegative non-decreasing integer sequence `a` satisfying lb[i] <= a[i] < ub[i]
// Complexity: O(n log^2(n)) (if `convolve()` is O(n log n))
// https://noshi91.hatenablog.com/entry/2023/07/21/235339
// Verify: https://judge.yosupo.jp/problem/number_of_increasing_sequences_between_two_sequences
template <class MODINT>
MODINT CountBoundedMonotoneSequence(std::vector<int> lb, std::vector<int> ub, auto convolve) {
    assert(lb.size() == ub.size());

    const int n = ub.size();
    if (n == 0) return MODINT{1};

    for (int i = 1; i < n; ++i) lb[i] = std::max(lb[i - 1], lb[i]);
    for (int i = n - 1; i; --i) ub[i - 1] = std::min(ub[i - 1], ub[i]);

    for (int i = 0; i < n; ++i) {
        if (lb[i] >= ub[i]) return MODINT{0};
    }

    const int k = lb.back();
    lb.insert(lb.begin(), lb.front()); // len(lb) == n + 1
    lb.pop_back();

    std::vector<MODINT> init(ub.front() - lb.front());
    init.at(0) = 1;

    auto res = NorthEastLatticePathsBothBounded(lb, ub, init, convolve);
    res.erase(res.begin(), res.begin() + (k - lb.back()));
    return std::accumulate(res.begin(), res.end(), MODINT{});
}
#line 3 "modint.hpp"
#include <iostream>
#include <set>
#line 6 "modint.hpp"

template <int md> struct ModInt {
    static_assert(md > 1);
    using lint = long long;
    constexpr static int mod() { return md; }
    static int get_primitive_root() {
        static int primitive_root = 0;
        if (!primitive_root) {
            primitive_root = [&]() {
                std::set<int> fac;
                int v = md - 1;
                for (lint i = 2; i * i <= v; i++)
                    while (v % i == 0) fac.insert(i), v /= i;
                if (v > 1) fac.insert(v);
                for (int g = 1; g < md; g++) {
                    bool ok = true;
                    for (auto i : fac)
                        if (ModInt(g).pow((md - 1) / i) == 1) {
                            ok = false;
                            break;
                        }
                    if (ok) return g;
                }
                return -1;
            }();
        }
        return primitive_root;
    }
    int val_;
    int val() const noexcept { return val_; }
    constexpr ModInt() : val_(0) {}
    constexpr ModInt &_setval(lint v) { return val_ = (v >= md ? v - md : v), *this; }
    constexpr ModInt(lint v) { _setval(v % md + md); }
    constexpr explicit operator bool() const { return val_ != 0; }
    constexpr ModInt operator+(const ModInt &x) const {
        return ModInt()._setval((lint)val_ + x.val_);
    }
    constexpr ModInt operator-(const ModInt &x) const {
        return ModInt()._setval((lint)val_ - x.val_ + md);
    }
    constexpr ModInt operator*(const ModInt &x) const {
        return ModInt()._setval((lint)val_ * x.val_ % md);
    }
    constexpr ModInt operator/(const ModInt &x) const {
        return ModInt()._setval((lint)val_ * x.inv().val() % md);
    }
    constexpr ModInt operator-() const { return ModInt()._setval(md - val_); }
    constexpr ModInt &operator+=(const ModInt &x) { return *this = *this + x; }
    constexpr ModInt &operator-=(const ModInt &x) { return *this = *this - x; }
    constexpr ModInt &operator*=(const ModInt &x) { return *this = *this * x; }
    constexpr ModInt &operator/=(const ModInt &x) { return *this = *this / x; }
    friend constexpr ModInt operator+(lint a, const ModInt &x) { return ModInt(a) + x; }
    friend constexpr ModInt operator-(lint a, const ModInt &x) { return ModInt(a) - x; }
    friend constexpr ModInt operator*(lint a, const ModInt &x) { return ModInt(a) * x; }
    friend constexpr ModInt operator/(lint a, const ModInt &x) { return ModInt(a) / x; }
    constexpr bool operator==(const ModInt &x) const { return val_ == x.val_; }
    constexpr bool operator!=(const ModInt &x) const { return val_ != x.val_; }
    constexpr bool operator<(const ModInt &x) const {
        return val_ < x.val_;
    } // To use std::map<ModInt, T>
    friend std::istream &operator>>(std::istream &is, ModInt &x) {
        lint t;
        return is >> t, x = ModInt(t), is;
    }
    constexpr friend std::ostream &operator<<(std::ostream &os, const ModInt &x) {
        return os << x.val_;
    }

    constexpr ModInt pow(lint n) const {
        ModInt ans = 1, tmp = *this;
        while (n) {
            if (n & 1) ans *= tmp;
            tmp *= tmp, n >>= 1;
        }
        return ans;
    }

    static constexpr int cache_limit = std::min(md, 1 << 21);
    static std::vector<ModInt> facs, facinvs, invs;

    constexpr static void _precalculation(int N) {
        const int l0 = facs.size();
        if (N > md) N = md;
        if (N <= l0) return;
        facs.resize(N), facinvs.resize(N), invs.resize(N);
        for (int i = l0; i < N; i++) facs[i] = facs[i - 1] * i;
        facinvs[N - 1] = facs.back().pow(md - 2);
        for (int i = N - 2; i >= l0; i--) facinvs[i] = facinvs[i + 1] * (i + 1);
        for (int i = N - 1; i >= l0; i--) invs[i] = facinvs[i] * facs[i - 1];
    }

    constexpr ModInt inv() const {
        if (this->val_ < cache_limit) {
            if (facs.empty()) facs = {1}, facinvs = {1}, invs = {0};
            while (this->val_ >= int(facs.size())) _precalculation(facs.size() * 2);
            return invs[this->val_];
        } else {
            return this->pow(md - 2);
        }
    }

    constexpr static ModInt fac(int n) {
        assert(n >= 0);
        if (n >= md) return ModInt(0);
        while (n >= int(facs.size())) _precalculation(facs.size() * 2);
        return facs[n];
    }

    constexpr static ModInt facinv(int n) {
        assert(n >= 0);
        if (n >= md) return ModInt(0);
        while (n >= int(facs.size())) _precalculation(facs.size() * 2);
        return facinvs[n];
    }

    constexpr static ModInt doublefac(int n) {
        assert(n >= 0);
        if (n >= md) return ModInt(0);
        long long k = (n + 1) / 2;
        return (n & 1) ? ModInt::fac(k * 2) / (ModInt(2).pow(k) * ModInt::fac(k))
                       : ModInt::fac(k) * ModInt(2).pow(k);
    }

    constexpr static ModInt nCr(int n, int r) {
        assert(n >= 0);
        if (r < 0 or n < r) return ModInt(0);
        return ModInt::fac(n) * ModInt::facinv(r) * ModInt::facinv(n - r);
    }

    constexpr static ModInt nPr(int n, int r) {
        assert(n >= 0);
        if (r < 0 or n < r) return ModInt(0);
        return ModInt::fac(n) * ModInt::facinv(n - r);
    }

    static ModInt binom(long long n, long long r) {
        static long long bruteforce_times = 0;

        if (r < 0 or n < r) return ModInt(0);
        if (n <= bruteforce_times or n < (int)facs.size()) return ModInt::nCr(n, r);

        r = std::min(r, n - r);
        assert((int)r == r);

        ModInt ret = ModInt::facinv(r);
        for (int i = 0; i < r; ++i) ret *= n - i;
        bruteforce_times += r;

        return ret;
    }

    // Multinomial coefficient, (k_1 + k_2 + ... + k_m)! / (k_1! k_2! ... k_m!)
    // Complexity: O(sum(ks))
    // Verify: https://yukicoder.me/problems/no/3178
    template <class Vec> static ModInt multinomial(const Vec &ks) {
        ModInt ret{1};
        int sum = 0;
        for (int k : ks) {
            assert(k >= 0);
            ret *= ModInt::facinv(k), sum += k;
        }
        return ret * ModInt::fac(sum);
    }
    template <class... Args> static ModInt multinomial(Args... args) {
        int sum = (0 + ... + args);
        ModInt result = (1 * ... * ModInt::facinv(args));
        return ModInt::fac(sum) * result;
    }

    // Catalan number, C_n = binom(2n, n) / (n + 1) = # of Dyck words of length 2n
    // C_0 = 1, C_1 = 1, C_2 = 2, C_3 = 5, C_4 = 14, ...
    // https://oeis.org/A000108
    // Complexity: O(n)
    static ModInt catalan(int n) {
        if (n < 0) return ModInt(0);
        return ModInt::fac(n * 2) * ModInt::facinv(n + 1) * ModInt::facinv(n);
    }

    ModInt sqrt() const {
        if (val_ == 0) return 0;
        if (md == 2) return val_;
        if (pow((md - 1) / 2) != 1) return 0;
        ModInt b = 1;
        while (b.pow((md - 1) / 2) == 1) b += 1;
        int e = 0, m = md - 1;
        while (m % 2 == 0) m >>= 1, e++;
        ModInt x = pow((m - 1) / 2), y = (*this) * x * x;
        x *= (*this);
        ModInt z = b.pow(m);
        while (y != 1) {
            int j = 0;
            ModInt t = y;
            while (t != 1) j++, t *= t;
            z = z.pow(1LL << (e - j - 1));
            x *= z, z *= z, y *= z;
            e = j;
        }
        return ModInt(std::min(x.val_, md - x.val_));
    }
};
template <int md> std::vector<ModInt<md>> ModInt<md>::facs = {1};
template <int md> std::vector<ModInt<md>> ModInt<md>::facinvs = {1};
template <int md> std::vector<ModInt<md>> ModInt<md>::invs = {0};

using ModInt998244353 = ModInt<998244353>;
// using mint = ModInt<998244353>;
// using mint = ModInt<1000000007>;
#line 3 "convolution/ntt.hpp"

#line 5 "convolution/ntt.hpp"
#include <array>
#line 7 "convolution/ntt.hpp"
#include <tuple>
#line 9 "convolution/ntt.hpp"

// CUT begin
// Integer convolution for arbitrary mod
// with NTT (and Garner's algorithm) for ModInt / ModIntRuntime class.
// We skip Garner's algorithm if `skip_garner` is true or mod is in `nttprimes`.
// input: a (size: n), b (size: m)
// return: vector (size: n + m - 1)
template <typename MODINT>
std::vector<MODINT> nttconv(std::vector<MODINT> a, std::vector<MODINT> b, bool skip_garner);

constexpr int nttprimes[3] = {998244353, 167772161, 469762049};

// Integer FFT (Fast Fourier Transform) for ModInt class
// (Also known as Number Theoretic Transform, NTT)
// is_inverse: inverse transform
// ** Input size must be 2^n **
template <typename MODINT> void ntt(std::vector<MODINT> &a, bool is_inverse = false) {
    int n = a.size();
    if (n == 1) return;
    static const int mod = MODINT::mod();
    static const MODINT root = MODINT::get_primitive_root();
    assert(__builtin_popcount(n) == 1 and (mod - 1) % n == 0);

    static std::vector<MODINT> w{1}, iw{1};
    for (int m = w.size(); m < n / 2; m *= 2) {
        MODINT dw = root.pow((mod - 1) / (4 * m)), dwinv = 1 / dw;
        w.resize(m * 2), iw.resize(m * 2);
        for (int i = 0; i < m; i++) w[m + i] = w[i] * dw, iw[m + i] = iw[i] * dwinv;
    }

    if (!is_inverse) {
        for (int m = n; m >>= 1;) {
            for (int s = 0, k = 0; s < n; s += 2 * m, k++) {
                for (int i = s; i < s + m; i++) {
                    MODINT x = a[i], y = a[i + m] * w[k];
                    a[i] = x + y, a[i + m] = x - y;
                }
            }
        }
    } else {
        for (int m = 1; m < n; m *= 2) {
            for (int s = 0, k = 0; s < n; s += 2 * m, k++) {
                for (int i = s; i < s + m; i++) {
                    MODINT x = a[i], y = a[i + m];
                    a[i] = x + y, a[i + m] = (x - y) * iw[k];
                }
            }
        }
        int n_inv = MODINT(n).inv().val();
        for (auto &v : a) v *= n_inv;
    }
}
template <int MOD>
std::vector<ModInt<MOD>> nttconv_(const std::vector<int> &a, const std::vector<int> &b) {
    int sz = a.size();
    assert(a.size() == b.size() and __builtin_popcount(sz) == 1);
    std::vector<ModInt<MOD>> ap(sz), bp(sz);
    for (int i = 0; i < sz; i++) ap[i] = a[i], bp[i] = b[i];
    ntt(ap, false);
    if (a == b)
        bp = ap;
    else
        ntt(bp, false);
    for (int i = 0; i < sz; i++) ap[i] *= bp[i];
    ntt(ap, true);
    return ap;
}
long long garner_ntt_(int r0, int r1, int r2, int mod) {
    using mint2 = ModInt<nttprimes[2]>;
    static const long long m01 = 1LL * nttprimes[0] * nttprimes[1];
    static const long long m0_inv_m1 = ModInt<nttprimes[1]>(nttprimes[0]).inv().val();
    static const long long m01_inv_m2 = mint2(m01).inv().val();

    int v1 = (m0_inv_m1 * (r1 + nttprimes[1] - r0)) % nttprimes[1];
    auto v2 = (mint2(r2) - r0 - mint2(nttprimes[0]) * v1) * m01_inv_m2;
    return (r0 + 1LL * nttprimes[0] * v1 + m01 % mod * v2.val()) % mod;
}
template <typename MODINT>
std::vector<MODINT> nttconv(std::vector<MODINT> a, std::vector<MODINT> b, bool skip_garner) {
    if (a.empty() or b.empty()) return {};
    int sz = 1, n = a.size(), m = b.size();
    while (sz < n + m) sz <<= 1;
    if (sz <= 16) {
        std::vector<MODINT> ret(n + m - 1);
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) ret[i + j] += a[i] * b[j];
        }
        return ret;
    }
    int mod = MODINT::mod();
    if (skip_garner or
        std::find(std::begin(nttprimes), std::end(nttprimes), mod) != std::end(nttprimes)) {
        a.resize(sz), b.resize(sz);
        if (a == b) {
            ntt(a, false);
            b = a;
        } else {
            ntt(a, false), ntt(b, false);
        }
        for (int i = 0; i < sz; i++) a[i] *= b[i];
        ntt(a, true);
        a.resize(n + m - 1);
    } else {
        std::vector<int> ai(sz), bi(sz);
        for (int i = 0; i < n; i++) ai[i] = a[i].val();
        for (int i = 0; i < m; i++) bi[i] = b[i].val();
        auto ntt0 = nttconv_<nttprimes[0]>(ai, bi);
        auto ntt1 = nttconv_<nttprimes[1]>(ai, bi);
        auto ntt2 = nttconv_<nttprimes[2]>(ai, bi);
        a.resize(n + m - 1);
        for (int i = 0; i < n + m - 1; i++)
            a[i] = garner_ntt_(ntt0[i].val(), ntt1[i].val(), ntt2[i].val(), mod);
    }
    return a;
}

template <typename MODINT>
std::vector<MODINT> nttconv(const std::vector<MODINT> &a, const std::vector<MODINT> &b) {
    return nttconv<MODINT>(a, b, false);
}
#line 6 "other_algorithms/test/north_east_lattice_paths.bruteforce.test.cpp"

using mint = ModInt998244353;

#line 10 "other_algorithms/test/north_east_lattice_paths.bruteforce.test.cpp"
#include <map>
#include <utility>
using namespace std;

void Check() {
    const int n = 30, d = 10;
    vector<mint> init(d);
    for (int i = 0; i < d; ++i) init[i] = mint(2).pow(i);

    map<pair<int, int>, mint> dp;
    for (int x = 0; x < d; ++x) dp[{x, 0}] = init[x];

    for (int x = -n; x <= n; ++x) {
        for (int y = -n; y <= n; ++y) {
            if (x + 1 <= n) dp[{x + 1, y}] += dp[{x, y}];
            if (y + 1 <= n) dp[{x, y + 1}] += dp[{x, y}];
        }
    }

    for (int dx_init = -n; dx_init <= n; ++dx_init) {
        for (int y = -n; y <= n; ++y) {
            for (int len = 1; dx_init + len - 1 <= n; ++len) {
                vector<mint> expected(len);
                for (int i = 0; i < len; ++i) expected[i] = dp[{dx_init + i, y}];

                auto res = NorthEastLatticePathsParallel<mint>(
                    init, dx_init, y, len, [&](auto &&a, auto &&b) { return nttconv(a, b); });
                if (res != expected) {
                    cerr << "Failed Parallel: " << dx_init << ' ' << y << ' ' << len
                         << "\nExpected: ";
                    for (auto e : expected) cerr << " " << e;
                    cerr << "\nResult: ";
                    for (auto e : res) cerr << " " << e;
                    cerr << '\n';
                    exit(1);
                }
            }
        }
    }

    for (int x = -n; x <= n; ++x) {
        for (int dy_init = -n; dy_init <= n; ++dy_init) {
            for (int len = 1; dy_init + len - 1 <= n; ++len) {
                vector<mint> expected(len);
                for (int i = 0; i < len; ++i) expected[i] = dp[{x, dy_init + i}];

                auto res = NorthEastLatticePathsVertical<mint>(
                    init, x, dy_init, len, [&](auto &&a, auto &&b) { return nttconv(a, b); });
                if (res != expected) {
                    cerr << "Failed Vertical: " << x << ' ' << dy_init << ' ' << len
                         << "\nExpected: ";
                    for (auto e : expected) cerr << " " << e;
                    cerr << "\nResult: ";
                    for (auto e : res) cerr << " " << e;
                    cerr << '\n';
                    exit(1);
                }
            }
        }
    }
}

int main() {
    Check();
    puts("Hello World");
}
Back to top page