This documentation is automatically generated by online-judge-tools/verification-helper
View the Project on GitHub hitonanode/cplib-cpp
#include "linear_algebra_matrix/hessenberg_system.hpp"
体上の $n \times n$ 正則 Hessenberg 行列 $\mathbf{A}$ と $n$ 次元ベクトル $b$ に対して,線形方程式系 $\mathbf{A} \mathbf{x} = \mathbf{b}$ を解く.計算量は $O(n^2)$.
vector<vector<mint>> A(N, vector<mint>(N)); vector<mint> b(N); // A: lower Hessenberg, regular vector<mint> x = solve_lower_hessenberg(A, b); // A: upper Hessenberg, regular vector<mint> x = solve_upper_hessenberg(A, b);
#pragma once #include "../number/dual_number.hpp" #include <algorithm> #include <cassert> #include <vector> // Solve Ax = b, where A is n x n (square), lower Hessenberg, and non-singular. // Complexity: O(n^2) // Verified: https://atcoder.jp/contests/abc249/tasks/abc249_h template <class T> std::vector<T> solve_lower_hessenberg(const std::vector<std::vector<T>> &A, const std::vector<T> &b) { const int N = A.size(); if (!N) return {}; assert(int(A[0].size()) == N and int(b.size()) == N); using dual = DualNumber<T>; std::vector<dual> sol(N); for (int h = 0; h < N;) { sol[h] = dual(0, 1); for (int i = h;; ++i) { dual y = b[i]; for (int j = 0; j <= i; ++j) y -= sol[j] * A[i][j]; T a = i + 1 < N ? A[i][i + 1] : T(); if (a == T()) { T x0 = y.root(); while (h <= i) sol[h] = sol[h].eval(x0), ++h; break; } else { sol[i + 1] = y / a; } } } std::vector<T> ret(N); for (int i = 0; i < N; ++i) ret[i] = sol[i].a; return ret; } // Solve Ax = b, where A is n x n (square), upper Hessenberg, and non-singular // Complexity: O(n^2) template <class T> std::vector<T> solve_upper_hessenberg(std::vector<std::vector<T>> A, std::vector<T> b) { std::reverse(A.begin(), A.end()); for (auto &v : A) std::reverse(v.begin(), v.end()); std::reverse(b.begin(), b.end()); auto ret = solve_lower_hessenberg(A, b); std::reverse(ret.begin(), ret.end()); return ret; }
#line 1 "number/dual_number.hpp" #include <type_traits> namespace dual_number_ { struct has_id_method_impl { template <class T_> static auto check(T_ *) -> decltype(T_::id(), std::true_type()); template <class T_> static auto check(...) -> std::false_type; }; template <class T_> struct has_id : decltype(has_id_method_impl::check<T_>(nullptr)) {}; } // namespace dual_number_ // Dual number (二重数) // Verified: https://atcoder.jp/contests/abc235/tasks/abc235_f template <class T> struct DualNumber { T a, b; // a + bx template <typename T2, typename std::enable_if<dual_number_::has_id<T2>::value>::type * = nullptr> static T2 _T_id() { return T2::id(); } template <typename T2, typename std::enable_if<!dual_number_::has_id<T2>::value>::type * = nullptr> static T2 _T_id() { return T2(1); } DualNumber(T x = T(), T y = T()) : a(x), b(y) {} static DualNumber id() { return DualNumber(_T_id<T>(), T()); } explicit operator bool() const { return a != T() or b != T(); } DualNumber operator+(const DualNumber &x) const { return DualNumber(a + x.a, b + x.b); } DualNumber operator-(const DualNumber &x) const { return DualNumber(a - x.a, b - x.b); } DualNumber operator*(const DualNumber &x) const { return DualNumber(a * x.a, b * x.a + a * x.b); } DualNumber operator/(const DualNumber &x) const { T cinv = _T_id<T>() / x.a; return DualNumber(a * cinv, (b * x.a - a * x.b) * cinv * cinv); } DualNumber operator-() const { return DualNumber(-a, -b); } DualNumber &operator+=(const DualNumber &x) { return *this = *this + x; } DualNumber &operator-=(const DualNumber &x) { return *this = *this - x; } DualNumber &operator*=(const DualNumber &x) { return *this = *this * x; } DualNumber &operator/=(const DualNumber &x) { return *this = *this / x; } bool operator==(const DualNumber &x) const { return a == x.a and b == x.b; } bool operator!=(const DualNumber &x) const { return !(*this == x); } bool operator<(const DualNumber &x) const { return (a != x.a ? a < x.a : b < x.b); } template <class OStream> friend OStream &operator<<(OStream &os, const DualNumber &x) { return os << '{' << x.a << ',' << x.b << '}'; } T eval(const T &x) const { return a + b * x; } T root() const { return (-a) / b; } // Solve a + bx = 0 (b \neq 0 is assumed) }; #line 3 "linear_algebra_matrix/hessenberg_system.hpp" #include <algorithm> #include <cassert> #include <vector> // Solve Ax = b, where A is n x n (square), lower Hessenberg, and non-singular. // Complexity: O(n^2) // Verified: https://atcoder.jp/contests/abc249/tasks/abc249_h template <class T> std::vector<T> solve_lower_hessenberg(const std::vector<std::vector<T>> &A, const std::vector<T> &b) { const int N = A.size(); if (!N) return {}; assert(int(A[0].size()) == N and int(b.size()) == N); using dual = DualNumber<T>; std::vector<dual> sol(N); for (int h = 0; h < N;) { sol[h] = dual(0, 1); for (int i = h;; ++i) { dual y = b[i]; for (int j = 0; j <= i; ++j) y -= sol[j] * A[i][j]; T a = i + 1 < N ? A[i][i + 1] : T(); if (a == T()) { T x0 = y.root(); while (h <= i) sol[h] = sol[h].eval(x0), ++h; break; } else { sol[i + 1] = y / a; } } } std::vector<T> ret(N); for (int i = 0; i < N; ++i) ret[i] = sol[i].a; return ret; } // Solve Ax = b, where A is n x n (square), upper Hessenberg, and non-singular // Complexity: O(n^2) template <class T> std::vector<T> solve_upper_hessenberg(std::vector<std::vector<T>> A, std::vector<T> b) { std::reverse(A.begin(), A.end()); for (auto &v : A) std::reverse(v.begin(), v.end()); std::reverse(b.begin(), b.end()); auto ret = solve_lower_hessenberg(A, b); std::reverse(ret.begin(), ret.end()); return ret; }