洛谷 P5276 模板题(uoi)

这题挺恶心的。

首先一颗树的时候点分加卷积统计答案,注意合并子树时按深度从小到大合并,否则复杂度就爆了。
我偷懒用size从小到大合并,复杂度应该还是两个log.

然后考虑万恶的环。
先随便删掉环上一条边,按照树统计一下答案。
然后考虑
必须经过环上该条边的答案但又不经过整个环的答案。

考虑再钦定一条边不经过,算答案。

然后递归做就行了。

最后加上经过整个环的答案。
时间复杂度(O(n log^2(n)))

// luogu-judger-enable-o2
// luogu-judger-enable-o2
#include <bits/stdc++.h>

using namespace std;

typedef vector<int> poly;
typedef long long ll;
poly a,b;
const int P=1<<17;
const int M=998244353;
const int G=3; 
int rev[P],w[P];
namespace{
    int add(int x,int y){
        return (x+=y)>=M?x-M:x;
    }
    int sub(int x,int y){
        return (x-=y)<0?x+M:x;
    }
    int mul(int x,int y){
        return (ll)x*y%M;
    }
    int fp(int x,int y){
        int ret=1;
        for (; y; y>>=1,x=mul(x,x))
            if (y&1) ret=mul(ret,x);
        return ret;
    }
}
int inv2[30];
void init(int len){
    for (int i=1; i<len; i<<=1){
        w[i]=1;
        if (i>1) w[i+1]=fp(G,(M-1)/(i<<1));
        for (int j=2; j<i; ++j) w[i+j]=mul(w[i+j-1],w[i+1]);
        //cerr<<w[i]<<" "<<w[i+1]<<" "<<w[i+2]<<endl;
    }
    inv2[0]=1;
    inv2[1]=499122177;
    int bit=1;
    for (int i=4; i<=len; i<<=1){
        ++bit;
        inv2[bit]=mul(inv2[bit-1],inv2[1]);
    }
}

void NTT(int *a,int len){
    for (int i=0; i<len; ++i) if (i<rev[i]) swap(a[i],a[rev[i]]);
    for (int i=1; i<len; i<<=1){
        for (int j=0; j<len; j+=(i<<1)){
            int *l=a+j,*b=l+i,*ww=w+i;
            for (int k=0; k<i; ++k){
                int y=mul(*b,*(ww++));
                (*b)=(*l)-y;
                (*b)+=((*b)>>31)&M;
                ++b;
                (*l)+=y-M;
                (*l)+=((*l)>>31)&M;
                ++l;
            }
        } 
    }
} 

void INTT(int *a,int len,int bit){
    reverse(a+1,a+len);
    NTT(a,len);
    int ni=inv2[bit];
    for (int i=0; i<len; ++i) a[i]=mul(a[i],ni); 
}

poly operator *(const poly &u,const poly &v){
    //cerr<<"mulfff"<<endl;
    if ((ll)u.size()*v.size()<=(u.size()+v.size())*30){
        //cerr<<u.size()<<" "<<v.size()<<endl;
        poly ret(u.size()+v.size()-1);
        for (int i=0; i<u.size(); ++i)
            for (int j=0; j<v.size(); ++j)
            ret[i+j]=add(ret[i+j],mul(u[i],v[j]));
        return ret;
    }
    //cerr<<"?????"<<endl;
    a=u;
    b=v;
    int len=1;
    int bit=0;
    for (; len<a.size()+b.size()-1; len<<=1) ++bit;
    //cerr<<"len"<<len<<" "<<u.size()<<" "<<v.size()<<endl;
    a.resize(len); b.resize(len);
    for (int i=0; i<len; ++i) rev[i]=rev[i>>1]>>1|((i&1)?len>>1:0);
    NTT(a.data(),len);
    NTT(b.data(),len);
    for (int i=0; i<len; ++i) a[i]=mul(a[i],b[i]);
    INTT(a.data(),len,bit);
    a.resize(u.size()+v.size()-1);
    return a;
}
poly operator +(const poly &u,const poly &v){
    poly ret(max(u.size(),v.size()));
    for (int i=0; i<ret.size(); ++i){
        int x=(i<u.size()?u[i]:0);
        int y=(i<v.size()?v[i]:0);
        ret[i]=add(x,y);
    }
    return ret;
}
void operator +=(poly &u,const poly &v){
    //cerr<<"????"<<endl;
    if (u.size()<v.size()) u.resize(v.size());
    for (int i=0; i<v.size(); ++i) u[i]=add(u[i],v[i]);
    //cerr<<"!!!!"<<endl;
}

