[JSOI 2018] 潜入行动

[题目链接]

         https://www.lydsy.com/JudgeOnline/problem.php?id=5314 

[算法]

        考虑dp , 用f[i][j][0 / 1][0 / 1]表示以i为根的子树中选了j个 , 是否选i , i是否被覆盖的方案数

        树形背包进行合并 , 转移即可

        时间复杂度 : O(NK)

[代码]

        

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
const int MAXN = 1e5 + 10;
const int MAXK = 110;
const int P = 1e9 + 7;

struct edge
{
        int to , nxt;
} e[MAXN << 1];

int n , K , tot;
int head[MAXN] , dp[MAXN][MAXK][2][2] , tmp[MAXK][2][2] , size[MAXN];

#define rint register int

template <typename T> inline void chkmax(T &x,T y) { x = max(x,y); }
template <typename T> inline void chkmin(T &x,T y) { x = min(x,y); }
template <typename T> inline void read(T &x)
{
    T f = 1; x = 0;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
    for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
    x *= f;
}
inline void addedge(int x , int y)
{
        ++tot;
        e[tot] = (edge){y , head[x]};
        head[x] = tot;
}
inline void update(int &x , int y)
{
        x += y;
        while (x >= P) x -= P;
}
inline void dfs(int u , int par)
{
        size[u] = 1;
        dp[u][0][0][0] = 1;
        dp[u][1][1][0] = 1;
        for (rint i = head[u]; i; i = e[i].nxt)
        {
                int v = e[i].to;
                if (v == par) continue;
                dfs(v , u);
                for (rint j = min(K , size[u] + size[v]); j >= 0; --j)
                {
                        tmp[j][1][1] = 0;
                        tmp[j][0][0] = 0;
                        tmp[j][1][0] = 0;
                        tmp[j][0][1] = 0;   
                }
                for (rint j = 0; j <= size[u] && j <= K; ++j)
                {
                        for (rint k = 0; k <= size[v] && j + k <= K; ++k)
                        {
                                if (dp[u][j][1][1])
                                {
                                        update(tmp[j + k][1][1] , 1LL * dp[u][j][1][1] * dp[v][k][1][1] % P);
                                        update(tmp[j + k][1][1] , 1LL * dp[u][j][1][1] * dp[v][k][0][1] % P);
                                        update(tmp[j + k][1][1] , 1LL * dp[u][j][1][1] * dp[v][k][1][0] % P);
                                        update(tmp[j + k][1][1] , 1LL * dp[u][j][1][1] * dp[v][k][0][0] % P);
                                } 
                                if (dp[u][j][1][0])
                                {
                                        update(tmp[j + k][1][0] , 1LL * dp[u][j][1][0] * dp[v][k][0][0] % P);
                                        update(tmp[j + k][1][0] , 1LL * dp[u][j][1][0] * dp[v][k][0][1] % P);
                                        update(tmp[j + k][1][1] , 1LL * dp[u][j][1][0] * dp[v][k][1][0] % P);
                                        update(tmp[j + k][1][1] , 1LL * dp[u][j][1][0] * dp[v][k][1][1] % P);
                                }
                                if (dp[u][j][0][1])
                                {
                                        update(tmp[j + k][0][1] , 1LL * dp[u][j][0][1] * dp[v][k][0][1] % P);
                                        update(tmp[j + k][0][1] , 1LL * dp[u][j][0][1] * dp[v][k][1][1] % P);
                                }
                                if (dp[u][j][0][0])
                                {
                                        update(tmp[j + k][0][0] , 1LL * dp[u][j][0][0] * dp[v][k][0][1] % P);
                                        update(tmp[j + k][0][1] , 1LL * dp[u][j][0][0] * dp[v][k][1][1] % P);
                                }
                        }
                }        
                size[u] += size[v];
                for (int j = min(K , size[u]); j >= 0; --j)
                {
                        dp[u][j][0][0] = tmp[j][0][0];
                        dp[u][j][0][1] = tmp[j][0][1];
                        dp[u][j][1][0] = tmp[j][1][0];
                        dp[u][j][1][1] = tmp[j][1][1];
                }
        }        
}

int main()
{
        
        read(n); read(K);
        for (rint i = 1; i < n; ++i)
        {
                int x , y;
                read(x); read(y);
                addedge(x , y);
                addedge(y , x);
        }
        dfs(1 , 0);
        printf("%d
" , (dp[1][K][0][1] + dp[1][K][1][1]) % P);
        
        return 0;
    
}
原文地址:https://www.cnblogs.com/evenbao/p/10659983.html