风暴之眼

题目背景

通过月岛,帝王蟹和天体探测仪,你成功拼合了三个天体科技,接下来你要做的,就是来到风暴之眼的中心,准备那个神秘实验的最后一步。

最终的真相近在咫尺,你能否成功通过这场考验呢?

题目描述

天体风暴中的气象瞬息万变。

风暴中的道路构成一棵 n 个结点的无根树,第 i个结点有初始权值 (w_i)(w_i)为 0 或 1)和类型 (t_i)

结点的类型分为两种:( exttt{AND})$ 型结点和$ exttt{OR}$ 型结点。

(对于 exttt{AND} 型结点,每一秒结束后它的权值将变为它与它所有邻居上一秒权值的 exttt{AND} 和)

(对于 exttt{OR} 型结点,每一秒结束后它的权值将变为它与它所有邻居上一秒权值的 exttt{OR} 和)

现在,已知从某一时刻起,所有结点的权值都不再发生任何变化,将此时点 i 的权值称为 (a_i)

现不知每个点的初始权值和类型,只知道最终每个点的权值 (a_i),求出有多少种可能的初始权值和类型的组合,答案对 998244353 取模。

输入格式

第一行,一个整数 n,表示树的结点数。

第二行,n 个整数 (a_1, a_2, ldots , a_n),表示每个结点最终的权值。

接下来 n-1 行,每行两个整数 x,y,描述无根树中的一条边。

输出格式

输出一行一个整数,表示可能的初始权值和类型的组合数量。

输入输出样例

输入 #1复制

2
0 0
1 2

输出 #1复制

6

说明/提示

【样例 1 解释】

有如下六种初始权值和类型的组合:

  1. (((w_1, t_1), (w_2, t_2)) = ((0, exttt{AND}), (0, exttt{AND})))
  2. (((w_1, t_1), (w_2, t_2)) = ((0, exttt{AND}), (0, exttt{OR})))
  3. (((w_1, t_1), (w_2, t_2)) = ((0, exttt{OR}), (0, exttt{AND})))
  4. (((w_1, t_1), (w_2, t_2)) = ((0, exttt{OR}), (0, exttt{OR})))
  5. (((w_1, t_1), (w_2, t_2)) = ((1, exttt{AND}), (0, exttt{AND})))
  6. (((w_1, t_1), (w_2, t_2)) = ((0, exttt{AND}), (1, exttt{AND})))

【数据范围】

本题采用捆绑测试。

对于 100% 的数据,$$2 le n le 2 imes {10}^5,1 le x, y le n,a_i in { 0, 1 }$$,保证输入构成一棵树。

子任务编号 (nleq) 特殊限制
1 10
2 20
3 1000
4 ({10}^5) y=x+1
5 ({10}^5) (a_i=0)
6 (2 imes {10}^5)

题目分析

首先是:OR节点变为1则停止变化,AND节点变为0则停止变化

然后是变化结果的条件:

1.and 点组成的连通块,当且仅当初始时连通块内部全是 1,与其相邻的一圈 or 点初值也全是 1的时候,最终会全变为 1;否则最终会全变为 0

2.or 点组成的连通块,当且仅当初始时连通块内部全是 0,与其相邻的一圈 and 点初值也全是 0的时候,最终会全变为 0;否则最终会全变为 1

考虑到树上的联通块条件,进行树形DP,令(f[i][j][k][t])表示(i)节点,类型(j),初始化值设为(k)(t)是否满足条件以上变化条件

注意:DP转移时需要考虑一些特殊情况,比如 (x)作为 (y) 的父亲满足了$ y$的限制

具体的转移分析见注释

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long ll;
struct Node
{
    int next,to;
}edge[400005];
int head[200005],num=0;
int n,a[200005],Mod=998244353;
/**
 * 第2维0表示and,1表示or
 * 第3维表示初始值
 * 第4维表示是否符合and/or联通块条件
 * **/
