P6773 [NOI2020]命运

整体DP

很明显计算答案需要用容斥计算,如果暴力容斥的话,就是枚举哪些路径不符合条件,在这些路径的并集中的边都不能取,其他边任意取,设当前取了$i$条路径,那么对答案的贡献是$(-1)^i2^{n-1-Union}$

但是可以发现这个路径是自下往上的,可以考虑树形DP,设$dp[i][j][k]$表示在i的子树内,已选$k$条路径的上面那个端点的深度的最小值为$j$的方案数,特别的如果在这个子树内不选任何一条路径的话,$j$为$maxde$

但其实在容斥的时候,我们并不关心这个$k$的具体取值,只要关心其奇偶性即可,那么可以去掉这一维,在设初始值的时候将$-1$乘到值里面,在之后乘起来的时候,$-1$会帮助数值自动变号

考虑一个儿子一个儿子更新$dp[x][i]$,分情况讨论

如果$ileqslant de[x]$

$dp[x][i]=sum_{j=i+1}^{maxde}2dp[u][j]dp[x][i]+sum_{j=i+1}^{de[x]}dp[u][j]dp[x][i]+sum_{j=i}^{maxde}dp[u][i]dp[x][j]$

第一部分表示$u$这棵子树内的最浅的那个不能选的路径是小于当前$x$的深度,那么$u->x$这条边是可以任意选择;第二部分就表示$u$中最浅的点比$i$深,那么当前最浅的那个节点依然是$i$;第三部分表示$u$中最浅的点比$i$浅,那么当前最浅的点需要更新

如果$i>de[x]$

$dp[x][i]=sum_{j=i+1}^{maxde}2dp[u][j]dp[x][i]+sum_{j=i}^{maxde}2dp[u][i]dp[x][j]$

这也是类似的

可以发现第二维不为$0$的取值是较少的,只有在较前的点才会变多,那么就用线段树合并维护这个DP(跟[PKUWC2018]Minimax类似)

但是这个DP的方程需要分类讨论,就很难直观的进行维护修改,考虑如何用一个式子来表示这个DP方程

可以注意到第一个式子里前面两个部分是可以衔接在一起的,只不过第一部分有一个$2$的系数,但下标都是大于$de[x]$的,第二条方程也是一样,都是下标大于$de[x]$的时候$dp[u]$需要$*2$
那么在线段树合并之前,就把这段区间乘$2$即可
那么方程就是$dp[x][i]=sum_{j=i+1}^{maxde}dp[x][i]dp[u][j]+sum_{j=i+1}^{maxde}dp[x][j]dp[u][i]+dp[x][i]dp[u][i]$
再考虑新增路径的情况,对于一个节点可能会有多条下端点是这个节点的路径,设有$size$条,取第$i$浅的路径的话,那么比它浅的路径都不能选,比它深的路径随便选,那么在DP中需要赋值为钦定选这条路径和比它深的路径随便选的容斥系数之和
$sum_{j=0}^{size-i}(-1)^{j+1}C_{size-i}^{j}$
$=(-1)sum_{j=0}^{size-i}(-1)^{j}C_{size-i}^{j}$
$=(-1)(1-1)^{size-i}$
那么只有$i=size$的时候这个值不为$0$,那么只要记录最深的的那条路径更新即可
注意可以一条路径都不选,那么在$maxde$处要设为$1$
 
