cplib-cpp

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub hitonanode/cplib-cpp

:warning: k-d tree (2D)
(data_structure/kd_tree_2d.hpp)

$k$-d tree の二次元平面の場合の実装.

木の各ノードにおいて平面を切断する方向は原則として $x$, $y$ 方向のうち最大値と最小値の差が大きい方を採用するが,ワーストケースでの計算量悪化を回避するため,同じ方向の切断が連続しすぎないよう簡単な工夫を入れている.

vector<pair<long long, long long>> xys;

kd_tree<long long> kdt(xys);

long long xmin, xmax, ymin, ymax;
vector<int> ids = kdt.get_rect(xmin, xmax, ymin, ymax);  // 矩形に含まれる頂点番号取得

問題例

Code

#pragma once
#include <algorithm>
#include <tuple>
#include <utility>
#include <vector>

// 2次元の kd-tree
// 矩形内の全頂点取得が可能
// Verified: abc234h https://atcoder.jp/contests/abc234/submissions/28456735
template <class T> struct kd_tree {
    std::vector<std::pair<T, T>> _ps;
    struct Node {
        T xmin, xmax, ymin, ymax;
        std::vector<int> ids;
        int lch, rch;
        template <class OStream> friend OStream &operator<<(OStream &os, const Node &n) {
            os << "{Node[" << n.xmin << ", " << n.xmax << "]x[" << n.ymin << ", " << n.ymax
               << "], ids=(";
            for (auto i : n.ids) os << i << ',';
            os << "), chs=" << n.lch << ',' << n.rch << '}';
            return os;
        }
    };
    std::vector<Node> _nodes;

    using Tpl = std::tuple<int, T, T>;
    std::vector<Tpl> _tmp;

    int _build(int l, int r, int nsplitx, int nsplity) {
        if (l >= r) return -1;
        T xmin = std::get<1>(_tmp[l]), xmax = xmin, ymin = std::get<2>(_tmp[l]), ymax = ymin;
        std::vector<int> ids(r - l);
        for (int i = l; i < r; ++i) {
            T x = std::get<1>(_tmp[i]), y = std::get<2>(_tmp[i]);
            if (x < xmin) xmin = x;
            if (x > xmax) xmax = x;
            if (y < ymin) ymin = y;
            if (y > ymax) ymax = y;
            ids[i - l] = std::get<0>(_tmp[i]);
        }
        const int _node_id = _nodes.size();
        _nodes.push_back({xmin, xmax, ymin, ymax, ids, -1, -1});

        // Decide which direction to split
        bool split_x = xmax - xmin > ymax - ymin;
        if (nsplitx > 3) split_x = false; // 同じ方向に何度も連続で切らない
        if (nsplity > 3) split_x = true;

        if (r - l > 1) {
            int c = (l + r) / 2;
            if (split_x) {
                // split x
                std::nth_element(
                    _tmp.begin() + l, _tmp.begin() + c, _tmp.begin() + r,
                    [&](const Tpl &l, const Tpl &r) { return std::get<1>(l) < std::get<1>(r); });
                _nodes[_node_id].lch = _build(l, c, nsplitx + 1, 0);
                _nodes[_node_id].rch = _build(c, r, nsplitx + 1, 0);
            } else {
                // split y
                std::nth_element(
                    _tmp.begin() + l, _tmp.begin() + c, _tmp.begin() + r,
                    [&](const Tpl &l, const Tpl &r) { return std::get<2>(l) < std : get<2>(r); });
                _nodes[_node_id].lch = _build(l, c, 0, nsplity + 1);:
                _nodes[_node_id].rch = _build(c, r, 0, nsplity + 1);
            }
        }
        return _node_id;
    }

    void _initialize(const std::vector<std::pair<T, T>> &vs) {
        _ps = vs;
        _tmp.resize(_ps.size());
        for (int i = 0; i < int(vs.size()); ++i)
            _tmp[i] = std::make_tuple(i, vs[i].first, vs[i].second);
        _build(0, _tmp.size(), 0, 0);
    }
    kd_tree(const std::vector<std::pair<T, T>> &vs) { _initialize(vs); }

    std::vector<int> get_rect(T xmin, T xmax, T ymin, T ymax) const {
        // [xmin, xmax] * [ymin, ymax] に含まれる全点取得
        std::vector<int> ret;
        std::vector<int> _stack;
        if (_nodes.size()) _stack.push_back(0);
        while (!_stack.empty()) {
            const Node &now = _nodes[_stack.back()];
            _stack.pop_back();
            if (xmax < now.xmin or now.xmax < xmin or ymax < now.ymin or now.ymax < ymin) {
                ;
            } else if (xmin <= now.xmin and now.xmax <= xmax and ymin <= now.ymin and
                       now.ymax <= ymax) {
                ret.insert(ret.end(), now.ids.begin(), now.ids.end());
            } else {
                if (now.lch >= 0) _stack.push_back(now.lch);
                if (now.rch >= 0) _stack.push_back(now.rch);
            }
        }
        return ret;
    }
};
#line 2 "data_structure/kd_tree_2d.hpp"
#include <algorithm>
#include <tuple>
#include <utility>
#include <vector>

