[HAOI2018]苹果树

嘟嘟嘟


这种计数大题就留给南方的计数神仙们做吧……


刚开始我一直想枚举点,考虑新加一个点在根节点的左右子树,以及左右子树大小怎么分配,但是这样太难计算新的点带来的贡献了。
后来lba又提示我枚举边,考虑每一条边的贡献。
这确实是一个好主意,枚举边的同时考虑边两侧的点数,但可怕的是我一直把他当成无根树来做,也就是忽略了树上打父子关系,导致少算了好多形态。
于是题解吵朝我挥了挥手。


既然是有根树,那么我们枚举每一个点,然后枚举的是这个点和他父亲的连边,这样就能不重不漏并且有顺序的枚举所有边了。


考虑点(i)这条边的贡献,就是(size _ i * (n - size_i))
(size_i)不确定,但根据题意是可以(O(n))枚举的。
我们枚举(size_i),算出当(size_i)一定时,这个子树以及子树外有多少种形态。


先考虑子树内:不算标号有(size_i !)种形态,因为第一个点只有一种连接方法,第二个点有两种,第三个点有三种……所以(size_i)个点就(size_i !)种。当形态固定时,考虑标号:因为在(i)子树内只可能是标号比(i)大的点,所以有(C_{n - i} ^ {size_i - 1})种。那么子树内的所有形态就是(size_i ! * C_{n - i} ^ {size_i - 1})


接下来我们考虑子树外:在生成点(i)的子树之前有(i!)种方式,然后我们考虑剩下的(n - i - size_i)个点的生成方式,为:((i - 1) * i * (i + 1) * ldots * (n - i - size_i - 1))
所以子树外的点的生成方式就是(i! *(i - 1) * i * (i + 1) * ldots * (n - i - size_i - 1) = (i - 1) * i * (n - size_i - 1)!)
那么答案就出来啦:

[ans = sum _ {i = 1} ^{n} sum _ {size = 1} ^ {n - i + 1} size _ i * (n - size_i) * size_i ! * C_{n - i} ^ {size_i - 1} * (i - 1) * i * (n - size_i - 1)! ]

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 2e3 + 5;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

int n, mod;

In ll inc(ll a, ll b) {return a + b >= mod ? a + b - mod : a + b;}

ll fac[maxn], C[maxn][maxn];
In void init()
{
  fac[0] = 1;
  for(int i = 1; i <= n; ++i) fac[i] = fac[i - 1] * i % mod;
  for(int i = 0; i <= n; ++i) C[i][0] = 1;
  for(int i = 1; i <= n; ++i)
    for(int j = 1; j <= i; ++j)
      C[i][j] = inc(C[i - 1][j - 1], C[i - 1][j]);
}

int main()
{
  n = read(), mod = read();
  init();
  ll ans = 0;
  for(int i = 2; i <= n; ++i)
    for(int j = 1; j <= n - i + 1; ++j)
      ans = inc(ans, fac[j] * C[n - i][j - 1] % mod * j % mod * (n - j) % mod * fac[n - j - 1] % mod * i % mod * (i - 1) % mod);
  write(ans), enter;
  return 0;
}
原文地址:https://www.cnblogs.com/mrclr/p/10611965.html