树形dp总结

题单来源
VJ题单

树形dp模型:

  • 以某一个节点为根,满足一定条件下的最大结果。这类题多自上向下转移,由已经算好的u的维护信息 更新 v所维护的信息。
  • H题中的1-k问题,可以转换成第k大的模型去解决。
  • 如J题涉及两点间路径选择,一般转移很多。
  • F题,与背包问题组合

A_CF686D Kay and Snowflake

int n, q;
int head[N << 1], cnt = 0;
int to[N << 1], nxt[N << 1];
int res[N], siz[N], son[N], fa[N];

void add(int u, int v){
    to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
}

void dfs(int u, int pre){
    siz[u] = 1;
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        dfs(v, u);
        siz[u] += siz[v];
        if(siz[v] > siz[son[u]]) son[u] = v;
    }
    if(siz[son[u]] * 2 > siz[u]){
        int rt = res[son[u]];
        while((siz[u] - siz[rt]) * 2 > siz[u]) rt = fa[rt];
        res[u] = rt;
    }
    else res[u] = u;
}

int main()
{
    scanf("%d%d",&n,&q);
    cnt = 0;
    for(int i = 0; i <= n; ++ i) head[i] = -1;
    for(int i = 2; i <= n; ++ i){
        int x; scanf("%d",&x);
        fa[i] = x;
        add(x, i);
    }
    dfs(1, 0);
    while(q --){
        int x; scanf("%d",&x);
        printf("%d
",res[x]);
    }
    return 0;
}

B_CF842C Ilya And The Tree

int n;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
int a[N], dp[N], res[N];
vector<int> sol[N];

void add(int u, int v){
    to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
    to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}

void dfs(int u, int pre){
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        int tt = sol[u].size();
        res[v] = dp[u];
        for(int i = 0; i < tt; ++ i){
            sol[v].push_back(gcd(sol[u][i], a[v]));
            res[v] = max(res[v], sol[v][i]);
        }
        sol[v].push_back(dp[u]);
        sort(sol[v].begin(),sol[v].end());
        sol[v].erase(unique(sol[v].begin(),sol[v].end()), sol[v].end()); 
        dp[v] = gcd(dp[u], a[v]);
        dfs(v, u);
    }
}

int main()
{
    scanf("%d",&n);
    cnt = 0;
    for(int i = 0; i <= n; ++ i){
        head[i] = -1;
    }
    for(int i = 1; i <= n; ++ i) scanf("%d",&a[i]);
    for(int i = 1; i < n; ++ i){
        int x, y; scanf("%d%d",&x,&y);
        add(x, y);
    }
    dp[1] = a[1];
    res[1] = a[1];
    sol[1].push_back(0);
    dfs(1, 0);
    for(int i = 1; i <= n; ++ i){
        if(i == n) printf("%d
",res[i]);
        else printf("%d ",res[i]);
    }
    return 0;
}

C_CF337D Book of Evil

int n, m, d;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
int down[N], up[N];
int res = 0;
int vis[N];

void add(int u, int v){
    to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
    to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}

void dfs1(int u, int pre){
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        dfs1(v, u);
        if(down[v] != -1) down[u] = max(down[u], down[v] + 1);
    }
    if(vis[u]){
        down[u] = max(down[u], 0);
    }
}

bool cmp(int x, int y){
    return down[x] > down[y];
}

void dfs(int u, int pre){
    vector<int> sol; sol.clear();
    if(max(up[u], down[u]) <= d) res ++;
    if(vis[u]) up[u] = max(up[u], 0);
    
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        sol.push_back(v);
    }
    sort(sol.begin(), sol.end(), cmp);
    
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        if(v == sol[0]){
            int maxx = -1;
            if(up[u] != -1) maxx = max(maxx, up[u]);
            if(sol.size() > 1 && down[sol[1]] != -1) maxx = max(maxx, down[sol[1]] + 1);
            if(maxx != -1) up[v] = maxx + 1;
        }
        else{
            int maxx = -1;
            if(up[u] != -1) maxx = max(maxx, up[u]);
            if(down[sol[0]] != -1) maxx = max(maxx, down[sol[0]] + 1);
            if(maxx != -1) up[v] = maxx + 1;           
        }
        dfs(v, u);
    }
}