// 2次元の kd-tree
// 矩形内の全頂点取得が可能
// Verified: abc234h https://atcoder.jp/contests/abc234/submissions/28456735
template <class T> struct kd_tree {
    std::vector<std::pair<T, T>> _ps;
    struct Node {
        T xmin, xmax, ymin, ymax;
        std::vector<int> ids;
        int lch, rch;
        template <class OStream> friend OStream &operator<<(OStream &os, const Node &n) {
            os << "{Node[" << n.xmin << ", " << n.xmax << "]x[" << n.ymin << ", " << n.ymax
               << "], ids=(";
            for (auto i : n.ids) os << i << ',';
            os << "), chs=" << n.lch << ',' << n.rch << '}';
            return os;
        }
    };
    std::vector<Node> _nodes;

    using Tpl = std::tuple<int, T, T>;
    std::vector<Tpl> _tmp;

    int _build(int l, int r, int nsplitx, int nsplity) {
        if (l >= r) return -1;
        T xmin = std::get<1>(_tmp[l]), xmax = xmin, ymin = std::get<2>(_tmp[l]), ymax = ymin;
        std::vector<int> ids(r - l);
        for (int i = l; i < r; ++i) {
            T x = std::get<1>(_tmp[i]), y = std::get<2>(_tmp[i]);
            if (x < xmin) xmin = x;
            if (x > xmax) xmax = x;
            if (y < ymin) ymin = y;
            if (y > ymax) ymax = y;
            ids[i - l] = std::get<0>(_tmp[i]);
        }
        const int _node_id = _nodes.size();
        _nodes.push_back({xmin, xmax, ymin, ymax, ids, -1, -1});

        // Decide which direction to split
        bool split_x = xmax - xmin > ymax - ymin;
        if (nsplitx > 3) split_x = false; // 同じ方向に何度も連続で切らない
        if (nsplity > 3) split_x = true;

        if (r - l > 1) {
            int c = (l + r) / 2;
            if (split_x) {
                // split x
                std::nth_element(
                    _tmp.begin() + l, _tmp.begin() + c, _tmp.begin() + r,
                    [&](const Tpl &l, const Tpl &r) { return std::get<1>(l) < std::get<1>(r); });
                _nodes[_node_id].lch = _build(l, c, nsplitx + 1, 0);
                _nodes[_node_id].rch = _build(c, r, nsplitx + 1, 0);
            } else {
                // split y
                std::nth_element(
                    _tmp.begin() + l, _tmp.begin() + c, _tmp.begin() + r,
                    [&](const Tpl &l, const Tpl &r) { return std::get<2>(l) < std : get<2>(r); });
                _nodes[_node_id].lch = _build(l, c, 0, nsplity + 1);:
                _nodes[_node_id].rch = _build(c, r, 0, nsplity + 1);
            }
        }
        return _node_id;
    }

    void _initialize(const std::vector<std::pair<T, T>> &vs) {
        _ps = vs;
        _tmp.resize(_ps.size());
        for (int i = 0; i < int(vs.size()); ++i)
            _tmp[i] = std::make_tuple(i, vs[i].first, vs[i].second);
        _build(0, _tmp.size(), 0, 0);
    }
    kd_tree(const std::vector<std::pair<T, T>> &vs) { _initialize(vs); }

    std::vector<int> get_rect(T xmin, T xmax, T ymin, T ymax) const {
        // [xmin, xmax] * [ymin, ymax] に含まれる全点取得
        std::vector<int> ret;
        std::vector<int> _stack;
        if (_nodes.size()) _stack.push_back(0);
        while (!_stack.empty()) {
            const Node &now = _nodes[_stack.back()];
            _stack.pop_back();
            if (xmax < now.xmin or now.xmax < xmin or ymax < now.ymin or now.ymax < ymin) {
                ;
            } else if (xmin <= now.xmin and now.xmax <= xmax and ymin <= now.ymin and
                       now.ymax <= ymax) {
                ret.insert(ret.end(), now.ids.begin(), now.ids.end());
            } else {
                if (now.lch >= 0) _stack.push_back(now.lch);
                if (now.rch >= 0) _stack.push_back(now.rch);
            }
        }
        return ret;
    }
};
Back to top page