愤怒的小 N 题解报告

阴间题

第二道联赛引入多项式的题……

首先看到题目让求得序列是非常有规律的,考虑一个 dp

(dp_{i,j,0/1}) 表示前 (2^i) 关的 (x)(j) 次方之和,01 表示是否是奖励关

然后容易得到方程式

[dp[i][j][ai]=dp[i-1][j][ai] + sum_{u=0}^jdbinom{j}{u} imes 2^{u(i-1)} imes dp[i-1][j][ai operatorname{xor} 1] ]

然后打个表可以发现有如下性质

(i>j,dp[i][j][0]=dp[i][j][1])

然后考虑换种算法,就是 (dp[i][j] = sum_{u=0}^{2^i-1}u^j)

然后我们知道自然数的 k 次幂是一个 k+1 次多项式,所以可以插值 (O(k)) 求出每一个 (dp[i][j]),但这毫无意义。

考虑一个多项式函数的前缀和也是一个多项式函数,所以我们可以直接考虑对前缀进行一个插值。

(operatorname{Largrange}) 插值多带几个点进去也无所谓,反正不会算错。

所以 (O(k)) 预处理 (dp[i]) 表示前i位的多项式和,预处理七八百个就差不多了。

然后对于不满足上面那个多项式的直接 (dp) 暴力计算即可。

复杂度 (O(n+k^3))

我也不知道为什么我常数大的离谱,以至于必须开 O2 才能过。

#include<bits/stdc++.h>

using namespace std;

// #define INF 1<<30
#define int long long

template<typename _T>
inline void read(_T &x)
{
    x= 0 ;int f =1;char s=getchar();
    while(s<'0' ||s>'9') {f =1;if(s == '-')f =- 1;s=getchar();}
    while('0'<=s&&s<='9'){x = (x<<3) + (x<<1) + s - '0';s= getchar();}
    x*=f;
}

const int mod = 1e9 + 7;
const int np = 5e5 + 5;
int dp[505][505][2];
// dp[i][j][0/1] 表示 0 到 2^(i-1)-1 中 a/b 关卡的 j 次方和
// dp[i][j][0] = dp[i-1][j][0] + sum_{u=0}^jdbinom{j}{u}A^udp[i-1][j-u][1];

char s[np];
int k,tmp;
int a[np];
// int fac[np],inv[np];
int x[np],y[np];
int _2[np];//,qla[np];
int c[705][705];
inline int power(int a,int b)
{
    int res(1);
    while(b)
    {
        if(b & 1) res = res * a,res %= mod;
        a = a * a;
        a %= mod;
        b >>= 1;
    }
    return res;
}

// inline int c(int n,int m)
// {
//     return (((fac[n] * inv[n-m]) % mod) * inv[m])%mod;
// }

inline int largrange(int kth)
{
    int ans = 0;
    for(int i=1;i <= 505;i ++)
    {
        int ell_up = y[i];
        int ell_down = 1;
        for(int j=1;j <= 505;j ++)
        {
            if(i == j) continue;
            ell_up = ell_up *(kth - x[j]),ell_up %= mod;
            ell_down = ell_down * (x[i]-x[j]) % mod;
        }
        ans = ans + ell_up * power(ell_down,mod-2);
        ans %= mod;
    }
    return ans;
}

inline void solve()
{
    int presum = 0;
    for(int i=0;i <= 700;i ++)    
    {
        x[i] = i,y[i] = presum;
        // printf("%lld %lld
",x[i],y[i]);
        int sum = 0;
        for(int j=k-1;j >= 0;j--) sum = sum *i+a[j],sum %= mod;
        presum += sum;
        presum %= mod;
    }
}

signed main()
{
    c[0][0] = 1;
    for(int i=1;i <= 505;i ++)
    {
        c[i][0]=1;
        for(int j=1;j <= 505;j ++)
        {
            int &d = c[i][j];
            d = c[i-1][j] + c[i-1][j-1];
            d -= d>=mod?mod:0;
        }
    }
    scanf("%s",s);
    int len = strlen(s); 
    reverse(s,s+len);
    _2[0] = 1;   
    for(int i=1;i <= 5e5;i ++) _2[i] = _2[i-1] * 2,_2[i] -= _2[i]>=mod?mod:0;//mod;
    // for(int i=1;i <= len;i ++) n[i] = s[i]-'0';
    // read(n);
    read(k);
    for(int i=0;i < k;i ++) read(a[i]);
    dp[0][0][0] = 1;
    // for(int i=0;i < k; i++) dp[0][i][0] = 0;

    for(int i=1;i <=k ;i ++)
    {
        for(int j=0;j < k;j ++)
        {
            // for(int ai=0;ai <= 1;ai ++)
            // {
                int ai=0;
                int &d = dp[i][j][ai];
                d = dp[i-1][j][ai];// + 
                for(int u=0;u <= j;u ++)
                {
                    d += (((c[j][u] * _2[(i-1)*u])%mod) * dp[i-1][j-u][ai^1])%mod;
                    d -= d>=mod?mod:0;
                }
                //  printf("dp[%lld][%lld][%lld] = %lld
",i,j,ai,dp[i][j][ai]);
            // }
                ai= 1;
                int &d_ = dp[i][j][ai];
                d_ = dp[i-1][j][ai];// + 
                for(int u=0;u <= j;u ++)
                {
                    d_ += (((c[j][u] * _2[(i-1)*u])%mod) * dp[i-1][j-u][ai^1])%mod;
                    d_ -= d_>=mod?mod:0;
                }
        }
    }

    solve();

    int Ans=0;
    int la = 0;
//    printf("%lld %lld
",dp[2][3][0],dp[2][3][1]);
    int a_= 1;
    for(int q = len-1;q >=0 ;q --)
    {
        if(s[q] == '1')
        {	
            if(q > k){
                la += _2[q];
                la -= la >= mod?mod:0;
                a_^= 1;
                continue;
            }
            for(int i=0;i <= k; i++)
            {
                int tmp=0;//[i] = 0;
                int qla=1;
                for(int u=0;u <= i;u ++,qla *= la,qla%=mod)
                {
                	int op = a_==1?dp[q][i-u][1]:dp[q][i-u][0];
                    tmp += (((c[i][u] * qla)%mod) * op)%mod;//power(la,u) * 
                	tmp -= tmp >= mod?mod:0;
				}
                Ans += (a[i]*tmp)%mod;
                Ans -= Ans >= mod?mod:0;
            }
            la += _2[q];            
            la -= la >= mod?mod:0;
            a_ ^= 1;
        }
    }

    // printf("%lld
",Ans);
    int qsum=0;
    for(int i=len-1;i >k;i--) if(s[i]=='1') qsum += _2[i],qsum -= qsum >= mod?mod:0;
    Ans += (largrange(qsum)*power(2,mod-2))%mod;
    Ans %= mod;
    printf("%lld",Ans);
}
原文地址:https://www.cnblogs.com/-Iris-/p/15340248.html