#include <bits/stdc++.h>
#define mod 998244353
using namespace std;
const int N=5*1e5+100;
int n,m,w,de[N],maxde,last[N],root[N],cnt;
int tot,first[N],nxt[N*2],point[N*2];
struct node
{
    int ls,rs;
    long long dp,tag;
}sh[N*40];
inline void add(long long &a,long long b){a=(a+b);((a>mod)?a-=mod:a=a);}
inline void del(long long &a,long long b){a=(a-b+mod)%mod;}
inline void mul(long long &a,long long b){a=(a*b)%mod;}
inline bool cmp(int a,int b){return(de[a]<de[b]);}
inline int read()
{
    int f=1,x=0;char s=getchar();
    while(s<'0'||s>'9'){if(s=='-')f=-1;s=getchar();}
    while(s>='0'&&s<='9'){x=x*10+s-'0';s=getchar();}
    return x*f;
}
inline void add_edge(int x,int y)
{
    tot++;
    nxt[tot]=first[x];
    first[x]=tot;
    point[tot]=y;
}
void dfs(int x,int fa)
{
    for (int i=first[x];i!=-1;i=nxt[i])
    {
        int u=point[i];
        if (u==fa) continue;
        de[u]=de[x]+1;
        dfs(u,x);
    }
}
inline void pushup(int x)
{
    sh[x].dp=(sh[sh[x].ls].dp+sh[sh[x].rs].dp)%mod;
}
inline void pushdown(int x)
{
    if (sh[x].tag==1) return;
    if (sh[x].ls) mul(sh[sh[x].ls].dp,sh[x].tag),mul(sh[sh[x].ls].tag,sh[x].tag);
    if (sh[x].rs) mul(sh[sh[x].rs].dp,sh[x].tag),mul(sh[sh[x].rs].tag,sh[x].tag);
    sh[x].tag=1;
}
int insert(int x,int l,int r,int wh,int v)
{
    if (!x) x=++cnt;
    sh[x].tag=1;
    if (l==r)
    {
        sh[x].dp=v;
        return x;
    }
    int mid=(l+r)>>1;
    if (wh<=mid) sh[x].ls=insert(sh[x].ls,l,mid,wh,v);
    else sh[x].rs=insert(sh[x].rs,mid+1,r,wh,v);
    pushup(x);
    return x;
}
void change(int x,int l,int r,int ll,int rr)
{
    if (ll<=l && rr>=r)
    {
        mul(sh[x].dp,2);mul(sh[x].tag,2);
        return;
    }
    int mid=(l+r)>>1;
    pushdown(x);
    if (ll<=mid) change(sh[x].ls,l,mid,ll,rr);
    if (rr>mid) change(sh[x].rs,mid+1,r,ll,rr);
    pushup(x);
}
int merge(int a,int b,int l,int r,long long sa,long long sb,long long pa,long long pb)
{
    if (!a)
    {
        mul(sh[b].dp,(sa-pa+mod)%mod);
        mul(sh[b].tag,(sa-pa+mod)%mod);
        return b;
    }
    if (!b)
    {
        mul(sh[a].dp,(sb-pb+mod)%mod);
        mul(sh[a].tag,(sb-pb+mod)%mod);
        return a;
    }
    if (l==r)
    {
        add(pa,sh[a].dp);add(pb,sh[b].dp);
        sh[a].dp=(sh[a].dp*(sb-pb+mod)%mod+sh[b].dp*(sa-pa+mod)%mod+sh[a].dp*sh[b].dp%mod)%mod;
        return a;
    }
    int mid=(l+r)>>1;long long ta=pa,tb=pb;
    pushdown(a);pushdown(b);
    add(ta,sh[sh[a].ls].dp);add(tb,sh[sh[b].ls].dp);
    sh[a].ls=merge(sh[a].ls,sh[b].ls,l,mid,sa,sb,pa,pb);
    sh[a].rs=merge(sh[a].rs,sh[b].rs,mid+1,r,sa,sb,ta,tb);
    pushup(a);
    return a;
}
void dfs1(int x,int fa)
{
    if (last[x]) root[x]=insert(root[x],1,maxde,last[x],mod-1);
    root[x]=insert(root[x],1,maxde,maxde,1);
    for (int i=first[x];i!=-1;i=nxt[i])
    {
        int u=point[i];
        if (u==fa) continue;
        dfs1(u,x);
        change(root[u],1,maxde,de[x]+1,maxde);
        root[x]=merge(root[x],root[u],1,maxde,sh[root[x]].dp,sh[root[u]].dp,0,0);
    }
}
int main()
{
    tot=-1;
    memset(first,-1,sizeof(first));
    memset(nxt,-1,sizeof(nxt));
    n=read();
    for (int i=1;i<n;i++)
    {
        int u=read(),v=read();
        add_edge(u,v);add_edge(v,u);
    }
    de[1]=1;
    dfs(1,1);
    for (int i=1;i<=n;i++) maxde=max(maxde,de[i]);
    maxde++;
    m=read();
    for (int i=1;i<=m;i++)
    {
        int u=read(),v=read();
        last[v]=max(last[v],de[u]);
    }
    dfs1(1,1);
    printf("%lld
",sh[root[1]].dp);
}
View Code
原文地址:https://www.cnblogs.com/huangchenyan/p/13533965.html