[TJOI2013]最长上升子序列 (splay)

题目描述

给定一个序列,初始为空。现在我们将1到N的数字插入到序列中,每次将一个数字插入到一个特定的位置。每插入一个数字,我们都想知道此时最长上升子序列长度是多少?

输入格式:

第一行一个整数N,表示我们要将1到N插入序列中。

接下是N个数字,第k个数字Xk,表示我们将k插入到位置Xk(0<=Xk<=k-1,1<=k<=N)

输出格式:

N行,第i行表示i插入Xi位置后序列的最长上升子序列的长度是多少。

输入样例:

6

0 0 2 3 4 1

输出样例:

1

1

2

3

4

4

题解:

  1. 在第i次操作将i加入,那么在当前数列中i是最大的。
  2. 如果答案被更新,那么新的LIS以i结尾。
  3. 以i为结尾的LIS不会再被更新,因为后面插入的数都比i大。

所以我们对最终的序列求一遍LIS,就可以得到以i结尾的LIS长度,那么ans[i]=max(dp[i],dp[j]),j<i;

 那么现在的问题就是如何求序列,可以用splay维护中序遍历,先放一个最大值和最小值进去为了避免玄学数组越界。

插入时,将x+1旋到根,x+2旋到根的右儿子,然后再把要插入的数接到x+2的左儿子即可。

注意插入后要更新x+1和x+2的信息。

#include<bits/stdc++.h>
using namespace std;

const int maxn=100005;
const int oo=0x3f3f3f;
int n,num,root,mx;
int a[maxn],f[maxn];
struct Splay{
    int v,size,fa,s[2];
}tr[maxn];
struct answer{
    int id,cx;
}ans[maxn];

template<class T>inline void read(T &x){
    x=0;char ch=getchar();
    while(!isdigit(ch)) ch=getchar();
    while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
} 

bool cmp(answer a,answer b){
    return a.id<b.id;
}

int identify(int x){
    return x==tr[tr[x].fa].s[1];
}

void connect(int x,int y,int d){
    tr[x].fa=y;tr[y].s[d]=x;
}

void update(int x){
    tr[x].size=tr[tr[x].s[0]].size+tr[tr[x].s[1]].size+1;
}

void rotate(int x){
    int f=tr[x].fa,ff=tr[f].fa;
    int d1=identify(x),d2=identify(f);
    int cs=tr[x].s[d1^1];
    connect(cs,f,d1);
    connect(f,x,d1^1);
    connect(x,ff,d2);
    update(x);update(f);
}

void splay(int x,int go){
    if(go==root) root=x;
    go=tr[go].fa;
    while(tr[x].fa!=go){
        int f=tr[x].fa;
        if(tr[f].fa==go) rotate(x);
        else if(identify(x)==identify(f)){rotate(f);rotate(x);}
        else {rotate(x);rotate(x);}
    }
}

int find(int x){
    int now=root;
    while(1){
        if(tr[tr[now].s[0]].size>=x) now=tr[now].s[0];
        else{
            x-=tr[tr[now].s[0]].size;
            if(x==1) return now;
            x-=1;
            now=tr[now].s[1];
        }
    }
}

void insert(int go,int val){
    int x=find(go),y=find(go+1);
    splay(x,root);splay(y,tr[root].s[1]);
    tr[++num]=(Splay){val,1,y,{0,0}};
    tr[y].s[0]=num;
    update(y);update(x);
    splay(num,root);
}

void nice(int x){
    if(tr[x].s[0]) nice(tr[x].s[0]);
    if(tr[x].v!=oo&&tr[x].v!=-oo) a[++num]=tr[x].v;
    if(tr[x].s[1]) nice(tr[x].s[1]);
}

int divi(int x){
    int l=0,r=n;
    while(l<=r){
        int mid=(l+r)>>1;
        if(f[mid]>=x) r=mid-1;
        else l=mid+1;
    }
    return r;
}

int main(){
    read(n);
    tr[++num]=(Splay){-oo,2,0,{0,2}};
    tr[++num]=(Splay){oo,1,1,{0,0}};
    root=1;
    for(int i=1;i<=n;i++){
        int x;read(x);
        insert(x+1,i);
    }
    num=0;
    nice(root);
    //for(int i=1;i<=n;i++) printf("%d ",a[i]);
    //putchar(10);
    for(int i=1;i<=n;i++) f[i]=oo;
    for(int i=1;i<=n;i++){
        int k=divi(a[i])+1;
        f[k]=a[i];
        ans[i].cx=k;
        ans[i].id=a[i];
    }
    //for(int i=1;i<=n;i++) printf("%d ",ans[i]);
    sort(ans+1,ans+n+1,cmp);
    for(int i=1;i<=n;i++) printf("%d
",mx=max(mx,ans[i].cx));
}
View Code
原文地址:https://www.cnblogs.com/sto324/p/11171511.html