int main()
{
    scanf("%d%d%d",&n,&m,&d);
    cnt = 0;
    for(int i = 0; i <= n; ++ i){
        head[i] = -1; up[i] = down[i] = -1; 
    }
    for(int i = 1; i <= m; ++ i){
        int x; scanf("%d",&x); vis[x] = 1;
    }
    for(int i = 1; i < n; ++ i){
        int x, y; scanf("%d%d",&x,&y);
        add(x, y);
    }
    if(m == 0){
        printf("0
");
        return 0;
    }
    dfs1(1, 0);
    dfs(1, 0);
    
    printf("%d
",res);
    return 0;
}

D_CF813C The Tag Game

int n, B;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
int dep[N], len[N], fa[N];

void add(int u, int v){
    to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
    to[cnt] = u, nxt[cnt] = head[v] ,head[v] = cnt ++;
}

void dfs(int u, int pre){
    fa[u] = pre;
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        len[v] = len[u] + 1;
        dfs(v, u);
        dep[u] = max(dep[u], dep[v] + 1);
    }
}

int main()
{
    scanf("%d%d",&n,&B);
    cnt = 0;
    for(int i = 0; i <= n + 10; ++ i) head[i] = -1;
    for(int i = 1; i < n; ++ i){
        int x, y; scanf("%d%d",&x,&y);
        add(x, y);
    }
    dfs(1, 0);
    int rt = B, res = 0, num = 0;
    while(rt != 1){
        if(len[rt] * 2 <= len[B]) break;
        res = max(res, (len[rt] + dep[rt]) * 2);
        rt = fa[rt];
    }
    printf("%d
",res);
    return 0;
}

E_CF219D Choosing Capital for Treeland

int n;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1], c[N << 1];
int dp[N], num[N][2];
vector<int> sol;

void add(int u, int v, int w){
    to[cnt] = v, c[cnt] = w, nxt[cnt] = head[u], head[u] = cnt ++;
    to[cnt] = u, c[cnt] = -w, nxt[cnt] = head[v], head[v] = cnt ++;
}

void dfs1(int u,int pre){
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i], w = c[i];
        if(v == pre) continue;
        if(w == 1) num[u][1] ++;
        else num[u][0] ++;
        dfs1(v, u);
        num[u][1] += num[v][1];
        num[u][0] += num[v][0];
    }
}

void dfs(int u, int pre){
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i], w = c[i];
        if(v == pre) continue;
        dp[v] = dp[u] + w;
        dfs(v, u);
    }
}