int f[200005][2][2][2];
void add(int u,int v){
    num++;
    edge[num].next=head[u];
    head[u]=num;
    edge[num].to=v;
}
void dfs(int x,int fa){
    if (a[x]==0) {
        f[x][0][0][1]=1;
        f[x][0][1][0]=1;
        f[x][1][0][1]=1;
    }else{
        f[x][1][1][1]=1;
        f[x][1][0][0]=1;
        f[x][0][1][1]=1;
    }
    for (int i=head[x];i;i=edge[i].next){
        int v=edge[i].to;
        if (v==fa) continue;
        dfs(v,x);
        int temp[2][2][2]={0};
        if (a[x]==0&&a[v]==0){
            //and连通块,x和邻居至少一个0
            for (int v1=0;v1<2;v1++)
                for (int j=0;j<2;j++)
                    for (int v2=0;v2<2;v2++)
                        for (int k=0;k<2;k++){
                            temp[0][v1][j|k]+=1ll*f[x][0][v1][j]*f[v][0][v2][k]%Mod;
                            temp[0][v1][j|k]%=Mod;
                        }
            //x是and节点,v是or节点,都是0(要考虑父节点对子节点条件约束)
            temp[0][0][1]+=1ll*f[x][0][0][1]*f[v][1][0][1]%Mod;
            temp[0][0][1]%=Mod;
            //or连通块,全部是0
            //x是or节点,v是and,都是0(要考虑父节点对子节点条件约束)
            temp[1][0][1]+=(1ll*f[x][1][0][1]*f[v][1][0][1]%Mod+1ll*f[x][1][0][1]*f[v][0][0][1]%Mod)%Mod;
            temp[1][0][1]%=Mod;
        }
         if (a[x]==0&&a[v]==1){
            //x是and节点,v是or节点
            for (int v1=0;v1<2;v1++)
                for (int j=0;j<2;j++)
                    for (int v2=0;v2<2;v2++)
                        for (int k=0;k<2;k++)
                        if (k||v1){
                            temp[0][v1][j|(!v2)]+=1ll*f[x][0][v1][j]*f[v][1][v2][k]%Mod;
                            temp[0][v1][j|(!v2)]%=Mod;
                        }
        }
        if (a[x]==1&&a[v]==0){
            //x是or节点,v是and节点
            for (int v1=0;v1<2;v1++)
                for (int j=0;j<2;j++)
                    for (int v2=0;v2<2;v2++)
                        for (int k=0;k<2;k++)
                        if (k||(!v1)){
                            temp[1][v1][j|v2]+=1ll*f[x][1][v1][j]*f[v][0][v2][k]%Mod;
                            temp[1][v1][j|v2]%=Mod;
                        }
        }
        if (a[x]==1&&a[v]==1){
            //or连通块,x和邻居至少一个1
            for (int v1=0;v1<2;v1++)
                for (int j=0;j<2;j++)
                    for (int v2=0;v2<2;v2++)
                        for (int k=0;k<2;k++){
                            temp[1][v1][j|k]+=1ll*f[x][1][v1][j]*f[v][1][v2][k]%Mod;
                            temp[1][v1][j|k]%=Mod;
                        }
            //x是or节点,v是and节点,都是1(要考虑父节点对子节点条件约束)
            temp[1][1][1]+=1ll*f[x][1][1][1]*f[v][0][1][1]%Mod;
            temp[1][1][1]%=Mod;
            //and连通块,全部是0
            //x是and节点,v是or,都是1(要考虑父节点对子节点条件约束)
            temp[0][1][1]+=(1ll*f[x][0][1][1]*f[v][0][1][1]%Mod+1ll*f[x][0][1][1]*f[v][1][1][1]%Mod)%Mod;
            temp[0][1][1]%=Mod;
        }
        memcpy(f[x],temp,sizeof(temp));
    }
}
int main(){
    int u,v;
    ll ans;
    cin>>n;
    for (int i=1;i<=n;i++){
        scanf("%d",&a[i]);
    }
    for (int i=0;i<n-1;i++){
        scanf("%d%d",&u,&v);
        add(u,v);add(v,u);
    }
    dfs(1,0);
    ans=((f[1][0][0][1]+f[1][0][1][1])%Mod+(f[1][1][0][1]+f[1][1][1][1])%Mod)%Mod;
    cout<<ans;
}
原文地址:https://www.cnblogs.com/Y-E-T-I/p/15129341.html