拉格朗日插值

公式:$f(x)=sum_{i=1}^{n} y_{i} prod_{i eq j} frac{x-x_{j}}{x_{i}-x_{j}}$.    

这个式子正常算的话是 $O(n^2)$ 的,如果遇到 $x$ 是连续的情况可以优化到 $O(n log n)$.   

但是有些时候我们只知道 $f(x)$ 在 $x=k$ 时的点值是不够的,有时必须求出这个多项式每一位系数.    

多项式快速插值可以做到 $O(n log^2 n)$,但是快速插值非常非常难写,用处并不多.     

相比之下,有一种简易的写法可以在 $O(n^2)$ 的时间复杂度内通过 $n$ 个不同的点来还原一个 $n-1$ 次多项式.    

插值公式中 $prod_{j eq i} (x-x_{j})$ 是比较难求的,其他地方由于都是基于整数的运算,所以比较简单.     

先令 $f_{i,j}$ 表示考虑前 $i$ 个点 $(x,y)$,$x^j$ 前的系数.    

那么有转移:$f_{i,j}=f_{i-1,j-1}+f_{i-1,j} imes (-x_{i})$ 即分别表示当前位的贡献为 $x^1 / -x_{i}$.      

求出这个后,我们枚举 $i$,然后想 $O(n)$ 计算 $h(x)=prod_{i eq j} (x-x_{j})$.   

令 $k1[i],k2[i]$ 分别表示 $h(x)$ 的 $x^i$ 前的系数,强制让第 $i$ 位贡献 $x^1$ 时 $x^i$ 前的系数.    

由于有 $k2$ 这个强制贡献的状态,转移就比较简单:

$k1[i] leftarrow k2[i+1]$ 

$f_{n,i}=k1[i] imes (-x_{i}) +k2[i] Rightarrow k2[i]=f_{n,i}+k1[i] imes (x_{i})$.    

算出 $k2$ 后把 $y_{i}$ 及插值公式中分母的贡献乘上然后累加到答案数组中即可.      

应用:

求 $sum_{i=1}^{n} i^k$.    

这是一个关于 $n$ 的 $k+1$ 次多项式.  

所以可以取 $k+2$ 个点带进去,然后用拉格朗日插值法来求值.    

具体,$f(k)=sum_{i=1}^{n} y_{i} prod_{}^{i eq j}frac{k-x_{j}}{x_{i}-x_{j}}$    

由于点可以做到取 $x$ 连续的,所以提前预处理前缀/后缀积极可以做到 $O(n log n)$.  

code: 