int main()
{
    scanf("%d",&n);
    cnt = 0;
    for(int i = 0; i <= n; ++ i) head[i] = -1;
    for(int i = 1; i < n; ++ i){
        int x, y; scanf("%d%d",&x,&y);
        add(x, y, 1);
    }
    dfs1(1, 0);
    dp[1] = num[1][0];
    dfs(1, 0);
    int minn = INF;
    for(int i = 1; i <= n; ++ i) minn = min(minn, dp[i]);
    for(int i = 1; i <= n; ++ i){
        if(dp[i] == minn) sol.push_back(i);
    }
    int tt = sol.size();
    printf("%d
",minn);
    for(int i = 0; i < tt; ++ i){
        if(i == tt - 1) printf("%d
",sol[i]);
        else printf("%d ",sol[i]);
    }
    return 0;
}

F_CF212E IT Restaurants

int n;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
int siz[N];
vector<int> res;
bool vis[N];
int dp[N][N];

void add(int u, int v){
    to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
    to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}

void dfs(int u, int pre){
    vector<int> sol;
    siz[u] = 1;
    dp[u][0] = 1;
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        dfs(v, u);
        siz[u] += siz[v];
        sol.push_back(siz[v]);
    }
    if(pre != 0) sol.push_back(n - siz[u]);
    int tt = sol.size();
    for(int i = 0; i < tt; ++ i){
        for(int j = n - 1; j >= 0; -- j){
            if(dp[u][j]) dp[u][j + sol[i]] = 1;
        }
    }
    for(int i = 1; i < n - 1; ++ i){
        if(dp[u][i] && !vis[i]){
            vis[i] = true;
            vis[n - i - 1] = true;
            res.push_back(i);
            if(n - i - 1 != i) res.push_back(n - i - 1);
        }
    }
}

int main()
{
    scanf("%d",&n);
    cnt = 0;
    for(int i = 0; i <= n; ++ i) head[i] = -1;
    for(int i = 1; i < n; ++ i){
        int x, y; scanf("%d%d",&x,&y);
        add(x, y);
    }
    dfs(1, 0);
    
    sort(res.begin(), res.end());
    int tt = res.size();
    printf("%d
",tt);
    for(int i = 0; i < tt; ++ i){
        printf("%d %d
",res[i], n - 1 - res[i]);
    }
    return 0;
}

G_CF161D Distance in Tree

int n, k;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
ll dp[N][520];
ll res = 0;

void add(int u, int v){
    to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
    to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}

void dfs(int u, int pre){
    dp[u][0] = 1;
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        dfs(v, u);
        for(int i = 0; i < k; ++ i){
            res += dp[u][i] * dp[v][k - i - 1];
        }
        for(int i = 1; i <= k; ++ i){
            dp[u][i] += dp[v][i - 1];
        }
    }
}

int main()
{
    scanf("%d%d",&n,&k);
    cnt = 0;
    for(int i = 0; i <= n; ++ i) head[i] = -1;
    for(int i = 1; i < n; ++ i){
        int x, y; scanf("%d%d",&x,&y);
        add(x, y);
    }
    dfs(1, 0);
    printf("%lld
",res);
    return 0;
}

H_CF1153D Serval and Rooted Tree

int n;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
int val[N], dp[N], res = 0;

void add(int u, int v){
    to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
    to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}

void dfs(int u, int pre){
    int tp = INF, flag = 0;
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        flag = 1;
        dfs(v, u);
        if(val[u]) tp = min(tp, dp[v]);
        else dp[u] += dp[v];
    }
    if(val[u]) dp[u] = tp;
    if(!flag){
        dp[u] = 1, res ++;
    }
}


int main()
{
    scanf("%d",&n);
    cnt = 0;
    for(int i = 1; i <= n; ++ i) head[i] = -1;
    for(int i = 1; i <= n; ++ i) scanf("%d",&val[i]);
    for(int i = 2; i <= n; ++ i){
        int x; scanf("%d",&x);
        add(x, i);
    }
    dfs(1, 0);
    printf("%d
",res - dp[1] + 1);
    return 0;
}

I_CF14D Two Paths

int n;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
int up[N], dep[N], val1[N], val2[N];
int res = 0;
struct node{
    int val, si;
};
bool cmp(node a, node b){
    return a.val > b.val;
}
void add(int u, int v){
    to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
    to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}
void dfs1(int u, int pre){
    int max1 = 0, max2 = 0;
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        dfs1(v, u);
        val2[u] = max(val2[u], val2[v]);
        if(dep[v] + 1 > max1){
            max2 = max1;
            max1 = dep[v] + 1;
        }
        else if(dep[v] + 1 > max2) max2 = dep[v] + 1;
    }
    dep[u] = max1;
    val2[u] = max(val2[u], max1 + max2);
}
void dfs2(int u, int pre){
    vector<node> sol;
    int dep1 = 0, dep2 = 0, tdep = 0;
    int max1 = val1[pre], max2 = 0, tp = 0;
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        if(dep[v] + 1 > dep1){
            dep2 = dep1;
            dep1 = dep[v] + 1;
            tdep = v;
        }
        else if(dep[v] + 1 > dep2) dep2 = dep[v] + 1;
        sol.push_back((node){dep[v] + 1, v});
        if(val2[v] > max1){
            max2 = max1;
            max1 = val2[v];
            tp = v;
        }
        else if(val2[v] > max2)  max2 = val2[v];
    }
    sol.push_back((node){up[u], u});
    sort(sol.begin(),sol.end(),cmp);
    int tt = sol.size();
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        if(v == tdep)  up[v] = 1 + max(up[u], dep2);
        else up[v] = 1 + max(up[u], 1 + dep1);
        if(tt > 2){
            int tval = 0;
            if(v == sol[0].si) tval = sol[1].val + sol[2].val;
            else if(v == sol[1].si) tval = sol[0].val + sol[2].val;
            else tval = sol[0].val + sol[1].val;
            val1[u] = tval;
        }
        else val1[u] = up[u];
        if(v == tp) val1[u] = max(val1[u], max2);
        else val1[u] = max(val1[u], max1);
        res = max(res, val2[v] * val1[u]);
        dfs2(v, u);
    }
}
int main()
{
    scanf("%d",&n);
    cnt = 0;
    for(int i = 0; i <= n; ++ i) head[i] = -1;
    for(int i = 1; i < n; ++ i){
        int x, y; scanf("%d%d",&x,&y);
        add(x, y);
    }
    dfs1(1, 0);
    dfs2(1, 0);
    printf("%d
",res);
    return 0;
}

J_CF1156D 0-1-Tree

int n;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1], c[N << 1];
int root[2][N], num[2][N];
int tx[N], ty[N], tz[N];
ll res = 0;

