LOJ#2546. 「JSOI2018」潜入行动 树形DP

现在看来这道题就简单了.    

首先要知道,树形 DP 的复杂度是 $O(n^2)$ 的(通过严格控制子树大小,均摊下来一个状态只会贡献 n 次).   

然后这道题要求选的个数最多为 $k$,所以复杂度就是 $O(nk)$ 的.      

设 4 个状态:$f[x][y][0/1],g[x][y][0/1]$ 分别代表 $x$ 点是否被控制/ $x$ 点是否选择了一个点.        

树形 DP 的时候要注意清空 $tmp$ 数组以免影响后面的过程.   

code:   

#include <cstdio>    
#include <cstring>
#include <algorithm>      
#define N 100009    
#define ll long long    
#define mod 1000000007
#define setIO(s) freopen(s".in","r",stdin)  
using namespace std;          
int edges,n,MAX;     
int size[N],hd[N],to[N<<1],nex[N<<1];     
int f[N][103][2],g[N][103][2],tp1[104][2],tp2[104][2];  
void add(int u,int v) {    
    nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;  
}        
int ADD(int x,int y) { 
    return (ll)(x+y)%mod;  
}     
int MUL(int x,int y) {  
    return (ll)x*y%mod;   
}
void dfs(int x,int ff) {            
    size[x]=1;   
    f[x][0][0]=f[x][1][1]=1;     
    for(int i=hd[x];i;i=nex[i]) {  
        int v=to[i];  
        if(v==ff) continue;   
        dfs(v,x);           
        for(int j=0;j<=min(size[x]+size[v],MAX);++j) { 
            tp1[j][0]=tp1[j][1]=0;  
            tp2[j][0]=tp2[j][1]=0;    
        }           
        for(int j=0;j<=min(MAX,size[x]);++j)     
            for(int k=0;k<=min(MAX,size[v]);++k) {     
                if(j+k>MAX) break;             
                (tp1[j+k][0]+=MUL(f[x][j][0],g[v][k][0]))%=mod;      
                (tp1[j+k][1]+=MUL(f[x][j][1],ADD(g[v][k][0],f[v][k][0])))%=mod;     
                (tp2[j+k][0]+=MUL(g[x][j][0],ADD(g[v][k][0],g[v][k][1])))%=mod;         
                (tp2[j+k][0]+=MUL(f[x][j][0],g[v][k][1]))%=mod;     
                (tp2[j+k][1]+=MUL(g[x][j][1],ADD(ADD(g[v][k][0],g[v][k][1]),ADD(f[v][k][0],f[v][k][1]))))%=mod;
                (tp2[j+k][1]+=MUL(f[x][j][1],ADD(g[v][k][1],f[v][k][1])))%=mod;    
            }   
        size[x]+=size[v];       
        for(int j=0;j<=min(MAX,size[x]);++j) { 
            f[x][j][0]=tp1[j][0];    
            f[x][j][1]=tp1[j][1];  
            g[x][j][0]=tp2[j][0];  
            g[x][j][1]=tp2[j][1];   
        }
    }     
}       
int main() {  
    // setIO("input");       
    int x,y,z;  
    scanf("%d%d",&n,&MAX);           
    for(int i=1;i<n;++i) { 
        scanf("%d%d",&x,&y);  
        add(x,y),add(y,x);   
    }   
    dfs(1,0);  
    printf("%d
",ADD(g[1][MAX][0],g[1][MAX][1]));   
    return 0;  
}

  

原文地址:https://www.cnblogs.com/guangheli/p/13292028.html