[vijos 1770]大内密探

描述

在古老的皇宫中,有N个房间以及N-1条双向通道,每条通道连接着两个不同的房间,所有的房间都能互相到达。皇宫中有许多的宝物,所以需要若干个大内密探来守护。一个房间被守护当切仅当该房间内有一名大内密探或者与该房间直接相邻的房间内有大内密探。

现在身为大内密探零零七的你想知道要把整个皇宫守护好至少需要多少名大内密探以及有多少种安排密探的方案。两种方案不同当且仅当某个房间在一种方案有密探而在另一个方案内没有密探。

格式

输入格式

第一行一个正整数N.(1<=N<=100000)
后面N-1行,每行两个正整数a和b,表示房间a和房间b之间有一条无向通道。

房间的编号从1到N

输出格式

第一行输出正整数K,表示最少安排的大内密探。

第二行输出整数S,表示有多少种方案安排最少的密探,由于结果可能较大,请输出方案数mod 1000000007的余数。

样例1

样例输入1[复制]
7
2 1
3 1
4 2
5 1
6 2
7 6
样例输出1[复制]
3
4
 
首先这是两个子问题
第一问是比较基本的树形dp
f[i][0] i的子树中,i被覆盖但不取i的方案数
f[i][1] i被覆盖,没有其他限制的方案数
f[i][2] i点不取且i不被儿子覆盖的方案数
第二问 g[i][0~3]表示上面三个对应的方案数
第一问的转移
f[i][1]=sum_{each~son_i}min(f[j][0],f[j][1],f[j][2])
f[i][2]=sum_{each~son}f[j][0]
然后为了方便g的计算,我们一会再讨论f[i][0]的转移
那么如何来搞这个g呢。。
首先g[i][1,2]随便加法乘法原理算算就好了,但是g[i][0]比较蛋疼
首先一般我们会这样算f[i][0]
f[i][0]=min_{eachson}f[j][1]+sum_{eachson}min(f[k][0],f[k][1])
这样会在转移g的时候出现重复状态
然后我们发现可以对后面一坨维护一下前缀和和后缀和避免重复
实现细节看代码
#include<map>
#include<stack>
#include<queue>
#include<cstdio>
#include<string>
#include<vector>
#include<cstring>
#include<complex>
#include<iostream>
#include<assert.h>
#include<algorithm>
using namespace std;
using namespace std;
#define pb push_back
#define inf 1001001001
#define infll 1001001001001001001LL
#define FOR0(i,n) for(int (i)=0;(i)<(n);++(i))
#define FOR1(i,n) for(int (i)=1;(i)<=(n);++(i))
#define mp make_pair
#define pii pair<int,int>
#define ll long long
#define ld double
#define vi vector<int>
#define SZ(x) ((int)((x).size()))
#define fi first
#define se second
#define RI(n) int (n); scanf("%d",&(n));
#define RI2(n,m) int (n),(m); scanf("%d %d",&(n),&(m));
#define RI3(n,m,k) int (n),(m),(k); scanf("%d %d %d",&(n),&(m),&(k));
template<typename T,typename TT> ostream& operator<<(ostream &s,pair<T,TT> t) {return s<<"("<<t.first<<","<<t.second<<")";}
template<typename T> ostream& operator<<(ostream &s,vector<T> t){FOR0(i,sz(t))s<<t[i]<<" ";return s; }
#define dbg(vari) cerr<<#vari<<" = "<<(vari)<<endl
#define all(t) t.begin(),t.end()
#define FEACH(i,t) for (typeof(t.begin()) i=t.begin(); i!=t.end(); i++)
#define TESTS RI(testow)while(testow--)
#define FORZ(i,a,b) for(int (i)=(a);(i)<=(b);++i)
#define FORD(i,a,b) for(int (i)=(a); (i)>=(b);--i)
#define gmax(a,b) (a)=max((a),(b))
#define gmin(a,b) (a)=min((a),(b))
#define ios0 ios_base::sync_with_stdio(0)
#define Ri register int
#define gc getchar()
#define il inline
il int read(){
    bool f=true;
    Ri x=0;char ch;
    while(!isdigit(ch=gc))
        if(ch=='-')f=false;
    while(isdigit(ch)){
        x=(x<<1)+(x<<3)+ch-'0';
        ch=gc;
    }
    return f?x:-x;
}
#define gi read()
#define FO(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout);
#define childs(x,i) for(int i=last[x]; i; i=e[i].next)
const int N=100005,mod=1000000007;
int last[N],cnt,n,l,f[N][3],r[N],st[N];
ll d[N][3],suml,sumr[N];
struct edge{
    int to,next;
}e[230000];
ll mul(ll a,ll b){
    return ((a%mod)*(b%mod))%mod;
}
ll Plus(ll a,ll b){
    return (a%mod+b%mod)%mod;
}
void insert(int u, int v) {
    e[++cnt].next=last[u];last[u]=cnt;e[cnt].to=v;
    e[++cnt].next=last[v];last[v]=cnt;e[cnt].to=u;
}

void dfs(int x,int fa) {
    int t1=1,t2=0,s1=1,s2=1,ch;
    childs(x,i) 
        if((ch=e[i].to)!=fa) {
            ll T=0;
            dfs(ch,x);
            int mn=min(min(f[ch][0],f[ch][1]),f[ch][2]);
            FOR0(j,3)
                if(f[ch][j]==mn) T+=d[ch][j];
            s1=mul(s1,T);
            T=0;
            s2=mul(s2,d[ch][0]);
             t1+=mn;
            t2+=f[ch][0];
        }
    f[x][1]=t1;
    f[x][2]=t2;
    d[x][1]=s1;
    d[x][2]=s2;
    int sz=0;
    childs(x,i) if(e[i].to!=fa) st[++sz]=e[i].to;
    r[sz+1]=0; 
    sumr[sz+1]=1;suml=1;l=0;
    FORD(i,sz,1) {
        ch=st[i];
        ll T=0;
        int mn=min(f[ch][0],f[ch][1]); 
        FOR0(j,2)if(f[ch][j]==mn) T+=d[ch][j];
        r[i]=r[i+1]+mn;
        sumr[i]=mul(sumr[i+1],T);
    }
    f[x][0]=N;
    FOR1(i,sz){
        int fyb=l+f[st[i]][1]+r[i+1];
        if(fyb<f[x][0])    f[x][0]=fyb,d[x][0]=mul(d[st[i]][1],mul(suml,sumr[i+1]));
        else if(fyb==f[x][0])     d[x][0]=Plus(d[x][0],mul(d[st[i]][1],mul(suml,sumr[i+1])));
        if(f[st[i]][0]==N) break;
        l+=f[st[i]][0];
        suml=mul(suml,d[st[i]][0]);
    }
}
int main() {
    RI(n);
    FOR1(i,n-1)
        insert(gi,gi);
    int root=1;
    dfs(root,-1);
    int ans1=min(f[root][1],f[root][0]),ans2=0;    
    if(ans1==f[root][0]) 
        ans2=Plus(ans2,d[root][0]);
    if(ans1==f[root][1]) 
        ans2=Plus(ans2,d[root][1]);
    printf("%d
%d
",ans1,ans2);
    return 0;
}
原文地址:https://www.cnblogs.com/chouti/p/5804044.html