「日常」树规子树并归优化

学长留了三道题。

才写完然后来写一下总结。

我在蓝皮书上学的树规全都是n3的dp,当时学的时候没什么感觉,但是后来做软件安装的时候就有一点感觉好像这个背包的很多转移都是浪费的,因为当前的物品远远不能达到整个子树的大小,而每次外层枚举的空间大小全部都是整个子树大小,而不是当前正在当作背包物品的大小和之前合并的物品的大小,也就是,多余的空间,全是浪费的。

然后wq学长就讲了一下他对熟练剖分(你为什么那么熟练啊你到底和雪菜亲过多少次了啊)那道题的另外一个写法。同时他发现我们根本不会n2合并的思路。

所以留了三道题。

先说一下考试题熟练剖分,这题真的熟练的让我头疼。

题解的思路看懂了一半,后来老吕把std放上来我就看明白了,然后注释了一下扔到数组练习,开始想wq学长的方案计数思路。

在转移的时候挨个子树向上合并,那么设tmp[2][0/1][j]为已经合并了的子树中,(存在/不存在)一个重儿子,子树中的最长轻链长度为j的方案数,那么挨个滚动tmp的转移就非常简单了。

紧接着就给dp数组赋上这个tmp就可以了。

第一题是树上染色。

这道题我一直在想怎么从点来转移,我同桌也是,但是发现算点的贡献太难了,我一开始想要存一个最优子结构情况数组,存一下每种情况下的黑点和白点深度和,但是这样转移代价太大。

昨天刚刚看到了一个关于dp的本质解释的blog,上面说dp其实是从小的决策来影响转台,从而使得阶段进步,而不需要管转移来的状态的决策是什么。这样看来,我这种方式其实是记录了子结构的决策,根本就不是dp。

反思一下接着来正解。

计算每条边的贡献,我们发现这个东西只需要每个子树的状态就可以解决。

设dp[x][i]为以x为子树,含有i个黑点的最大贡献。

这条从x链接到儿子c的边的贡献就是

i*(K-i)+(sz[c]-i)*(n-sz[c]-(K-i))

背包转移即可,子树合并就是n2了。

第二题可怜与超市。

正解是秒掉的。但是子树合并和一些复杂度的东西导致我一直在T70,最后从n3换成了子树合并才A掉,还好这道题卡掉我了,因为我熟练剖分和树上染色全都是n3的暴力做法,我在认真的开始想这道题的归并子树做法,仍然是设一个tmp数组在表示归并代价。

最后赋值转移就可以了。

第三题SAO

发现不是无向树,是一个有向的,一开始觉得自己想的差不多了,但其实还差一个组合数乘法转移,总之还是差一点的。

设一下dp[x][i]是以x为子树中,x是第i个“发生”的关卡,还是子树转移,乘一下插板组合数即可。

证明一下子树归并优化的复杂度。

发现每一个后面的子树枚举的大小都乘了前面已经合并的子树大小,所以这个转移是以x为LCA的节点对个数((x,y)和(y,x)是一对)

那么以每个点都为LCA的个数累加起来就是整个树的点对数,也就是n2。

这一块暂时算弄完了。

