tqc大佬代码暂存

大家快来膜呀~~~

#include<bits/stdc++.h>

#define int long long
const int P = 998244353;
const int G = 3;

using namespace std;

const int N=400005;

int n,f[N],g[N];
int limit,rev[N],a[N],b[N],c[N],d[N],e[N];

inline int ksm ( int x , int y ) {
    int ret = 1;
    while ( y ) {
        if ( y & 1 ) 
            ret = 1ll * ret * x % P;
        x = 1ll * x * x % P;
        y >>= 1;
    }
    return ret % P;
}
inline void init ( int n ) {
    limit = 1;
    int t = 0;
    while ( limit < n ) {
        limit <<= 1;
        t++;
    }
    for ( int i = 1 ; i < limit ; i++ ) 
        rev[i] = ( rev[i >> 1] >> 1 ) | ( i & 1 ) << ( t - 1 );
    return;
}
inline void NTT ( int * A , int limit , int tp ) {
    for ( int i = 0 ; i < limit ; i++ ) 
        if ( i < rev[i] ) 
            swap ( A[i] , A[rev[i]] );
    for ( int m = 1 ; m < limit ; m <<= 1 ) {
        int Wn = ksm ( G , ( P - 1 ) / ( m << 1 ) );
        for ( int j = 0 ; j < limit ; j += m << 1 ) {
            int w = 1 , x , y;
            for ( int k = 0 ; k < m ; k++ ) {
                x = A[j + k] % P;
                y = w * A[j + k + m] % P;
                A[j + k] = ( x + y ) % P;
                A[j + k + m] = ( x - y + P ) % P;
                w = 1ll * w * Wn % P;
            }
        }
    }
    if ( tp == -1 ) {
        reverse ( A + 1 , A + limit );
        int inv = ksm ( limit , P - 2 );
        for ( int i = 0 ; i < limit ; i++ ) 
            A[i] = 1ll * A[i] * inv % P;
    }
    return;
}
inline void mul ( int *f , int *g , int len ) {
    init ( len );
    NTT ( f , limit , 1 );
    NTT ( g , limit , 1 );
    for ( int i = 0 ; i < limit ; i++ ) 
        f[i] = 1ll * f[i] * g[i] % P;
    NTT ( f , limit , -1 );
    return;
}
void getinv ( int *f , int *g , int len ) {
    if ( len == 1 ) {
        g[0] = ksm ( f[0] , P - 2 );
        return;
    }
    getinv ( f , g , len + 1 >> 1 );
    init ( len << 1 );
    for ( int i = 0 ; i < len ; i++ ) 
        c[i] = f[i];
    for ( int i = len ; i < limit ; i++ ) 
        c[i] = 0;
    NTT ( c , limit , 1 );
    NTT ( g , limit , 1 );
    for ( int i = 0 ; i < limit ; i++ ) 
        g[i] = ( 1ll * 2 - 1ll * g[i] * c[i] % P + P ) % P * g[i] % P;
    NTT ( g , limit , -1 );
    for ( int i = len ; i < limit ; i++ ) 
        g[i] = 0;
    return;
}
void getdev ( int *f , int *g , int len ) {
    for ( int i = 1 ; i < len ; i++ ) 
        g[i - 1] = 1ll * i * f[i] % P;
    g[len - 1] = 0;
    return;
}
void getinvdev ( int *f , int *g , int len ) {
    for ( int i = 1 ; i < len ; i++ ) 
        g[i] = 1ll * f[i - 1] * ksm ( i , P - 2 ) % P;
    g[0] = 0;
    return;
}
void getln ( int *f , int *g , int len ) {
    memset ( a , 0 , sizeof ( a ) );
    memset ( b , 0 , sizeof ( b ) );
    getdev ( f , a , len );
    getinv ( f , b , len );
    mul ( a , b , len << 1 );
    getinvdev ( a , g , len );
    return;
}
void getexp ( int *f , int *g , int len ) {
    if ( len == 1ll ) {
        g[0] = 1;
        return;
    }
    getexp ( f , g , len + 1 >> 1 );
    init ( len << 1 );
    for ( int i = 0 ; i < ( len << 1 ) ; i++ ) 
        d[i] = e[i] = 0;
    getln ( g , d , len );
    for ( int i = 0 ; i < len ; i++ ) 
        e[i] = f[i];
    NTT ( g , limit , 1 );
    NTT ( d , limit , 1 );
    NTT ( e , limit , 1 );
    for ( int i = 0 ; i < limit ; i++ ) 
        g[i] = 1ll * ( 1 - d[i] + e[i] + P ) * g[i] % P;
    NTT( g , limit , -1 );
    for ( int i = len ; i < limit ; i++ )
        g[i] = 0; 
    return;
}
int ans[N] , k;
void getksm ( int *f , int *ans , int k , int len ) {
    getln ( f , g , len );
    for ( int i = 0 ; i < n ; i++ ) 
        g[i] = 1ll *g[i] * k % P;
    getexp ( g , ans , len ); 
    for ( int i = 0 ; i < limit ; i++ ) 
        g[i] = 0;
    return;
} 
void getsqrt ( int *f , int *ans , int prime , int len ) {
    int inv = ksm ( 2 , P - 2 );
    getln ( f , g , len );
    for ( int i = 0 ; i < n ; i++ ) 
        g[i] = g[i] * inv % P;
    getexp ( g , ans , len );
    for ( int i = 0 ; i < limit ; i++ ) 
        g[i] = 0;
    return;
} 
signed main ( void ) {
    scanf ( "%lld" , &n );
    for ( int i = 0 ; i < n ; i++ ) 
        scanf ( "%lld" , &f[i] );
    getsqrt ( f , ans , 2 , n );
    for ( int i = 0 ; i < n ; i++ ) 
        printf ( "%lld " , ans[i] );
    return 0;
}
原文地址:https://www.cnblogs.com/TheRoadToTheGold/p/14197615.html