竞赛讨论区 > 多校第六场 I 题加难版的 treap 做法,欢迎吐槽
头像
Aerix
发布于 2018-08-05 18:29
+ 关注

多校第六场 I 题加难版的 treap 做法,欢迎吐槽

题意是给定 n(<= 2e5) 个区间,然后给出 m(<= 2e5) 个操作,操作数为 p 表示把含有 p 这个数的区间都删去,且记录删除时的操作号码;最终输出每次操作删除的区间数,以及每个区间被删除时的操作号码。

这个题其实还能再难一点,每个操作给出 p,q 两个操作数,要求把与 [p,q] 相交的所有区间删去。(这道题就看成是 p=q 的特例吧)

对于一个区间 [l,r],与 [p,q] 相交,可以分两种情况讨论:

  1. 对于 l<= p 的:p <= r
  2. 对于 l>p 的:l<= q

我们建一棵 treap, 它是可以在 O(log n) 的复杂度进行合并和分割的平衡二叉树,把区间左端点为第一关键字,把区间右端点为第二关键字,都是从小到大排,操作时把 treap 按 [p,q] 分割成两棵,叫左半边和右半边吧,左半边有着情况 1 要的点,右半边有着情况 2 要的点。

观察1: 右半边关键字最小的点都不满足 l<= q 的话,整个右半边都不满足。

观察2: 对于左半边的点,我们维护它的子树的区间右界能达到的最大值 rmax, 如果不满足 p <= rmax ,那么整棵子树都不满足。

图片说明

这个图可以作为一个例子(这个图是网上借来的) ,每个节点下面打粗体的就是 rmax ,比方说 p=42, 到了根节点一看,左孩子的 rmax 不够,左孩子那棵 3 个节点的子树统统不用考虑,直接看右孩子就行了。

值得注意的是这是一个摊还算法,应对每个操作的复杂度最好情况下是 O(log n),最坏情况下是 O(nlog n), 因为可能一次操作把所有区间都给删了,但是总体复杂度还是 O((n+m)log n)。
最后贴个代码。。
#include <cassert>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <tuple>
using namespace std;

typedef long long LL;
const int mod=998244353;
inline int addmod(int x, int y){
    x+=y; return x>=mod?x-mod:x;
}
inline int mulmod(LL x, int y){
    x*=y; return x>=mod?x%mod:x;
}

const int N=2e5+5, T=N;
const int oo=mod<<1;
#define left get<0>
#define right get<1>
#define index get<2>
typedef tuple<int,int,int> P;
int fix[T], lch[T], rch[T], siz[T];
P key[T]; int rmax[T], ans[N];
int tot, root;

inline void reset() {
    rmax[0]=-oo, root=0; tot=1;
}
inline int newnode(P const& x){
    int t=tot; ++tot; assert(tot<=T);
    lch[t]=rch[t]=0, siz[t]=1;
    fix[t]=rand()<<15|rand();
    key[t]=x; rmax[t]=right(key[t]);
    return t;
}
inline void pushdown(int){}
inline void update(int t){
    siz[t]=siz[lch[t]]+1+siz[rch[t]];
    rmax[t]=max(right(key[t]), max(rmax[lch[t]],rmax[rch[t]]));
}
void split(int t, int k, int &l, int &r){
    if(!t){l=r=0; assert(k==1); return;}
    pushdown(t);
    int tk=siz[lch[t]]+1;
    if(k<=tk){
        split(lch[t], k, l,lch[t]); r=t;
    }else{
        split(rch[t], k-tk, rch[t],r); l=t;
    }
    update(t);
}
void join(int &t, int l, int r){
    if(!l){t=r; return;}
    if(!r){t=l; return;}
    if(fix[l]>fix[r]){
        pushdown(t=l); join(rch[l],rch[l],r);
    }else{
        pushdown(t=r); join(lch[r],l,lch[r]);
    }
    update(t);
}
int order(int t, P x){
    if(!t)return 1;
    pushdown(t);
    if(key[t]<x)return siz[lch[t]]+1+order(rch[t],x);
    return order(lch[t],x);
}
inline void insert(P const& x){
    int l, r, k=order(root,x), t=newnode(x);
    split(root, k, l, r);
    join(r, t, r);
    join(root, l, r);
}
inline void erase(P const& x){
    int l, r, t, k=order(root,x);
    split(root, k, l, r);
    split(r, 2, t, r);
    join(root, l, r);
}
int _checkrmax(int t, int p){
    if(p <= right(key[t])) return t;
    if(p <= rmax[lch[t]]) return _checkrmax(lch[t], p);
    if(p <= rmax[rch[t]]) return _checkrmax(rch[t], p);
    abort();
}
int _checkleft(int t, int q){
    if(!t) return 0;
    if(left(key[t]) <= q) return t;
    return _checkleft(lch[t], q);
}
int intersect(P x){
    int l, r, k=order(root,x);
    split(root, k, l,r);
    int res;
    if(left(x) <= rmax[l]){
        res = _checkrmax(l, left(x));
    }else{
        res = _checkleft(r, right(x));
    }
    join(root, l, r);
    return res;
}

int main(){
#ifdef LOCAL
    freopen("in.txt","r",stdin);
#endif // LOCAL
    srand(19970518);
    int tac; scanf("%d",&tac);
    for(int tic=1; tic<=tac; ++tic){ printf("Case #%d:\n",tic);
        int n, zo; scanf("%d%d",&n,&zo);
        reset();
        for(int i=1; i<=n; ++i){
            int l, r; scanf("%d%d",&l,&r);
            insert(P(l,r,i));
            ans[i] = 0;
        }
        int res=0;
        for(int ko=1; ko<=zo; ++ko){
            int y; scanf("%d",&y);
            int x = res^y;
            res=1;
            int cnt=0, t;
            while((t=intersect(P(x,x,0)))){
                ++cnt;
                int u = index(key[t]);
                res=mulmod(res,u);
                ans[u]=ko;
                erase(key[t]);
            }
            if(cnt==0)res=0;
            printf("%d\n",cnt);
        }
        for(int i=1; i<=n; ++i){
            printf("%d%c",ans[i]," \n"[i==n]);
        }
    }
    return 0;
}

全部评论

(1) 回帖
加载中...
话题 回帖

等你来战

查看全部

热门推荐