This documentation is automatically generated by online-judge-tools/verification-helper
#include "segmenttree/persistent_segtree.hpp"完全永続版のセグメント木.各点更新のたびに新しい版の根を返し,過去の任意の版に対して 1 点更新や区間積クエリを行える.インターフェースは atcoder::segtree に近く,第一引数に更新のもととなる版の根を与える点が異なる.
struct S {
unsigned long long sum;
int len;
};
S op(S l, S r) { return {l.sum + r.sum, l.len + r.len}; }
S e() { return {0, 0}; }
vector<S> A(N, {0, 1});
persistent_segtree<S, op, e> seg(A);
auto root0 = seg.get_root(); // 初期版
auto root1 = seg.set(root0, idx, {x, 1}); // idx 番目を更新した新しい版
S x = seg.get(root0, idx); // root0 版の idx 番目の値
S y = seg.prod(root1, l, r); // root1 版の [l, r) の積
S z = seg.all_prod(root1); // root1 版の列全体の積
int i = seg.max_right(root1, l, [](S x) { return x.sum <= lim; });
int j = seg.min_left(root1, r, [](S x) { return x.sum <= lim; });
max_right, min_left の意味は atcoder::segtree と同じ.すなわち,f(e()) = true を満たす単調な述語 f に対して,
max_right(root, l, f) は prod(root, l, r) が f を満たすような最大の r を返す.min_left(root, r, f) は prod(root, l, r) が f を満たすような最小の l を返す.計算量は以下の通り.
set $O(\log N)$ 時間,更新 1 回あたり追加ノード数 $O(\log N)$get, prod, max_right, min_left $O(\log N)$all_prod $O(1)$#pragma once
#include <bit>
#include <cassert>
#include <functional>
#include <vector>
template <class S, auto op, auto e> struct persistent_segtree {
static_assert(std::is_convertible_v<decltype(op), std::function<S(S, S)>>,
"op must work as S(S, S)");
static_assert(std::is_convertible_v<decltype(e), std::function<S()>>, "e must work as S()");
struct Root {
int id;
};
explicit persistent_segtree(int n) : persistent_segtree(std::vector<S>(n, e())) {}
explicit persistent_segtree(const std::vector<S> &v) : _n(int(v.size())) {
size = std::bit_ceil((unsigned int)_n);
lg = std::countr_zero((unsigned int)size);
nodes.assign(2 * size, Node{e(), -1, -1});
for (int i = 0; i < _n; i++) nodes[size + i].val = v[i];
for (int i = size - 1; i >= 1; i--) {
nodes[i].left = 2 * i;
nodes[i].right = 2 * i + 1;
nodes[i].val = op(nodes[2 * i].val, nodes[2 * i + 1].val);
}
}
Root set(const Root &root, int p, const S &x) {
assert(0 <= p && p < _n);
std::vector<int> ids(lg + 1);
ids[lg] = root.id;
for (int i = lg - 1; i >= 0; --i) {
const Node &par = nodes[ids[i + 1]];
ids[i] = ((p >> i) & 1) ? par.right : par.left;
}
int copy_cur = new_node(x, -1, -1);
for (int i = 1; i <= lg; i++) {
const int par = ids[i], cur = ids[i - 1];
const Node &par_node = nodes[par];
const int left = par_node.left == cur ? copy_cur : par_node.left;
const int right = par_node.right == cur ? copy_cur : par_node.right;
copy_cur = new_node(op(nodes[left].val, nodes[right].val), left, right);
}
return Root{copy_cur};
}
S get(const Root &root, int p) const {
assert(0 <= p && p < _n);
int i = root.id;
for (int bit = lg - 1; bit >= 0; --bit) {
i = ((p >> bit) & 1) ? nodes[i].right : nodes[i].left;
}
return nodes[i].val;
}
S prod(const Root &root, int l, int r) const {
assert(0 <= l && l <= r && r <= _n);
auto rec = [&](auto &&self, int i, int lo, int hi) -> S {
if (r <= lo || hi <= l) return e();
if (l <= lo && hi <= r) return nodes[i].val;
const int mid = (lo + hi) >> 1;
return op(self(self, nodes[i].left, lo, mid), self(self, nodes[i].right, mid, hi));
};
return rec(rec, root.id, 0, size);
}
S all_prod(const Root &root) const { return nodes[root.id].val; }
template <bool (*f)(S)> int max_right(const Root &root, int l) const {
return max_right(root, l, [](S x) { return f(x); });
}
template <class F> int max_right(const Root &root, int l, F f) const {
assert(0 <= l && l <= _n);
assert(f(e()));
if (l == _n) return _n;
S sm = e();
auto rec = [&](auto &&self, int i, int lo, int hi) -> int {
if (hi <= l) return hi;
if (l <= lo) {
const S nxt = op(sm, nodes[i].val);
if (f(nxt)) {
sm = nxt;
return hi;
}
if (hi - lo == 1) return lo;
}
const int mid = (lo + hi) >> 1;
if (l < mid) {
const int left_res = self(self, nodes[i].left, lo, mid);
if (left_res < mid) return left_res;
}
return self(self, nodes[i].right, mid, hi);
};
return std::min(rec(rec, root.id, 0, size), _n);
}
template <bool (*f)(S)> int min_left(const Root &root, int r) const {
return min_left(root, r, [](S x) { return f(x); });
}
template <class F> int min_left(const Root &root, int r, F f) const {
assert(0 <= r && r <= _n);
assert(f(e()));
if (r == 0) return 0;
S sm = e();
auto rec = [&](auto &&self, int i, int lo, int hi) -> int {
if (r <= lo) return lo;
if (hi <= r) {
const S nxt = op(nodes[i].val, sm);
if (f(nxt)) {
sm = nxt;
return lo;
}
if (hi - lo == 1) return hi;
}
const int mid = (lo + hi) >> 1;
if (mid < r) {
const int right_res = self(self, nodes[i].right, mid, hi);
if (mid < right_res) return right_res;
}
return self(self, nodes[i].left, lo, mid);
};
return rec(rec, root.id, 0, size);
}
Root get_root() const { return {1}; }
struct Node {
S val;
int left, right;
};
int _n, size, lg;
std::vector<Node> nodes;
int new_node(const S &val, int left, int right) {
nodes.push_back(Node{val, left, right});
return int(nodes.size()) - 1;
}
};#line 2 "segmenttree/persistent_segtree.hpp"
#include <bit>
#include <cassert>
#include <functional>
#include <vector>
template <class S, auto op, auto e> struct persistent_segtree {
static_assert(std::is_convertible_v<decltype(op), std::function<S(S, S)>>,
"op must work as S(S, S)");
static_assert(std::is_convertible_v<decltype(e), std::function<S()>>, "e must work as S()");
struct Root {
int id;
};
explicit persistent_segtree(int n) : persistent_segtree(std::vector<S>(n, e())) {}
explicit persistent_segtree(const std::vector<S> &v) : _n(int(v.size())) {
size = std::bit_ceil((unsigned int)_n);
lg = std::countr_zero((unsigned int)size);
nodes.assign(2 * size, Node{e(), -1, -1});
for (int i = 0; i < _n; i++) nodes[size + i].val = v[i];
for (int i = size - 1; i >= 1; i--) {
nodes[i].left = 2 * i;
nodes[i].right = 2 * i + 1;
nodes[i].val = op(nodes[2 * i].val, nodes[2 * i + 1].val);
}
}
Root set(const Root &root, int p, const S &x) {
assert(0 <= p && p < _n);
std::vector<int> ids(lg + 1);
ids[lg] = root.id;
for (int i = lg - 1; i >= 0; --i) {
const Node &par = nodes[ids[i + 1]];
ids[i] = ((p >> i) & 1) ? par.right : par.left;
}
int copy_cur = new_node(x, -1, -1);
for (int i = 1; i <= lg; i++) {
const int par = ids[i], cur = ids[i - 1];
const Node &par_node = nodes[par];
const int left = par_node.left == cur ? copy_cur : par_node.left;
const int right = par_node.right == cur ? copy_cur : par_node.right;
copy_cur = new_node(op(nodes[left].val, nodes[right].val), left, right);
}
return Root{copy_cur};
}
S get(const Root &root, int p) const {
assert(0 <= p && p < _n);
int i = root.id;
for (int bit = lg - 1; bit >= 0; --bit) {
i = ((p >> bit) & 1) ? nodes[i].right : nodes[i].left;
}
return nodes[i].val;
}
S prod(const Root &root, int l, int r) const {
assert(0 <= l && l <= r && r <= _n);
auto rec = [&](auto &&self, int i, int lo, int hi) -> S {
if (r <= lo || hi <= l) return e();
if (l <= lo && hi <= r) return nodes[i].val;
const int mid = (lo + hi) >> 1;
return op(self(self, nodes[i].left, lo, mid), self(self, nodes[i].right, mid, hi));
};
return rec(rec, root.id, 0, size);
}
S all_prod(const Root &root) const { return nodes[root.id].val; }
template <bool (*f)(S)> int max_right(const Root &root, int l) const {
return max_right(root, l, [](S x) { return f(x); });
}
template <class F> int max_right(const Root &root, int l, F f) const {
assert(0 <= l && l <= _n);
assert(f(e()));
if (l == _n) return _n;
S sm = e();
auto rec = [&](auto &&self, int i, int lo, int hi) -> int {
if (hi <= l) return hi;
if (l <= lo) {
const S nxt = op(sm, nodes[i].val);
if (f(nxt)) {
sm = nxt;
return hi;
}
if (hi - lo == 1) return lo;
}
const int mid = (lo + hi) >> 1;
if (l < mid) {
const int left_res = self(self, nodes[i].left, lo, mid);
if (left_res < mid) return left_res;
}
return self(self, nodes[i].right, mid, hi);
};
return std::min(rec(rec, root.id, 0, size), _n);
}
template <bool (*f)(S)> int min_left(const Root &root, int r) const {
return min_left(root, r, [](S x) { return f(x); });
}
template <class F> int min_left(const Root &root, int r, F f) const {
assert(0 <= r && r <= _n);
assert(f(e()));
if (r == 0) return 0;
S sm = e();
auto rec = [&](auto &&self, int i, int lo, int hi) -> int {
if (r <= lo) return lo;
if (hi <= r) {
const S nxt = op(nodes[i].val, sm);
if (f(nxt)) {
sm = nxt;
return lo;
}
if (hi - lo == 1) return hi;
}
const int mid = (lo + hi) >> 1;
if (mid < r) {
const int right_res = self(self, nodes[i].right, mid, hi);
if (mid < right_res) return right_res;
}
return self(self, nodes[i].left, lo, mid);
};
return rec(rec, root.id, 0, size);
}
Root get_root() const { return {1}; }
struct Node {
S val;
int left, right;
};
int _n, size, lg;
std::vector<Node> nodes;
int new_node(const S &val, int left, int right) {
nodes.push_back(Node{val, left, right});
return int(nodes.size()) - 1;
}
};