hungary

更正:模数1000000007

/*
  最大匹配求p=1的情况能得30分 
  正解:树形DP,f[i][0/1]表示i节点向下连的那条边选或不选时的最大值 
*/
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#define maxn 100010
#define mod 1000000007
#define ll long long
using namespace std;
ll T,P,n,head[maxn],num,f[maxn][2],g[maxn][2],L,R[maxn],l,r[maxn],son[maxn];
struct node
{
    ll v,pre;
}e[maxn*2];
ll init()
{
    ll x=0,f=1;char s=getchar();
    while(s<'0'||s>'9'){if(s=='0')f=-1;s=getchar();}
    while(s>='0'&&s<='9'){x=x*10+s-'0';s=getchar();}
    return x*f;
}
void Add(ll from,ll to)
{
    num++;e[num].v=to;
    e[num].pre=head[from];
    head[from]=num;
}
void Clear()
{
    num=0;
    memset(f,0,sizeof(f));
    memset(g,0,sizeof(g));
    memset(head,0,sizeof(head));
}
void DP(ll now,ll from)
{
    g[now][0]=1;
    ll mx,sum;
    for(int i=head[now];i;i=e[i].pre)
    {
        ll v=e[i].v;
        if(v==from)continue;
        DP(v,now);//x不连儿子 儿子们可连可不连 
        mx=max(f[v][1],f[v][0]);sum=0;
        if(mx==f[v][1])sum+=g[v][1];
        if(mx==f[v][0])sum+=g[v][0];
        g[now][0]=g[now][0]*sum%mod;
        f[now][0]+=mx;
    }
    //x连某个儿子 这个不选 其他的连或者不连 
    L=0;l=1;ll S=0;
    for(int i=head[now];i;i=e[i].pre) 
        if(e[i].v!=from)son[++S]=e[i].v;
    R[S+1]=0;r[S+1]=1;
    for(int i=S;i>=1;i--)//预处理一个后缀和,R[i]表示i的后缀匹配数,r[i]为方案数 
    {
        ll v=son[i];sum=0;
        mx=max(f[v][1],f[v][0]);
        if(mx==f[v][1])sum+=g[v][1];
        if(mx==f[v][0])sum+=g[v][0];
        R[i]=R[i+1]+mx;
        r[i]=r[i+1]*sum%mod;
    }
    for(int i=1;i<=S;i++)//L是前缀和,不断累加 
    {
        ll v=son[i];//枚举now连的是哪个儿子 
        mx=L+f[v][0]+R[i+1]+1;
        if(mx>f[now][1])
        {
            f[now][1]=mx;
            g[now][1]=l*g[v][0]%mod*r[i+1]%mod;
        }
        else if(mx==f[now][1])
            g[now][1]=(g[now][1]+l*g[v][0]%mod*r[i+1]%mod)%mod;
        sum=0;
        mx=max(f[v][1],f[v][0]);
        if(mx==f[v][1])sum+=g[v][1];
        if(mx==f[v][0])sum+=g[v][0];
        l=l*sum%mod;L+=mx;
    }
}
int main()
{
    freopen("hungary.in","r",stdin);
    freopen("hungary.out","w",stdout);
    T=init();P=init();
    while(T--)
    {
        n=init();
        ll u,v;Clear();
        for(int i=1;i<n;i++)
        {
            u=init();v=init();
            Add(u,v);Add(v,u);
        }
        DP(1,0);ll sum,mx;
        mx=max(f[1][0],f[1][1]);sum=0;
        if(mx==f[1][0])sum+=g[1][0],sum%=mod;
        if(mx==f[1][1])sum+=g[1][1],sum%=mod;
        if(P==1)cout<<mx<<endl;
        if(P==2)cout<<mx<<" "<<sum<<endl;
    }
    return 0;
}
View Code
原文地址:https://www.cnblogs.com/harden/p/5934980.html