ostream& operator <<(ostream& out,const poly &a){
    for (auto i:a) out<<i<<" ";
    return out<<endl;
}

void test(){
    poly a({1,2}),b({2,3,2333});
    a=a*b;
    cerr<<a;
} 

int n,m;
const int N=100010;
poly ans;
namespace solve1{

    vector<int> e[N];
    int sz[N],tmp[N],rt;
    void Dfs(int x,int fa){
        sz[x]=1;
        for (auto i:e[x])
            if (i!=fa){
                Dfs(i,x);
                sz[x]+=sz[i];
            }
    }
    int calc(int y,int x){
        return max(y-sz[x],tmp[x]);
    }
    void Getrt(int x,int fa,const int totsize){
        //cerr<<"Getrt"<<x<<" "<<fa<<endl;
        tmp[x]=0;
        for (auto i:e[x])
            if (i!=fa){
                tmp[x]=max(sz[i],tmp[x]);
                Getrt(i,x,totsize);
            }
        if (calc(totsize,x)<calc(totsize,rt)) rt=x;
    }
    
    void Getdeep(int x,int fa,poly &a,int nowdis){
        //cerr<<"Getdeep"<<x<<" "<<fa<<endl;
        ++a[nowdis];
        for (auto i:e[x])
            if (i!=fa){
                //cerr<<"???"<<i<<endl;
                Getdeep(i,x,a,nowdis+1);
            }
    } 
    void df(int x){
        //int t=clock();
        Dfs(x,0);
        rt=x;
        int bbb=sz[rt];
        //cerr<<"bbb"<<bbb<<endl;
        Getrt(x,0,sz[x]);
        //cerr<<"rt"<<rt<<" "<<sz[rt]<<endl;
        //getchar();
        for (auto i:e[rt])
            if (sz[i]>sz[rt]) sz[i]=bbb-sz[rt];
        sort(e[rt].begin(),e[rt].end(),[&](int x,int y){
            return sz[x]<sz[y];
        });
        //cerr<<"???"<<endl;
        poly c,b(1,1);
        for (auto i:e[rt]){
            //cerr<<"son"<<i<<" "<<sz[i]<<endl;
            c.clear();
            c.resize(sz[i]+1);
            Getdeep(i,rt,c,1);
            //cerr<<"Gend"<<c<<endl;
            //cerr<<"mulend"<<c.size()<<endl;
            ans+=b*c;
            //cerr<<"AAAA"<<endl;
            b+=c;
        }
        //cerr<<"ans"<<ans<<endl;
        int fkrt=rt;
        for (auto i:e[fkrt]){
            e[i].erase(find(e[i].begin(),e[i].end(),fkrt));
            df(i);
        }
        //cerr<<"dend"<<endl;
    }
    
    void main(int *fa){
        for (int i=1; i<=n; ++i)
            if (fa[i]){
                //cerr<<"faf"<<i<<" "<<fa[i]<<endl;
                e[fa[i]].push_back(i);
                e[i].push_back(fa[i]);
            }
        df(1);
        ans[0]=n;
    }	
}
int vis[N];
int fa[N],k,f;
vector<int> g[N]; 
void noloop(int x){
    //cerr<<"noloop"<<x<<endl;
    vis[x]=1;
    for (auto i:g[x])
        if (!vis[i]){
            fa[i]=x;
            noloop(i);
        }
}
void Output(poly &a,int k,int f){
    a.resize(k+1);
    int ans1=0;
    for (auto i:a) ans1=add(ans1,i);
    cout<<ans1<<endl;
    if (f) cout<<a;
}

