This documentation is automatically generated by online-judge-tools/verification-helper
#include "heuristic/multivariate_gaussian.hpp"
多変量正規分布のパラメータを管理するクラス.線形変換・ノイズの加算・観測による事後確率の更新が行える.カルマンフィルタの実装に利用可能.
線形システムのカルマンフィルタの実装例を以下に示す.
#include "linear_algebra_matrix/matrix.hpp"
// 初期化
MultivariateGaussian<matrix<double>> kf;
vector<double> mu(dim);
matrix<double> Sigma(dim, dim);
kf.set(mu, Sigma); // N(mu, Sigma) で初期化
// 以下の「時間発展」「雑音の付与」「制御信号の注入」「推定」を任意の順序で任意の回数行ってよい。
// 時間発展
matrix<double> F(dim, dim); // 時間発展行列
kf.linear_transform(F);
// 雑音の付与
matrix<double> Q(dim, dim); // 正規雑音の分散・共分散行列
kf.add_noise(Q);
// 制御信号の注入
vector<double> u(dim); // 制御入力
kf.shift(u);
// 観測
matrix<double> H(o, dim); // 観測行列
matrix<double> R(o, o); // 観測に重畳される正規雑音の分散・共分散行列
vector<double> z(o); // 観測行列による観測結果
double regularize = 1e-9; // 逆行列数値計算の安定のためのパラメータ
kf.measure(H, R, z, regularize);
// 推定
vector<double> est = kf.x;
#ifndef MULTIVARIATE_GAUSSIAN_HPP
#define MULTIVARIATE_GAUSSIAN_HPP
#include <cassert>
#include <vector>
// #include "linear_algebra_matrix/matrix.hpp"
// Multivariate Gausssian distribution / Kalman filter
// 多変量正規分布の数値計算・カルマンフィルタ
template <class Matrix> struct MultivariateGaussian {
// 正規分布 N(x, P)
std::vector<double> x; // 期待値
Matrix P; // 分散共分散行列
void set(const std::vector<double> &x0, const Matrix &P0) {
const int dim = x0.size();
assert(P0.height() == dim and P0.width() == dim);
x = x0;
P = P0;
}
// 加算
// すなわち x <- x + dx
void shift(const std::vector<double> &dx) {
const int n = x.size();
assert(dx.size() == n);
for (int i = 0; i < n; ++i) x.at(i) += dx.at(i);
}
// F: 状態遷移行列 正方行列を想定
// すなわち x <- Fx
void linear_transform(const Matrix &F) {
x = F * x;
P = F * P * F.transpose();
}
// Q: ゼロ平均ガウシアンノイズの分散共分散行列
// すなわち x <- x + w, w ~ N(0, Q)
void add_noise(const Matrix &Q) { P = P + Q; }
// 現在の x の分布を P(x | *) として、条件付き確率 P(x | *, z) で更新する
// H: 観測行列, R: 観測ノイズの分散共分散行列, z: 観測値
// すなわち z = Hx + v, v ~ N(0, R)
void measure(const Matrix &H, const Matrix &R, const std::vector<double> &z,
double regularlize = 1e-9) {
const int nobs = z.size();
// 残差 e = z - Hx
const auto Hx = H * x;
std::vector<double> e(nobs);
for (int i = 0; i < nobs; ++i) e.at(i) = z.at(i) - Hx.at(i);
// 残差共分散 S = R + H P H^T
Matrix Sinv = R + H * P * H.transpose();
Sinv = Sinv + Matrix::Identity(nobs) * regularlize; // 不安定かも?
Sinv.inverse();
// カルマンゲイン K = P H^T S^-1
Matrix K = P * H.transpose() * Sinv;
// Update x
const auto dx = K * e;
for (int i = 0; i < (int)x.size(); ++i) x.at(i) += dx.at(i);
P = P - K * H * P;
}
};
#endif
#line 1 "heuristic/multivariate_gaussian.hpp"
#include <cassert>
#include <vector>
// #include "linear_algebra_matrix/matrix.hpp"
// Multivariate Gausssian distribution / Kalman filter
// 多変量正規分布の数値計算・カルマンフィルタ
template <class Matrix> struct MultivariateGaussian {
// 正規分布 N(x, P)
std::vector<double> x; // 期待値
Matrix P; // 分散共分散行列
void set(const std::vector<double> &x0, const Matrix &P0) {
const int dim = x0.size();
assert(P0.height() == dim and P0.width() == dim);
x = x0;
P = P0;
}
// 加算
// すなわち x <- x + dx
void shift(const std::vector<double> &dx) {
const int n = x.size();
assert(dx.size() == n);
for (int i = 0; i < n; ++i) x.at(i) += dx.at(i);
}
// F: 状態遷移行列 正方行列を想定
// すなわち x <- Fx
void linear_transform(const Matrix &F) {
x = F * x;
P = F * P * F.transpose();
}
// Q: ゼロ平均ガウシアンノイズの分散共分散行列
// すなわち x <- x + w, w ~ N(0, Q)
void add_noise(const Matrix &Q) { P = P + Q; }
// 現在の x の分布を P(x | *) として、条件付き確率 P(x | *, z) で更新する
// H: 観測行列, R: 観測ノイズの分散共分散行列, z: 観測値
// すなわち z = Hx + v, v ~ N(0, R)
void measure(const Matrix &H, const Matrix &R, const std::vector<double> &z,
double regularlize = 1e-9) {
const int nobs = z.size();
// 残差 e = z - Hx
const auto Hx = H * x;
std::vector<double> e(nobs);
for (int i = 0; i < nobs; ++i) e.at(i) = z.at(i) - Hx.at(i);
// 残差共分散 S = R + H P H^T
Matrix Sinv = R + H * P * H.transpose();
Sinv = Sinv + Matrix::Identity(nobs) * regularlize; // 不安定かも?
Sinv.inverse();
// カルマンゲイン K = P H^T S^-1
Matrix K = P * H.transpose() * Sinv;
// Update x
const auto dx = K * e;
for (int i = 0; i < (int)x.size(); ++i) x.at(i) += dx.at(i);
P = P - K * H * P;
}
};