This documentation is automatically generated by online-judge-tools/verification-helper
#include "linear_algebra_matrix/hessenberg_system.hpp"
体上の $n \times n$ 正則 Hessenberg 行列 $\mathbf{A}$ と $n$ 次元ベクトル $b$ に対して,線形方程式系 $\mathbf{A} \mathbf{x} = \mathbf{b}$ を解く.計算量は $O(n^2)$.
vector<vector<mint>> A(N, vector<mint>(N));
vector<mint> b(N);
// A: lower Hessenberg, regular
vector<mint> x = solve_lower_hessenberg(A, b);
// A: upper Hessenberg, regular
vector<mint> x = solve_upper_hessenberg(A, b);
#pragma once
#include "../number/dual_number.hpp"
#include <algorithm>
#include <cassert>
#include <vector>
// Solve Ax = b, where A is n x n (square), lower Hessenberg, and non-singular.
// Complexity: O(n^2)
// Verified: https://atcoder.jp/contests/abc249/tasks/abc249_h
template <class T>
std::vector<T>
solve_lower_hessenberg(const std::vector<std::vector<T>> &A, const std::vector<T> &b) {
const int N = A.size();
if (!N) return {};
assert(int(A[0].size()) == N and int(b.size()) == N);
using dual = DualNumber<T>;
std::vector<dual> sol(N);
for (int h = 0; h < N;) {
sol[h] = dual(0, 1);
for (int i = h;; ++i) {
dual y = b[i];
for (int j = 0; j <= i; ++j) y -= sol[j] * A[i][j];
T a = i + 1 < N ? A[i][i + 1] : T();
if (a == T()) {
T x0 = y.root();
while (h <= i) sol[h] = sol[h].eval(x0), ++h;
break;
} else {
sol[i + 1] = y / a;
}
}
}
std::vector<T> ret(N);
for (int i = 0; i < N; ++i) ret[i] = sol[i].a;
return ret;
}
// Solve Ax = b, where A is n x n (square), upper Hessenberg, and non-singular
// Complexity: O(n^2)
template <class T>
std::vector<T> solve_upper_hessenberg(std::vector<std::vector<T>> A, std::vector<T> b) {
std::reverse(A.begin(), A.end());
for (auto &v : A) std::reverse(v.begin(), v.end());
std::reverse(b.begin(), b.end());
auto ret = solve_lower_hessenberg(A, b);
std::reverse(ret.begin(), ret.end());
return ret;
}
#line 2 "number/dual_number.hpp"
#include <type_traits>
namespace dual_number_ {
struct has_id_method_impl {
template <class T_> static auto check(T_ *) -> decltype(T_::id(), std::true_type());
template <class T_> static auto check(...) -> std::false_type;
};
template <class T_> struct has_id : decltype(has_id_method_impl::check<T_>(nullptr)) {};
} // namespace dual_number_
// Dual number (二重数)
// Verified: https://atcoder.jp/contests/abc235/tasks/abc235_f
template <class T> struct DualNumber {
T a, b; // a + bx
template <typename T2, typename std::enable_if<dual_number_::has_id<T2>::value>::type * = nullptr>
static T2 _T_id() {
return T2::id();
}
template <typename T2, typename std::enable_if<!dual_number_::has_id<T2>::value>::type * = nullptr>
static T2 _T_id() {
return T2(1);
}
DualNumber(T x = T(), T y = T()) : a(x), b(y) {}
static DualNumber id() { return DualNumber(_T_id<T>(), T()); }
explicit operator bool() const { return a != T() or b != T(); }
DualNumber operator+(const DualNumber &x) const { return DualNumber(a + x.a, b + x.b); }
DualNumber operator-(const DualNumber &x) const { return DualNumber(a - x.a, b - x.b); }
DualNumber operator*(const DualNumber &x) const {
return DualNumber(a * x.a, b * x.a + a * x.b);
}
DualNumber operator/(const DualNumber &x) const {
T cinv = _T_id<T>() / x.a;
return DualNumber(a * cinv, (b * x.a - a * x.b) * cinv * cinv);
}
DualNumber operator-() const { return DualNumber(-a, -b); }
DualNumber &operator+=(const DualNumber &x) { return *this = *this + x; }
DualNumber &operator-=(const DualNumber &x) { return *this = *this - x; }
DualNumber &operator*=(const DualNumber &x) { return *this = *this * x; }
DualNumber &operator/=(const DualNumber &x) { return *this = *this / x; }
bool operator==(const DualNumber &x) const { return a == x.a and b == x.b; }
bool operator!=(const DualNumber &x) const { return !(*this == x); }
bool operator<(const DualNumber &x) const { return (a != x.a ? a < x.a : b < x.b); }
template <class OStream> friend OStream &operator<<(OStream &os, const DualNumber &x) {
return os << '{' << x.a << ',' << x.b << '}';
}
T eval(const T &x) const { return a + b * x; }
T root() const { return (-a) / b; } // Solve a + bx = 0 (b \neq 0 is assumed)
};
#line 3 "linear_algebra_matrix/hessenberg_system.hpp"
#include <algorithm>
#include <cassert>
#include <vector>
// Solve Ax = b, where A is n x n (square), lower Hessenberg, and non-singular.
// Complexity: O(n^2)
// Verified: https://atcoder.jp/contests/abc249/tasks/abc249_h
template <class T>
std::vector<T>
solve_lower_hessenberg(const std::vector<std::vector<T>> &A, const std::vector<T> &b) {
const int N = A.size();
if (!N) return {};
assert(int(A[0].size()) == N and int(b.size()) == N);
using dual = DualNumber<T>;
std::vector<dual> sol(N);
for (int h = 0; h < N;) {
sol[h] = dual(0, 1);
for (int i = h;; ++i) {
dual y = b[i];
for (int j = 0; j <= i; ++j) y -= sol[j] * A[i][j];
T a = i + 1 < N ? A[i][i + 1] : T();
if (a == T()) {
T x0 = y.root();
while (h <= i) sol[h] = sol[h].eval(x0), ++h;
break;
} else {
sol[i + 1] = y / a;
}
}
}
std::vector<T> ret(N);
for (int i = 0; i < N; ++i) ret[i] = sol[i].a;
return ret;
}
// Solve Ax = b, where A is n x n (square), upper Hessenberg, and non-singular
// Complexity: O(n^2)
template <class T>
std::vector<T> solve_upper_hessenberg(std::vector<std::vector<T>> A, std::vector<T> b) {
std::reverse(A.begin(), A.end());
for (auto &v : A) std::reverse(v.begin(), v.end());
std::reverse(b.begin(), b.end());
auto ret = solve_lower_hessenberg(A, b);
std::reverse(ret.begin(), ret.end());
return ret;
}