int main(){
    init(1<<17);
    ios::sync_with_stdio(0);
    cin.tie(0);
    test();
    cin.ignore(233,'
');
    cin>>n>>m>>k>>f;
    //n=100000; m=n-1;
    //k=100000; f=1;
    //cerr<<n<<" "<<m<<" "<<k<<" "<<f<<endl;
    for (int i=1; i<=m; ++i){
        int x,y;
        cin>>x>>y;
        //x=rand()%i+1; y=i+1;
        //cerr<<"add"<<x<<" "<<y<<endl;
        g[x].push_back(y);
        g[y].push_back(x);
    }
    noloop(1);
    //cerr<<"What's the fuck?"<<endl;
    solve1::main(fa);
    if (m==n-1){
        Output(ans,k,f);
        return 0;
    }
    //Output(ans,k,f);
    poly s;
    int pp=0;
    function<void(int,int)> findloop=[&](int x,int f){
        vis[x]=2;
        s.push_back(x);
        for (auto i:g[x])
            if (i!=f){
                if (vis[i]!=2) findloop(i,x);
                else pp=i;
                if (pp) return;
            }
        s.pop_back();
    };
    findloop(1,0);
    s.erase(s.begin(),find(s.begin(),s.end(),pp));
    //cerr<<"cut"<<s.front()<<" "<<s.back()<<endl;
    g[s.front()].erase(find(g[s.front()].begin(),g[s.front()].end(),s.back()));
    g[s.back()].erase(find(g[s.back()].begin(),g[s.back()].end(),s.front()));
    function<void(int,int,poly&,int)> ddd=[&](int x,int fa,poly &c,int dis){
        if (dis>=c.size()) c.resize(dis+1);
        ++c[dis];
        for (auto j:g[x])
            if (j!=fa) ddd(j,x,c,dis+1);
    };
    auto Fakeadd=[&](poly &u,const poly &v,int len){
        if (u.size()<v.size()+len) u.resize(v.size()+len);
        for (int i=0; i<v.size(); ++i) u[i+len]=add(u[i+len],v[i]);
    };
    auto waylength=[&](int x,int y){
        return y-x;
    };
    function<void(int,int,int)> solve=[&](int l,int r,int nowlen){
        //cerr<<"solve"<<l<<" "<<r<<" "<<nowlen<<endl;
        if (l==r) return;
        //valid l~r point
        int mid=(l+r)>>1;
        //cut mid mid+1
        //cerr<<"cut"<<s[mid]<<" "<<s[mid+1]<<endl;
        g[s[mid]].erase(find(g[s[mid]].begin(),g[s[mid]].end(),s[mid+1]));
        g[s[mid+1]].erase(find(g[s[mid+1]].begin(),g[s[mid+1]].end(),s[mid]));
        //cerr<<"!!!"<<endl;
        poly c,d;
        ddd(s[l],0,c,0);
        ddd(s[r],0,d,0);
        //cerr<<"???"<<c<<" "<<d<<" "<<"noewln"<<nowlen<<endl;
        Fakeadd(ans,c*d,nowlen);
        //cerr<<"ANS"<<ans<<endl;
        solve(l,mid,waylength(mid,r)+nowlen);
        solve(mid+1,r,waylength(l,mid+1)+nowlen);
    };
    solve(0,s.size()-1,1);
    for (auto i:s){
        poly c;
        ddd(i,0,c,0);
        c[0]=0;
        Fakeadd(ans,c,s.size());
    }
    ans[s.size()]=add(ans[s.size()],1);
    Output(ans,k,f);
}
原文地址:https://www.cnblogs.com/Yuhuger/p/10621956.html