Evensgn 剪树枝

转载请注明出处:

http://www.cnblogs.com/hzoi-wangxh/p/7738629.html 

Evensgn 剪树枝

时间限制:1s 空间限制:128MB

题目描述

繁华中学有一棵苹果树。苹果树有n 个节点(也就是苹果),n − 1 条边(也就

是树枝)。调皮的Evensgn 爬到苹果树上。他发现这棵苹果树上的苹果有两种:一

种是黑苹果,一种是红苹果。Evensgn想要剪掉 k 条树枝,将整棵树分成k + 1 个

部分。他想要保证每个部分里面有且仅有一个黑苹果。请问他一共有多少种剪树枝

的方案?

输入格式

第一行一个数字n,表示苹果树的节点(苹果)个数。

第二行一共n − 1 个数字p0, p1, p2, p3, ..., pn−2,pi表示第 i + 1 个节点和pi 节

点之间有一条边。注意,点的编号是0 到 n − 1。

第三行一共n 个数字 x0, x1, x2, x3, ..., xn−1。如果xi 是 1,表示i 号节点是黑

苹果;如果xi 是 0,表示i 号节点是红苹果。

输出格式

输出一个数字,表示总方案数。答案对109 + 7 取模。

样例输入1

3

0 0

0 1 1

6

样例输出1

2

样例输入2

6

0 1 1 0 4

1 1 0 0 1 0

样例输出2

1

样例输入3

10

0 1 2 1 4 4 4 0 8

0 0 0 1 0 1 1 0 0 1

样例输出3

27

数据范围

对于30% 的数据,1 ≤n ≤ 10。

对于60% 的数据,1 ≤n ≤ 100。

对于80% 的数据,1 ≤n ≤ 1000。

对于100% 的数据,1 ≤n ≤ 105

对于所有数据点,都有0 ≤ pi ≤n − 1,xi = 0 或xi = 1。

特别地,60%中、80% 中、100%中各有一个点,树的形态是一条链。


题解:

    其实是一个树规。
    设f[i][j],f表示方案数,i表示以i为根节点的子树,j为0或1,0表示这棵子树的黑苹果数量等于砍的刀数,1代表砍的刀数比黑苹果数量少1.
    为什么设这两种关系?我们可以想一下,整棵树有k个黑苹果,需要砍k-1刀,分成k个部分。如果把其中的一部分单独提出来,发现黑苹果数量比这一段砍的刀数少1,那么其他部分肯定是砍的刀数等于黑苹果数量。
    接下来就是状态转移了。我们可以先跑一遍dfs,找出以i节点为根的子树中共有多少个黑苹果。如果为零,那这一段就不用搜了。因为这一棵树中反正也不能砍,对结果没有影响。
    首先我们每向上走一层,把少砍一刀的情况加入砍全的情况,f[v][0]+=f[v][1](v为i的合法儿子)。接下来分两种情况,如果i节点为红,f[i][0]=∏f[v][0](v为i的合法儿子)。设sum=∏f[v][0],f[i][1]=Σ(sum/f[v][0]*f[v][1])。如果节点为黑,我们只考虑i的儿子,所以只存在f[i][1]=∏f[v][0].
    最后输出f[1][1]。
    注意时刻取模,sum/f[v][0]是用逆元。

附上代码

#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<vector>
using namespace std;
struct tree{
    int u,v,next;
}l[301000];
long long f[101000][5],mod=1000000007;
int lian[101000],e=0,n,fa[101000],size[101000],a[101000];
void bian(int,int);
void dfs(int);
void dp(int);
long long ksm(long long,long long);
int main()
{
    scanf("%d",&n);
    for(int i=2;i<=n;i++)
    {
        int x;
        scanf("%d",&x);
        x+=1;
        bian(x,i);
        bian(i,x);
    }
    for(int i=1;i<=n;i++)
    {
        scanf("%d",&a[i]);
    }
    dfs(1);
    dp(1);
    printf("%lld",f[1][1]);
    return 0;
}
void bian(int x,int y)
{
    e++;
    l[e].u=x;
    l[e].v=y;
    l[e].next=lian[x];
    lian[x]=e;
}
void dfs(int x)
{
    if(a[x]!=0)
        size[x]+=1;
    for(int i=lian[x];i;i=l[i].next)
    {
        int v=l[i].v;
        if(v!=fa[x])
        {
            fa[v]=x;
            dfs(v);
            size[x]+=size[v];
        }
    }
}
void dp(int x)
{
    int num=0;
    vector<int> ve;
    if(a[x]==0)
    {
        long long sum=1;
        for(int i=lian[x];i;i=l[i].next)
        {
            int v=l[i].v;
            if(v==fa[x])
                continue;
            if(size[v]==0)
                continue;
            dp(v);
            num++;
            ve.push_back(v);
            f[v][0]+=f[v][1];
            sum*=f[v][0];
            sum%=mod;
        }
        f[x][0]=sum;
        for(int i=0;i<num;i++)
        {
            long long k=sum*ksm(f[ve[i]][0],mod-2)%mod;
            k*=f[ve[i]][1];
            k%=mod;
            f[x][1]+=k;
            f[x][1]%=mod;
        }
    }
    else
    {
        long long sum=1;
        for(int i=lian[x];i;i=l[i].next)
        {
            int v=l[i].v;
            if(v==fa[x])
                continue;
            if(size[v]==0)
                continue;
            dp(v);
            num++;
            ve.push_back(v);
            f[v][0]+=f[v][1];
            sum*=f[v][0];
            sum%=mod;
        }
        f[x][1]=sum;
    }
}
long long ksm(long long x,long long y)
{
    long long ans=1,z=x;
    while(y)
    {
        if((y&1)==1)
        {
            ans*=z;
            ans%=mod;
        }
        y=y>>1;
        z*=z;
        z%=mod;
    }
    return ans;
}
原文地址:https://www.cnblogs.com/hzoi-wangxh/p/7738629.html