読者です 読者をやめる 読者になる 読者になる

Starry Sky

難しいのがStarry Sky Tree組むとこじゃないってのが面白いよね(面白くない)
StarrySkyTreeをstructにしただけで2600ms->2200msぐらいになって頭抱えてる

#include <vector>
#include <iostream>
#include <set>
#include <cstdio>
#include <queue>
#include <cstdio>
#include <cstdlib>
#include <cassert>
#include <ctime>
#include <algorithm>
#include <tuple>
#include <algorithm>
#include <limits>

using namespace std;
typedef long long ll;
typedef pair<int, int> P;
//typedef tuple<int, int, int> T;
 
const int INT_INF = 1<<23;
const int MAX_N = 4010;
 
//size:2**N N=17:131072 N=12:4096
template <class T, int N>
struct StarrySkyTree {
    using P = std::pair<T, T>;
    using uint = unsigned int;
    static constexpr int size = 1<<N;
    P seg[size*2];
    
    void init(T x) {
        for (int i = 0; i < size*2; i++) {
            seg[i] = P(x, 0);
        }
    }

    void add(uint i, T x) {
        //if (i >= size) return;
        i += size - 1;
        seg[i].first += x;
        while (i) {
            i = (i - 1) / 2;
            int s1 = seg[i*2+1].first, s2 = seg[i*2+2].first;
            seg[i].first = seg[i].second + std::max(s1, s2);
        }
    }

    inline void add(uint a, uint b, T x, uint k = 0, uint l = 0, uint r = size) {
        if (a >= b || b > size) return;
        if (r <= a || b <= l) return;
        if (a <= l && r <= b) {
            seg[k].first += x;
            seg[k].second += x;
            return;
        }
        add(a, b, x, k*2+1, l, (l+r)/2);
        add(a, b, x, k*2+2, (l+r)/2, r);
        seg[k].first = seg[k].second + std::max(seg[k*2+1].first, seg[k*2+2].first);
    }

    inline T get(uint a = 0, uint b = size, uint k = 0, uint l = 0, uint r = size) {
        if (a >= b || b > size) return std::numeric_limits<T>::min();
        if (a <= l && r <= b) return seg[k].first;
        if (r <= a || b <= l) return std::numeric_limits<T>::min();
        T vl = get(a, b, k*2+1, l, (l+r)/2);
        T vr = get(a, b, k*2+2, (l+r)/2, r);
        return seg[k].second + std::max(vl, vr);
    }
};



int X[MAX_N], Y[MAX_N], L[MAX_N], Y2[MAX_N], YL[MAX_N];
P X2[MAX_N], L2[MAX_N];
int n;
inline int my_find(int x) {
    int l = 0, r = n;
    while (r - l > 1) {
        int mid = (l+r-1)/2;
        if (Y2[mid] == x) return mid;
        if (Y2[mid] >= x) {
            r = mid + 1;
        } else {
            l = mid + 1;
        }
    }
    return l;
}

int main() {
    cin >> n;
    bool used[MAX_N] = {};
    for (int i = 0; i < n; i++) {
        scanf("%d %d %d", X+i, Y+i, L+i);
        L2[i] = P(L[i], i);
        X2[i] = P(X[i], i);
        Y2[i] = Y[i];
    }
    sort(L2, L2+n);
    sort(X2, X2+n);
    sort(Y2, Y2+n);
    int Y3[MAX_N];
    for (int i = 0; i < n; i++) {
        Y3[i] = my_find(Y[i]);
    }
 
    ll r = 0;
    int buff[MAX_N];
    StarrySkyTree<ll, 12> s;
    for (int li = 0; li < n; li++) {
        int l = L2[li].first;
        int i = L2[li].second;
        int minj = lower_bound(X2, X2+n, P(X[i] - l, 0)) - X2;
        int maxj = upper_bound(X2, X2+n, P(X[i] + l, MAX_N)) - X2;
        s.init(-INT_INF);
        int c = 0;
        for (int j = minj; j < maxj; j++) {
            P p = X2[j];
            if (!used[p.second]) {
                if (Y[i] - l <= Y[p.second] && Y[p.second] <= Y[i] + l) {
                    buff[c] = j;
                    c++;
                }
            }
        }
        used[i] = true;
        if (r >= n-li) break;
        if (r >= c) continue;
        for (int j = 0; j < c; j++) {
            YL[j] = my_find(Y[X2[buff[j]].second] - l);
        }
        int k = 0;
        for (int j = 0; j < c; j++) {
            P pj = X2[buff[j]];
            for (;k < c; k++) {
                P pk = X2[buff[k]];
                if (pk.first > pj.first + l) break;
                int yk = Y3[pk.second];
                s.add(yk, INT_INF);
                int lk = YL[k];
                s.add(lk, yk+1, 1);
            }
            r = max(r, s.get());
            int yj = Y3[pj.second];
            s.add(yj, -INT_INF);
            int lj = YL[j];
            s.add(lj, yj+1, -1);
        }
    }
    printf("%lld\n", r);
    return 0;
}