HDU6446 Tree and Permutation(树、推公式)

题意:

给一棵N个点的树,对应于一个长为N的全排列,对于排列的每个相邻数字a和b,他们的贡献是对应树上顶点a和b的路径长,求所有排列的贡献和

思路:

对每一条边,边左边有x个点,右边有y个点,x+y=n,权值为w,则答案为$displaystyle sum 2xyw(n-1)!=sum 2x(n-x)w(n-1)!$

其中每条边的x可以通过一次dfs找子树节点个数

比赛的时候找不到怎么又存权值又存子树节点个数的方法,瞎瘠薄调最后卡着空间时间过的,后来看别人代码学到了新方法:

比赛代码:

842MS 14012K
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<string>
#include<stack>
#include<queue>
#include<deque>
#include<set>
#include<vector>
#include<map>
#include<functional>
    
#define fst first
#define sc second
#define pb push_back
#define mem(a,b) memset(a,b,sizeof(a))
#define lson l,mid,root<<1
#define rson mid+1,r,root<<1|1
#define lc root<<1
#define rc root<<1|1
#define lowbit(x) ((x)&(-x)) 

using namespace std;

typedef double db;
typedef long double ldb;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PI;
typedef pair<ll,ll> PLL;

const db eps = 1e-6;
const int mod = 1e9+7;
const int maxn = 2e5+2;
const int maxm = 2e6+100;
const int inf = 0x3f3f3f3f;
const db pi = acos(-1.0);
vector<int>v[maxn];
ll w[maxn];
int f[maxn];
ll ans = 0;
ll cnt[maxn];
ll dfs(int x, int fa){
    ll tmp = 1;
    int sz = v[x].size();
    if(sz==1&&fa!=-1)return f[x]=1;
    for(int i = 0; i < sz; i++){
        if(v[x][i]!=fa)tmp += dfs(v[x][i], x);
    }
    return f[x]=tmp;
}
void init(){
    cnt[1] = 1;
    for(int i = 2; i < 100000+1; i++){
        cnt[i] = cnt[i-1]*i;
        cnt[i]%=mod;
    }
    return;
}
inline int read(){
    int num;
    char ch;
    while((ch=getchar())<'0' || ch>'9');
    num=ch-'0';
    while((ch=getchar())>='0' && ch<='9'){
        num=num*10+ch-'0';
    }
    return num;
}
struct Edge{
    int u;
    int v;
    int w;
}edge[maxn];
int top = 1;
void addedge(int u, int v, int w){
    edge[top].u = u;
    edge[top].v = v;
    edge[top++].w = w;
}

int main() {
    int n;
    init();
    while(scanf("%d", &n)!=EOF){
        top = 1;
        ans = 0;
        mem(f, 0);
        for(int i = 1; i <= n; i++){
            v[i].clear();
        }
        for(int i = 1; i <= n-1; i++){
            int x, y;
            x=read();
            y=read();
            v[x].pb(y);
            v[y].pb(x);
            int c;
            c = read();
            addedge(x, y, c);
            //mp[x][y] = mp[y][x] = c;
        }
        //ddfs(1, -1);
        
        dfs(1, -1);
        f[1] = 0;
        for(int i = 1; i <= n-1; i++){
            int x = edge[i].u;
            int y = edge[i].v;
            int ww = edge[i].w;
            //printf("%d %d %d %d
", x, f[x], y, f[y]);
            if(x==1){
                w[y] = ww;
                continue;
            }
            else if(y==1){
                w[x] = ww;
                continue;
            }
            if(f[x] < f[y]){
                w[x] = ww;
                continue;
            }
            else{
                w[y] = ww;
                continue;
            }
        }
        
         // for(int i = 1; i <= n; i++){
         //     printf("%d %d
", i, w[i]);
         // }
        for(int i = 2; i <= n; i++){
            ll tmp = 1;
            tmp = 2*f[i];
            tmp%=mod;
            tmp *= (n-f[i]);
            tmp%=mod;
            tmp *= cnt[n-1];
            tmp %= mod;
            tmp *= (ll)w[i];
            ans += tmp;
            ans %= mod;
        }
        printf("%I64d
", ans);
    }    

    return 0;
}

赛后代码:

499MS 18972K
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<string>
#include<stack>
#include<queue>
#include<deque>
#include<set>
#include<vector>
#include<map>
#include<functional>
    
#define fst first
#define sc second
#define pb push_back
#define mem(a,b) memset(a,b,sizeof(a))
#define lson l,mid,root<<1
#define rson mid+1,r,root<<1|1
#define lc root<<1
#define rc root<<1|1
#define lowbit(x) ((x)&(-x)) 

using namespace std;

typedef double db;
typedef long double ldb;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PI;
typedef pair<ll,ll> PLL;

const db eps = 1e-6;
const int mod = 1e9+7;
const int maxn = 2e5+2;
const int maxm = 2e6+100;
const int inf = 0x3f3f3f3f;
const db pi = acos(-1.0);

ll d[maxn];
struct Edge{
    int u;
    ll w;
};
vector<Edge>v[maxn];
ll w[maxn];
int f[maxn];
int dfs(int x, int fa){
    int sz = v[x].size();
    int ans = 1;
    for(int i = 0; i < sz; i++){
        if(v[x][i].u==fa)w[x]=v[x][i].w;
        else ans+=dfs(v[x][i].u, x);
    }
    if(sz==1&&fa!=-1)return f[x] = 1;
    return f[x] = ans;
}
int main() {
    int n;
    d[1] = 1;
    for(int i = 2; i < 100000 + 100; i++){
        d[i] = d[i-1] *i;
        d[i] %= mod;
    }    
    while(~scanf("%d", &n)){
        for(int i = 1; i <= n; i++)v[i].clear();
            //mem(f, 0);
        for(int i = 0; i < n-1; i++){
            int x, y;
            ll l;
            scanf("%d %d %I64d", &x, &y, &l);
            Edge t1, t2;
            t1.u=y;t2.u=x;
            t1.w=t2.w=l;
            v[x].pb(t1);
            v[y].pb(t2);
        }
        ll ans = 0;
        dfs(1, -1);
        for(int i = 2; i <= n; i++){
            ll tmp = d[n-1];
            tmp %= mod;
            tmp *= (ll)2*f[i]*(n-f[i]);
            tmp %= mod;
            tmp *= w[i];
            tmp %= mod;
            ans += tmp;
            ans %= mod;
        }
        printf("%I64d
", ans);
    }
    return 0;
}
原文地址:https://www.cnblogs.com/wrjlinkkkkkk/p/9537922.html