51nod 1677: treecnt

题目是求一棵n节点树中对于C(n,k)颗子树,每棵子树为在n个节点中选不同的k个节点作为树的边界点,这样的所有子树共包含多少条边。

问题可以转化一下,对每一条边,不同的子树中可能包含可能不包含这条边,显然,只有子树那k个节点在该边的两侧均有分布时该边才被包含在子树中。所有边的被包含次数的和,即为answer。对于一条边的被包含次数,设该边两侧分别有a,b个节点,那么,该边被包含的次数为C(a+b,k)-C(a,k)-C(b,k)(也可以借助母函数函数求C(a,i)*C(b,k-i),i从1到min{a,b,k-1},结果一样)。

//dfs写的太搓了,调了半天才好。。。

题目链接

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 
 4 typedef long long LL;
 5 const LL mod=1e9+7;
 6 const LL M=1e5+3;
 7 
 8 LL fac[100005];            //阶乘
 9 LL inv_of_fac[100005];        //阶乘的逆元
10 
11 LL qpow(LL x,LL n)
12 {
13     LL ret=1;
14     for(; n; n>>=1)
15     {
16         if(n&1) ret=ret*x%mod;
17         x=x*x%mod;
18     }
19     return ret;
20 }
21 void init()
22 {
23     fac[1]=1;
24     for(int i=2; i<=M; i++)
25         fac[i]=fac[i-1]*i%mod;
26     inv_of_fac[M]=qpow(fac[M],mod-2);
27     for(int i=M-1; i>=0; i--)
28         inv_of_fac[i]=inv_of_fac[i+1]*(i+1)%mod;
29 }
30 LL C(LL a,LL b)
31 {
32     if(b>a) return 0;
33     if(b==0) return 1;
34     return fac[a]*inv_of_fac[b]%mod*inv_of_fac[a-b]%mod;
35 }
36 /////////////////////////////////////////////////////////////
37 vector<int> adj[M];
38 int vis[M];
39 LL n,k,ans,du[M],hh;
40 void init1()
41 {
42     ans=0;
43     memset(vis,0,sizeof(vis));
44     memset(du,0,sizeof(du));
45     du[1]=n;
46     hh=C(n,k);
47     for(int i=1; i<=n; i++)
48         adj[i].clear();
49 }
50 LL dfs(int s)
51 {
52     if(adj[s].size()==1&&s!=1) return du[s]=1;
53     if(du[s]&&s!=1)    return du[s];
54     vis[s]=1;
55     LL ret,cnt=0;
56     for(int i=0; i<adj[s].size(); i++)
57     {
58         if(!vis[adj[s][i]])
59         {
60 //            printf("%d -> %d
",s,adj[s][i]);
61             cnt+=dfs(adj[s][i]);
62             ans=(ans+hh-C(dfs(adj[s][i]),k)-C(n-dfs(adj[s][i]),k))%mod;
63         }
64     }
65     return du[s]=cnt+1;
66 }
67 
68 int main()
69 {
70     init();
71     while(~scanf("%lld%lld",&n,&k))
72     {
73         init1();
74         for(int i=1; i<n; i++)
75         {
76             LL u,v;
77             scanf("%d%d",&u,&v);
78             adj[u].push_back(v);
79             adj[v].push_back(u);
80         }
81         dfs(1);
82 //        for(int i=1; i<=n; i++)
83 //            printf("%d:%lld=========
",i,du[i]);
84 //        for(int i=1; i<=n; i++)
85 //        {
86 //            printf("i=%d:
",i);
87 //            for(int j=0; j<adj[i].size(); j++)
88 //                printf("%d ",adj[i][j]);
89 //            puts("");
90 //        }
91         printf("%lld
",(ans+mod)%mod);
92     }
93 }

// 2017.8.15 更

回头翻一下之前自己写的博客,发现连个dfs都写这么挫,就算这样居然也有人看。重新改了一下代码贴在下面。

#include<bits/stdc++.h>
using namespace std;

typedef long long LL;
const LL mod=1e9+7;
const LL M=1e5+3;

LL fac[M+5];            //阶乘
LL inv_of_fac[M+5];        //阶乘的逆元

LL qpow(LL x,LL n)
{
    LL ret=1;
    for(; n; n>>=1)
    {
        if(n&1) ret=ret*x%mod;
        x=x*x%mod;
    }
    return ret;
}
void init()
{
    fac[1]=1;
    for(int i=2; i<=M; i++)
        fac[i]=fac[i-1]*i%mod;
    inv_of_fac[M]=qpow(fac[M],mod-2);
    for(int i=M-1; i>=0; i--)
        inv_of_fac[i]=inv_of_fac[i+1]*(i+1)%mod;
}
LL C(LL a,LL b)
{
    if(b>a) return 0;
    if(b==0) return 1;
    return fac[a]*inv_of_fac[b]%mod*inv_of_fac[a-b]%mod;
}
/////////////////////////////////////////////////////////////
vector<int> adj[M];
LL n,k,ans,hh;
void init1()
{
    ans=0;
    hh=C(n,k);
    for(int i=1; i<=n; i++)
        adj[i].clear();
}

LL dfs(int s,int pre)
{
    LL ret=1;
    for(int i=0; i<adj[s].size(); i++)
    {
        if(adj[s][i]==pre) continue;
        LL t=dfs(adj[s][i],s);
        ret+=t;
        ans=(ans+hh-C(t,k)-C(n-t,k))%mod;
    }
    return ret;
}

int main()
{
    init();
    while(~scanf("%lld%lld",&n,&k))
    {
        init1();
        for(int i=1; i<n; i++)
        {
            LL u,v;
            scanf("%d%d",&u,&v);
            adj[u].push_back(v);
            adj[v].push_back(u);
        }
        dfs(1,-1);
        printf("%lld
",(ans+mod)%mod);
    }
}
原文地址:https://www.cnblogs.com/Just--Do--It/p/6103326.html