void add(int u, int v, int w){
    to[cnt] = v, c[cnt] = w, nxt[cnt] = head[u], head[u] = cnt ++;
    to[cnt] = u, c[cnt] = w, nxt[cnt] = head[v], head[v] = cnt ++;
}

int Find(int op, int x){
    return root[op][x] == x ? x : root[op][x] = Find(op, root[op][x]);
}

void Union(int op, int x, int y){
    int tx = Find(op, x), ty = Find(op, y);
    if(tx != ty){
        root[op][tx] = ty;
        num[op][ty] += num[op][tx];
        num[op][tx] = 0;
    }
}

void solve(){
    for(int i = 1; i <= n; ++ i){
        if(root[1][i] == i){
            res += 1ll * num[1][i] * (num[1][i] - 1);
        }
        if(root[0][i] == i){
            res += 1ll * num[0][i] * (num[0][i] - 1);
        }
    }
    // cout<<res<<endl;
    for(int u = 1; u <= n; ++ u){
        int num0 = 0, num1 = 0;
        for(int i = head[u]; i != -1; i = nxt[i]){
            int v = to[i], w = c[i];
            if(w){
                num1 = 1;
            }
            else{
                num0 = 1;
            }
        }
        if(num1 && num0){
            num1 = num[1][Find(1, u)] - 1;
            num0 = num[0][Find(0, u)] - 1;
            res += 1ll * num1 * num0;
        }
    }
}

int main()
{
    scanf("%d",&n);
    cnt = 0;
    for(int i = 0; i <= n; ++ i){
        head[i] = -1; root[0][i] = root[1][i] = i;
        num[0][i] = num[1][i] = 1;
    }
    
    for(int i = 1; i < n; ++ i){
        scanf("%d%d%d",&tx[i],&ty[i],&tz[i]);
        add(tx[i], ty[i], tz[i]);
        if(tz[i]){
            Union(1, tx[i], ty[i]);
        }
        else{
            Union(0, tx[i], ty[i]);
        }
    }
    solve();
    printf("%lld
",res);
    return 0;
}

K_CF1092F Tree with Maximum Cost

int n;
int head[N], cnt = 0;
int to[N << 1], nxt[N << 1];
ll a[N], dep[N], dp[N], val[N];
ll all, res;

void add(int u, int v){
    to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++;
    to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++;
}

void dfs(int u, int pre){
    val[u] = a[u];
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        dep[v] = dep[u] + 1;
        dfs(v, u);
        val[u] += val[v];
    }
}

void dfs1(int u, int pre){
    res = max(res, dp[u]);
    for(int i = head[u]; i != -1; i = nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        dp[v] = dp[u] + all - 2 * val[v];
        dfs1(v, u);
    }
}

int main()
{
    scanf("%d",&n);
    cnt = 0;
    for(int i = 0; i <= n; ++ i) head[i] = -1;
    for(int i = 1; i <= n; ++ i){
        scanf("%d",&a[i]);
        all += a[i];
    }
    for(int i = 1; i < n; ++ i){
        int x, y; scanf("%d%d",&x,&y);
        add(x, y);
    }
    dfs(1, 0);
    for(int i = 1; i <= n; ++ i){
        dp[1] += a[i] * dep[i];
    }
    dfs1(1, 0);
    printf("%lld
",res);
    return 0;
}
原文地址:https://www.cnblogs.com/A-sc/p/13860761.html