NOI2020命运(线段树合并优化树形DP)

(考场时只想到暴力容斥,24分滚了)

题目:https://loj.ac/problem/3340#submit_code

DP方程怎么来就不写了,本文重点分析如何用线段树合并及正确性

DP方程为:令dp[u][h]表示u的子树中上端点已经处理好了,下段点在子树中的,上段点最深为h的方案数,特别的,如果h==0,则代表已经处理好了

每当遇到一个儿子v时都更新一遍dp数组

dp’[u][h] = dp[v][h] * sum[u][h-1] + dp[u][h] * (sum[v][h] + sum[v][dep[u]); 

考虑线段树合并(以下子树)

我们可以便遍历边更新各式的值

我们开n颗线段树,下标 i 维护dp[u][i],维护区间和,维护区间乘积

不妨令S1 = sum[u][h-1],S2 = (sum[v][h] + sum[v][dep[u][)

在线段树合并的同时更新S1,S2

我们只要讨论4种情况

(I) (!U && !V)很显然这种情况可以直接返回

(II)  (!U || !V)

即合并中有其中一颗树没有子树了,反映在DP方程其实就是这一整块都没值了,都不合法

这就意味着U或者V(没有值的那个)的sum在这段区间不会再变了,变成常量,于是就可以愉快地区间乘法了

(III)(l == r)

这就意味遍历到叶子节点,这是最简单的,直接套公式就可以了

(IV)

otherwise

即既有左子树又有右子树

那我们只需要遍历左子树完遍历右子树即可

为什么是对的呢

因为我们处理左子树的答案,S1,S2已经被左子树的前缀和更新了,在处理右子树的答案时,同时考虑了左边的贡献

类似于CDQ分治的思想

然后就做完了(本蒟觉得还是一道很妙的题)

代码如下

/*命运*/ 
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
#define ll long long
#define mod 998244353
const int maxn = 5e5 + 10;
int Add(int x,int y){
    x += y;
    return (x >= mod)?x - mod:x;
}
int rt[maxn],h[maxn];
struct SegmentTree{
    int lc,rc;
    ll lzy;
    ll sum;
    #define lc(p)    t[p].lc
    #define rc(p)    t[p].rc
    #define sum(p)    t[p].sum
    #define lzy(p)    t[p].lzy
}t[maxn<<5];
int cnt = 0;
void pushup(int p){
    sum(p) = (sum(lc(p)) + sum(rc(p))) % mod;
}
void pushdown(int p){
    if(lzy(p) == 1)        return;
    lzy(lc(p)) = lzy(lc(p)) * lzy(p) % mod;
    sum(lc(p)) = lzy(p) * sum(lc(p)) % mod;
    lzy(rc(p)) = lzy(rc(p)) * lzy(p) % mod;
    sum(rc(p)) = lzy(p) * sum(rc(p)) % mod;
    lzy(p) = 1;
}
void Ins(int &p,int l,int r,int pos,int v){
    if(!p)    p = ++cnt;
    t[p].lzy = t[p].sum = v;
    if(l == r){
        return;
    }
    int mid = (l + r) >> 1;
    if(pos <= mid)        Ins(lc(p),l,mid,pos,v);
    else    Ins(rc(p),mid+1,r,pos,v);
    //pushup(p);
}
ll query(int p,int l,int r,int a,int b){
    if(a <= l && b >= r){
        return sum(p);
    }
    pushdown(p);
    ll ans = 0;
    int mid = (l + r) >> 1;
    if(a <= mid)    ans = Add(ans,query(lc(p),l,mid,a,b));
    if(b > mid)        ans = Add(ans,query(rc(p),mid+1,r,a,b));
    return ans;
}
int merge(int u,int v,int l,int r,ll &S1,ll &S2){
    if(!u && !v)    return 0;
    if(!u || !v){
        if(!u){
            S2 = Add(S2,sum(v));
            sum(v) = sum(v) * S1 % mod;
            lzy(v) = lzy(v) * S1 % mod;
            return v;
        }
        S1 = Add(S1,sum(u));
        sum(u) = sum(u) * S2 % mod,lzy(u) = lzy(u) * S2 % mod;
        return u;
    }
    if(l == r){
        S2 = Add(S2,sum(v));
        ll idu = sum(u);
        sum(u) = Add(S1 * sum(v) % mod,S2 * sum(u) % mod);
        S1 = Add(S1,idu);
        return u;
    }
    pushdown(u),pushdown(v);
    int mid = (l + r) >> 1;
    lc(u) = merge(lc(u),lc(v),l,mid,S1,S2);
    rc(u) = merge(rc(u),rc(v),mid+1,r,S1,S2);
    pushup(u);
    return u;
}
int read(){
    char c = getchar();
    int x = 0;
    while(c < '0' || c > '9')    c = getchar();
    while(c >= '0' && c <= '9')        x = x * 10 + c - 48,c = getchar();
    return x;
}
struct Edge{
    int nxt,point;
}edge[maxn*2];
int tot = 0;
int dep[maxn];
int n;
int head[maxn];
void add(int x,int y){
    edge[++tot].nxt = head[x];
    edge[tot].point = y;
    head[x] = tot;
}
void Dfs(int x,int fa){
    dep[x] = dep[fa] + 1;
    for(int i = head[x]; i ; i = edge[i].nxt){
        int y = edge[i].point;
        if(y == fa)        continue;
        Dfs(y,x);
    }
}
void TreeDP(int u,int fa){
    Ins(rt[u],0,n,h[u],1);
    for(int i = head[u]; i ; i = edge[i].nxt){
        int v = edge[i].point;
        if(v == fa)        continue;
        TreeDP(v,u);
        ll S1 = 0,S2 = query(rt[v],0,n,0,dep[u]);
        rt[u] = merge(rt[u],rt[v],0,n,S1,S2);
    }
}
int main()
{
    freopen("destiny.in","r",stdin) ;
    freopen("destiny.out","w",stdout);
    n = read();
    for(int i = 1; i < n; ++i){
        int x = read(),y = read();
        add(x,y);    add(y,x);
    }
    Dfs(1,0);
    int m = read();
    for(int i = 1; i <= m; ++i){
        int u = read(),v = read();
        if(!h[v])    h[v] = dep[u];
        else    h[v] = max(h[v],dep[u]);
    }
    TreeDP(1,0);
    printf("%lld
",query(rt[1],0,n,0,0));
    return 0;
}
原文地址:https://www.cnblogs.com/y-dove/p/13538821.html