#include <cstdio>  
#include <vector>  
#include <cstring>
#include <algorithm>  
#define N 1000009  
#define ll long long 
#define mod 1000000007
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std;   
int f[N]; 
int ifac[N],fac[N],pre[N],suf[N],inv[N],n,K;  
int qpow(int x,int y) {  
    int tmp=1; 
    for(;y;y>>=1,x=(ll)x*x%mod)  
        if(y&1) tmp=(ll)tmp*x%mod;   
    return tmp;    
} 
int INV(int x) { return qpow(x,mod-2); }         
void init() {  
    ifac[0]=fac[0]=inv[1]=1;  
    for(int i=2;i<N;++i) inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;     
    inv[0]=1;  
    for(int i=1;i<N;++i) {
        fac[i]=(ll)fac[i-1]*i%mod;
        ifac[i]=(ll)ifac[i-1]*inv[i]%mod;   
    }
    pre[0]=1,suf[n+1]=1;  
    for(int i=1;i<=n;++i)  pre[i]=(ll)pre[i-1]*(K-i+mod)%mod;      
    for(int i=n;i>=1;--i)  suf[i]=(ll)suf[i+1]*(K-i+mod)%mod;             
} 
int sol() { 
    int ans=0; 
    for(int i=1;i<=n;++i) {  
        int a1=(ll)ifac[i-1]*ifac[n-i]%mod; 
        if((n-i)&1) a1=(ll)a1*(mod-1)%mod;    
        int a2=(ll)pre[i-1]*suf[i+1]%mod;   
        (ans+=(ll)f[i]*a1%mod*a2%mod)%=mod; 
    }  
    return ans;  
}
int main() {          
    // setIO("input");   
    scanf("%d%d",&K,&n),n+=2;                           
    init();      
    for(int i=1;i<=n;++i) { 
        f[i]=(ll)(f[i-1]+qpow(i,n-2))%mod;        
    }    
    printf("%d
",sol());  
    return 0; 
}

  

还原多项式系数

#include <cstdio>  
#include <cstring>
#include <algorithm>  
#define N 2008   
#define ll long long
#define mod 998244353
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std;     
int f[N][N],k1[N],k2[N],s[N],n;    
struct point {  
    int x,y;  
    point(int x=0,int y=0):x(x),y(y){}  
}a[N];  
int ADD(int x,int y) { 
    return (ll)(x+y)%mod; 
} 
int DEC(int x,int y) { 
    return (ll)(x-y+mod)%mod; 
} 
int MUL(int x,int y) { 
    return (ll)x*y%mod;   
}
int qpow(int x,int y) { 
    int tmp=1; 
    for(;y;y>>=1,x=MUL(x,x)) 
        if(y&1) { 
            tmp=MUL(tmp,x);  
        } 
    return tmp;   
}
int get_inv(int x) { return qpow(x,mod-2); }  
void init() {  
    f[0][0]=1;   
    for(int i=1;i<=n;++i) {  
        for(int j=1;j<=i;++j) {   
            f[i][j]=ADD(f[i-1][j-1],MUL(mod-a[i].x,f[i-1][j]));        
        }
        f[i][0]=MUL(f[i-1][0],mod-a[i].x);               
    }
}           
int main() {  
    // setIO("input");
    int X;   
    scanf("%d%d",&n,&X);  
    for(int i=1;i<=n;++i) {
        scanf("%d%d",&a[i].x,&a[i].y);  
    }
    init();               
    for(int i=1;i<=n;++i) {   
        for(int j=0;j<=n;++j) k2[j]=f[n][j];   
        for(int j=n-1;j>=0;--j) { 
            k1[j]=k2[j+1];   
            k2[j]=ADD(k2[j],MUL(k1[j],a[i].x));           
        }           
        int inv=1;  
        for(int j=1;j<=n;++j) 
            if(i!=j) {
                inv=(ll)inv*(a[i].x-a[j].x+mod)%mod;   
            }
        inv=get_inv(inv);   
        for(int j=0;j<=n-1;++j) { 
            (s[j]+=(ll)inv*a[i].y%mod*k1[j]%mod)%=mod;   
        }
    }
    int ans=0;      
    for(int i=n-1;i>=0;--i) { 
        ans=(ll)((ll)ans*X%mod+s[i])%mod;    
    }  
    printf("%d
",ans); 
    return 0;
}

  

例题

CF917D Stranger Trees

给你一颗树,求 $n$ 个点有多少个生成树满足该生成树与给定树有 $k$ 条边是重合的.    

题解:

先对完全图构建矩阵,然后将原树上的边 $(x,y)$ 在矩阵中的边权标记成 $x^1$,其余边权为 $1$.  

矩阵树定理求的是所有生成树边权乘积之和,那么要是可以对含 $x$ 的矩阵求行列式的话可以直接得出答案.   

但是复杂度太高,而且难写(写不了)    

所以用 $n$ 个不同的整数来替换那个 $x^1$,然后跑出来 $n$ 个结果,用拉格朗日插值还原出多项式的系数即可.    

#include <cstdio> 
#include <vector>
#include <cstring>
#include <algorithm>      
#define N 103
#define ll long long
#define mod 1000000007
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;     
int n;   
int A[N],B[N]; 
int f[N][N],k1[N],k2[N],ans[N];
int deg[N][N],con[N][N],a[N][N];  
struct point {
    int x,y; 
    point(int x=0,int y=0):x(x),y(y){} 
}p[N]; 
int qpow(int x,int y) {
    int tmp=1;
    for(;y;y>>=1,x=(ll)x*x%mod)
        if(y&1) {  
            tmp=(ll)tmp*x%mod; 
        }
    return tmp;
}
int get_inv(int x) {
    return qpow(x,mod-2); 
}    
int ADD(int x,int y) {
    return (ll)(x+y)%mod;
}
int DEC(int x,int y) {
    return (ll)(x-y+mod)%mod; 
} 
int MUL(int x,int y) {
    return (ll)x*y%mod;  
}
int gauss() { 
    int ans=1; 
    for(int i=1;i<n;++i) {
        for(int j=i+1;j<n;++j) {   
            while(a[j][i]) {
                int t=a[i][i]/a[j][i]; 
                for(int k=i;k<n;++k) {
                    a[i][k]=DEC(a[i][k],MUL(t,a[j][k]));         
                }
                swap(a[j],a[i]);  
                ans=(ll)ans*(mod-1)%mod;   
            }
        }
        if(!a[i][i]) {
            return 0; 
        }
    }
    for(int i=1;i<n;++i) {
        ans=(ll)ans*a[i][i]%mod;   
    }
    return ans;  
}        
int cal(int val) { 
    for(int i=1;i<=n;++i) { 
        for(int j=1;j<=n;++j) {
            a[i][j]=mod-1;
        }
    }  
    for(int i=1;i<=n;++i) {
        a[i][i]=n-1;   
    }
    for(int i=1;i<n;++i) { 
        int x=A[i],y=B[i];                       
        a[x][x]=(ll)(DEC(a[x][x],1)+val)%mod;   
        a[y][y]=(ll)(DEC(a[y][y],1)+val)%mod;   
        a[x][y]=(ll)(a[x][y]+1-val+mod)%mod; 
        a[y][x]=(ll)(a[y][x]+1-val+mod)%mod;
    }
    return gauss();  
}
void init() { 
    f[0][0]=1;  
    for(int i=1;i<=n;++i) {
        for(int j=1;j<=i;++j)
            f[i][j]=ADD(f[i-1][j-1],MUL(f[i-1][j],mod-p[i].x));    
        f[i][0]=(ll)f[i-1][0]*(mod-p[i].x)%mod;  
    }
}
int main() { 
    // setIO("input");  
    scanf("%d",&n); 
    int x,y,z;
    for(int i=1;i<n;++i) {           
        scanf("%d%d",&A[i],&B[i]); 
    }
    for(int i=1;i<=n;++i) { 
        p[i].x=i;  
        p[i].y=cal(i); 
    }    
    init(); 
    for(int i=1;i<=n;++i) { 
        int inv=1; 
        for(int j=1;j<=n;++j) {    
            if(i!=j) inv=(ll)inv*(p[i].x-p[j].x+mod)%mod;  
        }
        inv=get_inv(inv);      
        for(int j=0;j<=n;++j) {
            k2[j]=f[n][j];
        }
        for(int j=n-1;j>=0;--j) {
            k1[j]=k2[j+1];    
            k2[j]=ADD(k2[j],MUL(p[i].x,k1[j]));  
        }                
        for(int j=0;j<=n-1;++j) {
            ans[j]=ADD(ans[j],(ll)k1[j]*inv%mod*p[i].y%mod);  
        }
    }    
    for(int i=0;i<n;++i) {
        printf("%d ",ans[i]);  
    }
    return 0; 
}

  

 LuoguP4463 [集训队互测2012] calc

朴素的 DP 非常好列:$f[i][j]$ 表示选了 $i$ 个数,且值域为 $[1,j]$ 的总价值和.    

那么有 $f[i][j]=f[i-1][j-1] imes j+f[i][j-1]$,直接算的话复杂度是 $O(nD)$ 的.   

但是我们可以猜测这是一个关于 $j$ 的 $g_{i}$ 次多项式.    

有一个结论:对于 $n$ 次多项式 $h(x)$,满足 $h(x)-h(x-1)$ 是 $n-1$ 次多项式.   

那么有 $f[i][j]-f[i][j-1]=f[i-1][j-1] imes j$.    

将 $g$ 带入,有 $g_{i}-1=g_{i-1}+1$.    

即 $g_{i}=g_{i-1}+2$,说明这是一个关于 $j$ 的 $2 imes i$ 次多项式.    

那么我们就求出 $f[n][1...2n+1]$ 后将值带入,然后拉格朗日插值来插一下就行了.   

code: 

#include <cstdio>  
#include <cstring>
#include <algorithm> 
#define N 2002
#define ll long long 
#define setIO(s) freopen(s".in","r",stdin)
using namespace std; 
int D,n,mod,tot,f[N][N],fac[N];  
void init() {
    fac[0]=1;  
    for(int i=1;i<N;++i) {
        fac[i]=(ll)fac[i-1]*i%mod;   
    }
}
struct point {
    int x,y;  
    point(int x=0,int y=0):x(x),y(y){}  
}a[N];  
int qpow(int x,int y) {
    int tmp=1; 
    for(;y;y>>=1,x=(ll)x*x%mod)  
        if(y&1) tmp=(ll)tmp*x%mod; 
    return tmp; 
}  
int get_inv(int x) {
    return qpow(x,mod-2);   
}
int calc() {
    int ans=0;   
    for(int i=1;i<=tot;++i) {  
        int inv=1,up=1;    
        for(int j=1;j<=tot;++j) {
            if(i==j) continue;     
            up=(ll)up*(D-a[j].x+mod)%mod;    
            inv=(ll)inv*(a[i].x-a[j].x+mod)%mod;   
        }
        inv=get_inv(inv);   
        (ans+=(ll)a[i].y*up%mod*inv%mod)%=mod;  
    }
    return ans;   
}
int main() {          
    // setIO("input");    
    scanf("%d%d%d",&D,&n,&mod);  
    init(); 
    for(int i=0;i<=2*n+1;++i) f[0][i]=1;   
    for(int i=1;i<=n;++i) {
        for(int j=1;j<=2*n+1;++j) {
            f[i][j]=(ll)(f[i][j-1]+(ll)f[i-1][j-1]*j%mod)%mod;  
        }
    } 
    for(int i=1;i<=2*n+1;++i) {
        a[++tot]=point(i,f[n][i]);    
    }
    printf("%d
",(ll)calc()*fac[n]%mod);   
    return 0;   
}

  

原文地址:https://www.cnblogs.com/guangheli/p/13329921.html