//熟练
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
typedef long long ll;
const int maxn=3005,mod=1e9+7;
ll p,n,x,y,root,ans,d[maxn],dp[maxn][maxn],tmp[2][2][maxn];
ll qw(ll a,ll b)
{
    int ans=1;
    for(;b;b>>=1,a=a*a%mod) if(b&1) ans=ans*a%mod;
    return ans;
}
struct tree{
    ll s,ch[maxn],sz,mx;
}zt[maxn];
void dfs(int x)
{
    zt[x].sz=1;
    for(int i=1;i<=zt[x].s;i++)
    {
        dfs(zt[x].ch[i]);
        zt[x].sz+=zt[zt[x].ch[i]].sz;
        zt[x].mx=max(zt[zt[x].ch[i]].mx+1,zt[x].mx);
    }
    if(zt[x].s==0)
    {
        dp[x][0]=1;
        return ;
    }
    memset(tmp,0,sizeof(tmp));
    for(int i=0;i<=zt[zt[x].ch[1]].mx+1;i++) 
    {
        tmp[1][1][i]=dp[zt[x].ch[1]][i];
        tmp[1][0][i+1]=dp[zt[x].ch[1]][i];
    }
    for(int s=2;s<=zt[x].s;s++)
    {
        int c=zt[x].ch[s],pre=s&1^1,now=s&1;
        for(int j=0;j<=zt[x].mx+1;j++) tmp[now][0][j]=tmp[now][1][j]=0;
        for(int j=0;j<=zt[x].mx+1;j++)
            for(int k=0;k<=zt[c].mx+1;k++)
            {
                (tmp[now][0][max(j,k+1)]+=tmp[pre][0][j]*dp[c][k])%=mod;
                (tmp[now][1][max(j,k)]+=tmp[pre][0][j]*dp[c][k])%=mod;
                (tmp[now][1][max(j,k+1)]+=tmp[pre][1][j]*dp[c][k])%=mod;
            }
    }
    int now=zt[x].s&1;
    for(int i=0;i<=zt[x].mx+1;i++) dp[x][i]=tmp[now][1][i];
}
int main()
{
    scanf("%lld",&n);
    p=1;
    for(int i=1;i<=n;i++)
    {
        scanf("%lld",&x);
        if(x!=0) p=p*x%mod;
        for(int j=1;j<=x;j++)
        {
            scanf("%lld",&y);
            zt[i].s++;zt[i].ch[j]=y;
            d[y]=1;
        }
    }
    for(int i=1;i<=n;i++)
        if(d[i]==0)
        {
            root=i;
            break;
        }
    dfs(root);
    for(int s=1;s<=zt[root].mx+1;s++)
    {
    //    cout<<s<<" "<<root<<" "<<dp[root][s]<<endl;
        ans=(ans+dp[root][s]*s%mod)%mod;
    }
//    cout<<ans<<" "<<p<<endl;
    ans=ans*qw(p,mod-2)%mod;
    printf("%lld
",(ans%mod+mod)%mod);
    return 0;
}
//可怜与超市
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<cstdlib>
using namespace std;
typedef long long ll;
const int maxn=5005;
int n,b,x,ans,s[maxn],d[maxn],sz[maxn],c[maxn],dp[maxn][maxn][2],tmp[2][maxn][2];
vector<int> son[maxn];
char xB[(1<<15)+10],*xS=xB,*xT=xB;
#define gtc (xS==xT&&(xT=(xS=xB)+fread(xB,1,1<<15,stdin),xS==xT)?0:*xS++)
inline void read(int &x){
  register char ch=gtc;
  for(x=0;ch<'0'||ch>'9';ch=gtc);
  for(;ch>='0'&&ch<='9';x=(x<<1)+(x<<3)+ch-'0',ch=gtc);
}
void dfs(int x)
{
    sz[x]=1;
    for(int i=0;i<son[x].size();i++) 
    {
        dfs(son[x][i]);
        sz[x]+=sz[son[x][i]];
    }
    dp[x][0][0]=0;dp[x][1][0]=c[x];dp[x][1][1]=c[x]-d[x];
    if(sz[x]==1) return ;
    memset(tmp,0x3f,sizeof(tmp));
    int mx=1;
    for(int i=0;i<son[x].size();i++)
    {
        int t=son[x][i],pre=i&1^1,now=i&1;
        memset(tmp[pre],0x3f,sizeof(tmp[pre]));
        tmp[now][0][0]=0;tmp[now][1][1]=c[x]-d[x];
        for(int j=mx;j>=0;j--)
            for(int k=0;k<=sz[t];k++)
            {
                if(j+k>=1) tmp[pre][j+k][1]=min(tmp[pre][j+k][1],tmp[now][j][1]+min(dp[t][k][1],dp[t][k][0]));
                tmp[pre][j+k][0]=min(tmp[pre][j+k][0],tmp[now][j][0]+dp[t][k][0]);
            }
        mx+=sz[t];
    }
    int now=son[x].size()&1;
    for(int i=sz[x];i>=1;i--) tmp[now][i][0]=min(tmp[now][i][0],tmp[now][i-1][0]+c[x]);
    for(int i=sz[x];i>=0;i--) 
    {
        dp[x][i][0]=tmp[now][i][0];
        dp[x][i][1]=tmp[now][i][1];
    }
}
int main()
{
    read(n);read(b);
    read(c[1]);read(d[1]);
    for(int i=2;i<=n;i++)
    {
        read(c[i]);read(d[i]);read(x);
        son[x].push_back(i);
    }
    dfs(1);
    for(int i=0;i<=n;i++)
        if(dp[1][i][0]<=b||dp[1][i][1]<=b)
            ans=i;
    printf("%d
",ans);
    return 0;
}
可怜与超市
//SAO
#include<iostream>
#include<cstdio>
#include<vector>
#include<cstring>
using namespace std;
typedef long long ll;
const int maxn=1005,mod=1e9+7;
int T,n,x,y,tot,first[maxn],sz[maxn],dp[maxn][maxn],tmp[2][maxn],fac[maxn],inv[maxn];
char c;
vector<pair<int,int> > ch[maxn];
struct road{
    int u,t,w,nxt;    //1前/-1后;
}eage[maxn<<1];
void add(int x,int y,int z)
{
    eage[++tot].u=x;
    eage[tot].t=y;
    eage[tot].w=z;
    eage[tot].nxt=first[x];
    first[x]=tot;
}
void clear()
{
    tot=0;
    memset(first,0,sizeof(first));
    memset(sz,0,sizeof(sz));
    memset(dp,0,sizeof(dp));
    for(int i=1;i<=n;i++) ch[i].clear();
}
int qw(int a,int b)
{
    int ans=1;
    for(;b;b>>=1,a=1LL*a*a%mod) if(b&1) ans=1LL*ans*a%mod;
    return ans;
}
void Get_facinv()
{
    fac[0]=inv[0]=1;
    for(int i=1;i<=1000;i++) fac[i]=1LL*fac[i-1]*i%mod;
    inv[1000]=qw(fac[1000],mod-2);
    for(int i=999;i>=1;i--) inv[i]=1LL*inv[i+1]*(i+1)%mod;
}
int Get_C(int n,int m)
{
    if(n<m) return 0;
    return 1LL*fac[n]*inv[m]%mod*inv[n-m]%mod;
}
void dfs(int x,int fa)
{
    sz[x]=1;
    for(int i=first[x];i;i=eage[i].nxt)
        if(eage[i].t!=fa)
        {
            dfs(eage[i].t,x);
            sz[x]+=sz[eage[i].t];
            ch[x].push_back(make_pair(eage[i].t,eage[i].w));
        }
    if(ch[x].size()==0)
    {
        dp[x][1]=1;
//        cout<<x<<":"<<endl;
//        puts("1");
        return ;
    }
    memset(tmp,0,sizeof(tmp));
    int mx=1;tmp[1][1]=1;
    for(int i=0;i<ch[x].size();i++)
    {
        int c=ch[x][i].first,p=ch[x][i].second,now=i&1,pre=i&1^1;
        memset(tmp[now],0,sizeof(tmp[now]));
        for(int j=1;j<=mx;j++)
        {
            for(int k=1;k<=sz[c];k++)
            {
                if(p>0) (tmp[now][j+k]+=1LL*tmp[pre][j]*dp[c][k]%mod*Get_C(j+k-1,k)%mod*Get_C(mx+sz[c]-k-j,sz[c]-k)%mod)%=mod;
                if(p<0) (tmp[now][j+k-1]+=1LL*tmp[pre][j]*(dp[c][sz[c]]-dp[c][k-1])%mod*Get_C(j+k-2,k-1)%mod*Get_C(mx+sz[c]-k-j+1,sz[c]-k+1)%mod)%=mod;
            }
        }
        mx+=sz[c];
    }
    int now=(ch[x].size()-1)&1;
    for(int i=1;i<=sz[x];i++) dp[x][i]=(dp[x][i-1]+tmp[now][i])%mod;
/*    cout<<x<<":"<<endl;
    for(int i=1;i<=sz[x];i++) cout<<dp[x][i]-dp[x][i-1]<<" ";
    cout<<endl;*/
}
int main()
{
    Get_facinv();
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d",&n);
        for(int i=1;i<n;i++)
        {
            scanf("%d %c %d",&x,&c,&y);
            x++;y++;
            if(c=='<') add(x,y,1),add(y,x,-1);
            if(c=='>') add(x,y,-1),add(y,x,1);
        }
        dfs(1,0);
        printf("%d
",(dp[1][sz[1]]%mod+mod)%mod);
        clear();
    }
    return 0;
}
SAO
原文地址:https://www.cnblogs.com/Lrefrain/p/11199023.html