struct node {
int ch[2],val,siz,cnt,pre;
}TR[N];
void rotate(int cur) {
int f = getwh(cur),fa = TR[cur].pre,gr = TR[fa].pre;
TR[gr].ch[getwh(fa)] = cur;
TR[cur].pre = gr;
TR[fa].ch[f] = TR[cur].ch[f ^ 1];
TR[TR[cur].ch[f ^ 1]].pre = fa;
TR[fa].pre = cur;
TR[cur].ch[f ^ 1] = fa;
up(fa);
up(cur);
}
int getwh(int cur) {
return TR[TR[cur].pre].ch[1] == cur;
}
void splay(int cur,int to) {
while(TR[cur].pre != to) {
if(TR[TR[cur].pre].pre != to) {
if(getwh(cur) == getwh(TR[cur].pre)) rotate(TR[cur].pre);
else rotate(cur);
}
rotate(cur);
}
if(!to) rt = cur;
}
void insert(int cur,int val,int lst) {
if(!cur) {
cur = ++tot;
TR[cur].pre = lst;
TR[cur].siz = TR[cur].cnt = 1;
TR[cur].val = val;
TR[lst].ch[val > TR[lst].val] = cur;
splay(cur,0);
return;
}
TR[cur].siz++;
if(val == TR[cur].val) {TR[cur].cnt++;return;}
if(val > TR[cur].val) insert(rs,val,cur);
else insert(ls,val,cur);
}
void merge(int cur,int y) {
if(TR[cur].val > TR[y].val) swap(cur,y);
if(!cur) {
rt = y;
return;
}
while(rs) cur = rs;
splay(cur,0);
rs = y;
TR[y].pre = cur;
up(cur);
}
int getpos(int cur,int val) {
int lst;
while(cur) {
lst = cur;
if(TR[cur].val == val) return cur;
cur = TR[cur].ch[val > TR[cur].val];
}
return lst;
}
void del(int cur,int val) {
cur = getpos(rt,val);
if(!cur) return;
if(TR[cur].val != val) return;
splay(cur,0);
if(TR[cur].cnt > 1) {TR[cur].cnt--;TR[cur].siz--;return;}
TR[ls].pre = TR[rs].pre = 0;
merge(ls,rs);
}
int kth(int cur,int x) {
while(cur) {
if(x <= TR[ls].siz) cur = ls;
else if(x > TR[ls].siz + TR[cur].cnt) x -= TR[cur].cnt + TR[ls].siz,cur = rs;
else return TR[cur].val;
}
return cur;
}
int pred(int cur,int val) {
cur = getpos(rt,val);
if(TR[cur].val < val) return TR[cur].val;
splay(cur,0);
cur = ls;
while(rs) cur = rs;
return TR[cur].val;
}
int nex(int cur,int val) {
cur = getpos(rt,val);
if(TR[cur].val > val) return TR[cur].val;
splay(cur,0);
cur = rs;
while(ls) cur = ls;
return TR[cur].val;
}
#include<cstdio>
#include<iostream>
using namespace std;
typedef long long ll;
const int N = 100000 + 100;
#define ls TR[cur].ch[0]
#define rs TR[cur].ch[1]
ll read() {
ll x = 0,f = 1;char c = getchar();
while(c < '0' || c > '9') {
if(c == '-') f = -1;
c = getchar();
}
while(c >= '0' && c <= '9') {
x = x * 10 + c -'0';
c = getchar();
}
return x * f;
}
int rt;
struct node {
int ch[2],val,siz,cnt,pre;
}TR[N];
void up(int cur) {
TR[cur].siz = TR[ls].siz + TR[rs].siz + TR[cur].cnt;
}
int getwh(int cur) {
return TR[TR[cur].pre].ch[1] == cur;
}
void rotate(int cur) {
int f = getwh(cur),fa = TR[cur].pre,gr = TR[fa].pre;
TR[gr].ch[getwh(fa)] = cur;
TR[cur].pre = gr;
TR[fa].ch[f] = TR[cur].ch[f ^ 1];
TR[TR[cur].ch[f ^ 1]].pre = fa;
TR[fa].pre = cur;
TR[cur].ch[f ^ 1] = fa;
up(fa);
up(cur);
}
void splay(int cur,int to) {
while(TR[cur].pre != to) {
if(TR[TR[cur].pre].pre != to) {
// if(getwh(cur) == getwh(TR[cur].pre)) rotate(TR[cur].pre);
// else
rotate(cur);
}
rotate(cur);
}
if(!to) rt = cur;
}
int tot;
void insert(int cur,int val,int lst) {
if(!cur) {
cur = ++tot;
TR[cur].pre = lst;
TR[cur].siz = TR[cur].cnt = 1;
TR[cur].val = val;
TR[lst].ch[val > TR[lst].val] = cur;
splay(cur,0);
return;
}
TR[cur].siz++;
if(val == TR[cur].val) {TR[cur].cnt++;return;}
if(val > TR[cur].val) insert(rs,val,cur);
else insert(ls,val,cur);
}
void merge(int cur,int y) {
if(TR[cur].val > TR[y].val) swap(cur,y);
if(!cur) {
rt = y;
return;
}
while(rs) cur = rs;
splay(cur,0);
rs = y;
TR[y].pre = cur;
up(cur);
}
int getpos(int cur,int val) {
int lst;
while(cur) {
lst = cur;
if(TR[cur].val == val) return cur;
cur = TR[cur].ch[val > TR[cur].val];
}
return lst;
}
void del(int cur,int val) {
cur = getpos(rt,val);
if(!cur) return;
if(TR[cur].val != val) return;
splay(cur,0);
if(TR[cur].cnt > 1) {TR[cur].cnt--;TR[cur].siz--;return;}
TR[ls].pre = TR[rs].pre = 0;
merge(ls,rs);
}
int Rank(int cur,int val) {
cur = getpos(rt,val);
splay(cur,0);
return TR[ls].siz + 1;
}
int kth(int cur,int x) {
while(cur) {
if(x <= TR[ls].siz) cur = ls;
else if(x > TR[ls].siz + TR[cur].cnt) x -= TR[cur].cnt + TR[ls].siz,cur = rs;
else return TR[cur].val;
}
return cur;
}
int pred(int cur,int val) {
cur = getpos(rt,val);
if(TR[cur].val < val) return TR[cur].val;
splay(cur,0);
cur = ls;
while(rs) cur = rs;
return TR[cur].val;
}
int nex(int cur,int val) {
cur = getpos(rt,val);
if(TR[cur].val > val) return TR[cur].val;
splay(cur,0);
cur = rs;
while(ls) cur = ls;
return TR[cur].val;
}
int main() {
int n = read();
while(n--) {
int opt = read(),x = read();
if(opt == 1) insert(rt,x,0);
if(opt == 2) del(rt,x);
if(opt == 3) printf("%d\n",Rank(rt,x));
if(opt == 4) printf("%d\n",kth(rt,x));
if(opt == 5) printf("%d\n",pred(rt,x));
if(opt == 6) printf("%d\n",nex(rt,x));
}
return 0;
}
原文:https://www.cnblogs.com/wxyww/p/10090186.html