cplib-cpp

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub hitonanode/cplib-cpp

:heavy_check_mark: Hessenberg linear system
(linear_algebra_matrix/hessenberg_system.hpp)

体上の $n \times n$ 正則 Hessenberg 行列 $\mathbf{A}$ と $n$ 次元ベクトル $b$ に対して,線形方程式系 $\mathbf{A} \mathbf{x} = \mathbf{b}$ を解く.計算量は $O(n^2)$.

使用方法

vector<vector<mint>> A(N, vector<mint>(N));
vector<mint> b(N);

// A: lower Hessenberg, regular
vector<mint> x = solve_lower_hessenberg(A, b);

// A: upper Hessenberg, regular
vector<mint> x = solve_upper_hessenberg(A, b);

問題例

Depends on

Verified with

Code

#pragma once
#include "../number/dual_number.hpp"
#include <algorithm>
#include <cassert>
#include <vector>

// Solve Ax = b, where A is n x n (square), lower Hessenberg, and non-singular.
// Complexity: O(n^2)
// Verified: https://atcoder.jp/contests/abc249/tasks/abc249_h
template <class T>
std::vector<T>
solve_lower_hessenberg(const std::vector<std::vector<T>> &A, const std::vector<T> &b) {
    const int N = A.size();
    if (!N) return {};
    assert(int(A[0].size()) == N and int(b.size()) == N);

    using dual = DualNumber<T>;
    std::vector<dual> sol(N);
    for (int h = 0; h < N;) {
        sol[h] = dual(0, 1);
        for (int i = h;; ++i) {
            dual y = b[i];
            for (int j = 0; j <= i; ++j) y -= sol[j] * A[i][j];
            T a = i + 1 < N ? A[i][i + 1] : T();
            if (a == T()) {
                T x0 = y.root();
                while (h <= i) sol[h] = sol[h].eval(x0), ++h;
                break;
            } else {
                sol[i + 1] = y / a;
            }
        }
    }
    std::vector<T> ret(N);
    for (int i = 0; i < N; ++i) ret[i] = sol[i].a;
    return ret;
}

// Solve Ax = b, where A is n x n (square), upper Hessenberg, and non-singular
// Complexity: O(n^2)
template <class T>
std::vector<T> solve_upper_hessenberg(std::vector<std::vector<T>> A, std::vector<T> b) {
    std::reverse(A.begin(), A.end());
    for (auto &v : A) std::reverse(v.begin(), v.end());
    std::reverse(b.begin(), b.end());
    auto ret = solve_lower_hessenberg(A, b);
    std::reverse(ret.begin(), ret.end());
    return ret;
}
#line 1 "number/dual_number.hpp"
#include <type_traits>

namespace dual_number_ {
struct has_id_method_impl {
    template <class T_> static auto check(T_ *) -> decltype(T_::id(), std::true_type());
    template <class T_> static auto check(...) -> std::false_type;
};
template <class T_> struct has_id : decltype(has_id_method_impl::check<T_>(nullptr)) {};
} // namespace dual_number_

// Dual number (二重数)
// Verified: https://atcoder.jp/contests/abc235/tasks/abc235_f
template <class T> struct DualNumber {
    T a, b; // a + bx

    template <typename T2, typename std::enable_if<dual_number_::has_id<T2>::value>::type * = nullptr>
    static T2 _T_id() {
        return T2::id();
    }
    template <typename T2, typename std::enable_if<!dual_number_::has_id<T2>::value>::type * = nullptr>
    static T2 _T_id() {
        return T2(1);
    }

    DualNumber(T x = T(), T y = T()) : a(x), b(y) {}
    static DualNumber id() { return DualNumber(_T_id<T>(), T()); }
    explicit operator bool() const { return a != T() or b != T(); }
    DualNumber operator+(const DualNumber &x) const { return DualNumber(a + x.a, b + x.b); }
    DualNumber operator-(const DualNumber &x) const { return DualNumber(a - x.a, b - x.b); }
    DualNumber operator*(const DualNumber &x) const {
        return DualNumber(a * x.a, b * x.a + a * x.b);
    }
    DualNumber operator/(const DualNumber &x) const {
        T cinv = _T_id<T>() / x.a;
        return DualNumber(a * cinv, (b * x.a - a * x.b) * cinv * cinv);
    }
    DualNumber operator-() const { return DualNumber(-a, -b); }
    DualNumber &operator+=(const DualNumber &x) { return *this = *this + x; }
    DualNumber &operator-=(const DualNumber &x) { return *this = *this - x; }
    DualNumber &operator*=(const DualNumber &x) { return *this = *this * x; }
    DualNumber &operator/=(const DualNumber &x) { return *this = *this / x; }
    bool operator==(const DualNumber &x) const { return a == x.a and b == x.b; }
    bool operator!=(const DualNumber &x) const { return !(*this == x); }
    bool operator<(const DualNumber &x) const { return (a != x.a ? a < x.a : b < x.b); }
    template <class OStream> friend OStream &operator<<(OStream &os, const DualNumber &x) {
        return os << '{' << x.a << ',' << x.b << '}';
    }

    T eval(const T &x) const { return a + b * x; }
    T root() const { return (-a) / b; } // Solve a + bx = 0 (b \neq 0 is assumed)
};
#line 3 "linear_algebra_matrix/hessenberg_system.hpp"
#include <algorithm>
#include <cassert>
#include <vector>

// Solve Ax = b, where A is n x n (square), lower Hessenberg, and non-singular.
// Complexity: O(n^2)
// Verified: https://atcoder.jp/contests/abc249/tasks/abc249_h
template <class T>
std::vector<T>
solve_lower_hessenberg(const std::vector<std::vector<T>> &A, const std::vector<T> &b) {
    const int N = A.size();
    if (!N) return {};
    assert(int(A[0].size()) == N and int(b.size()) == N);

    using dual = DualNumber<T>;
    std::vector<dual> sol(N);
    for (int h = 0; h < N;) {
        sol[h] = dual(0, 1);
        for (int i = h;; ++i) {
            dual y = b[i];
            for (int j = 0; j <= i; ++j) y -= sol[j] * A[i][j];
            T a = i + 1 < N ? A[i][i + 1] : T();
            if (a == T()) {
                T x0 = y.root();
                while (h <= i) sol[h] = sol[h].eval(x0), ++h;
                break;
            } else {
                sol[i + 1] = y / a;
            }
        }
    }
    std::vector<T> ret(N);
    for (int i = 0; i < N; ++i) ret[i] = sol[i].a;
    return ret;
}

// Solve Ax = b, where A is n x n (square), upper Hessenberg, and non-singular
// Complexity: O(n^2)
template <class T>
std::vector<T> solve_upper_hessenberg(std::vector<std::vector<T>> A, std::vector<T> b) {
    std::reverse(A.begin(), A.end());
    for (auto &v : A) std::reverse(v.begin(), v.end());
    std::reverse(b.begin(), b.end());
    auto ret = solve_lower_hessenberg(A, b);
    std::reverse(ret.begin(), ret.end());
    return ret;
}
Back to top page