[数论]中国剩余定理小结

中国剩余定理(CRT)可以说是必学的一个东西啦,主要是用来求线性同余方程组的算法。

x≡a1(mod m1)  x≡a2(mod m2) ... x≡an(mod mn)    在这n个同余方程下解出x

普通的CRT只能求mi两两互质的情况,如若mi不两两互质就得用到扩展CRT。(其实感觉CRT跟扩展CRT原理差别挺大的)

CRT:https://www.cnblogs.com/zwfymqz/p/8425019.html

扩展CRT:https://www.cnblogs.com/zwfymqz/p/8425019.html

然后直接上模板:

中国剩余定理

#include<iostream>
#include<cstdio>
using namespace std;
long long n,m[100],a[100],d;

long long exgcd(long long a,long long b,long long &x,long long &y) {
    if (b==0) { x=1; y=0; return a; }
        else {
            long long tmp=exgcd(b,a%b,y,x);
            y-=x*(a/b); return tmp;
        }
}

long long work() {
    long long res=0,x,y;
    long long lcm=1;
    for (int i=1;i<=n;i++) lcm=lcm*m[i];
    
    for (int i=1;i<=n;i++) {
        long long M=lcm/m[i];
        exgcd(M,m[i],x,y);
        x=(x%m[i]+m[i])%m[i];
        res=(res+(long long)(a[i]*x*M))%lcm;
    }
    return res;
}

int main()
{
    scanf("%d",&n);
    for (int i=1;i<=n;i++) {
        scanf("%lld%lld",&m[i],&a[i]);
    }
    cout<<work();
    return 0;
} 
View Code

扩展中国剩余定理

#include<iostream>
#include<cstdio>
using namespace std;
const int N=1000000+10;
typedef long long LL;
int n;
LL m[N],a[N];

LL exgcd(LL a,LL b,LL &x,LL &y) {
    if (b==0) { x=1; y=0; return a; }
        else {
            LL tmp=exgcd(b,a%b,y,x);
            y-=x*(a/b); return tmp;
        }
}

long long work() {
    LL lcm=m[1],X=a[1],t,x,y;
    for (int i=2;i<=n;i++) {
        LL b=(a[i]-X%m[i]+m[i])%m[i];
        LL d=exgcd(lcm,m[i],x,y);   //解这个方程出来t的特解x t=(b/d)*x 
        if (b%d) return -1;
        t=(b/d)*x%m[i];
        X=(X+t*lcm);   //那么X(k)=X(k-1)+tm 
        lcm=lcm*m[i]/d;
        
        X=(X%lcm+lcm)%lcm;
    }
    return X;
}

int main()
{
    scanf("%d",&n);
    for (int i=1;i<=n;i++) scanf("%lld%lld",&m[i],&a[i]);
    cout<<work()<<endl;
    return 0;
} 
View Code

中国剩余定理的题目当然就是解同余方程组啦,这就要求你能看出来是同余方程组的模型。当然CRT可能不会单独考会结合其他知识点一起考,毕竟只有CRT也是很干瘪的啦。

题目练习:

POJ-1066

CRT裸题,容易看出是n个同余方程。

#include<iostream>
#include<cstdio>
using namespace std;
const long long m[4]={23,28,33};
int n,a[4],d;

int exgcd(int a,int b,int &x,int &y) {
    if (b==0) { x=1; y=0; return a; }
        else {
            int tmp=exgcd(b,a%b,y,x);
            y-=x*(a/b); return tmp;
        }
}

long long work() {
    long long res=0; int x,y;
    long long lcm=m[0]*m[1]*m[2];
    for (int i=0;i<3;i++) {
        int M=lcm/m[i];
        exgcd(M,m[i],x,y);
        x=(x%m[i]+m[i])%m[i];
        res=(res+(long long)(a[i]*x*M))%lcm;
    }
    return res;
}

int main()
{
    int T=0;
    while (scanf("%d%d%d%d",&a[0],&a[1],&a[2],&d)==4 && d!=-1) {
        long long tmp=work();
        tmp=(tmp-d+21252)%21252;
        if (tmp==0) tmp=21252;
        printf("Case %d: the next triple peak occurs in %d days.
",++T,tmp);
    }
    return 0;
} 
View Code

洛谷P4777

扩展CRT测模板裸题,但是注意这题乘法有可能会超出long long。要不用__int128,要不用快速乘。

#include<bits/stdc++.h>
#define LL __int128
using namespace std;
const int N=1e6+10;
int n;
LL m[N],a[N];

LL exgcd(LL a,LL b,LL &x,LL &y) {
    if (b==0) { x=1; y=0; return a; }
        else {
            LL tmp=exgcd(b,a%b,y,x);
            y-=x*(a/b); return tmp;
        }
}

long long work() {
    LL lcm=m[1],X=a[1],t,x,y;
    for (int i=2;i<=n;i++) {
        LL b=(a[i]-X%m[i]+m[i])%m[i];
        LL d=exgcd(lcm,m[i],x,y);   //解这个方程出来t的特解x t=(b/d)*x 
        if (b%d) return -1;
        t=(b/d)*x%m[i];
        X=(X+t*lcm);   //那么X(k)=X(k-1)+tm 
        lcm=lcm*m[i]/d;
        
        X=(X%lcm+lcm)%lcm;
    }
    return X;
}

int main()
{
    scanf("%d",&n);
    for (int i=1;i<=n;i++) scanf("%lld%lld",&m[i],&a[i]);
    cout<<work()<<endl;
    return 0;
} 
View Code

HDU-1573

这题主要是要想明白,其实线性同余方程组是有无数个解的,然后CRT求出来的就是最小解。

这里直接给出结论,CRT和拓展CRT解出来的方程通解都是 x+k*lcm(k€Z)。

所以ans=(N-x)/lcm+1。这里有一个坑点就是因为答案算的是正整数,所以最小解等于0的话就ans-1.

#include<bits/stdc++.h>
using namespace std;
const int N=20+10;
typedef long long LL;
LL n,l,m[N],a[N];

LL exgcd(LL a,LL b,LL &x,LL &y) {
    if (b==0) { x=1; y=0; return a; }
        else {
            LL tmp=exgcd(b,a%b,y,x);
            y-=x*(a/b); return tmp;
        }
}

void exCRT() {
    LL lcm=m[1],X=a[1],t=0,x=0,y=0;
    for (int i=2;i<=n;i++) {
        LL b=(a[i]-X%m[i]+m[i])%m[i];
        LL d=exgcd(lcm,m[i],x,y);   //解这个方程出来t的特解x t=(b/d)*x 
        if (b%d) { puts("0"); return; }  //无解 
        t=(b/d)*x%m[i];
        X=(X+t*lcm);   //那么X(k)=X(k-1)+tm 
        lcm=lcm*m[i]/d;
        
        X=(X%lcm+lcm)%lcm;
    }
    if (X>l) { puts("0"); return; }
    int ans=(l-X)/lcm+1;
    if (X==0) ans--;
    printf("%lld
",ans);
}

int main()
{
    int T; cin>>T;
    while (T--) {
        memset(m,0,sizeof(m));
        memset(a,0,sizeof(a));
        scanf("%lld%lld",&l,&n);
        for (int i=1;i<=n;i++) scanf("%lld",&m[i]);
        for (int i=1;i<=n;i++) scanf("%lld",&a[i]);
        exCRT();
    }
    return 0;
} 
View Code

HDU-1951

这道题是真的好,把许多知识点结合起来考了,一定要做一做想明白。

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int MOD=999911659;
typedef long long LL;
int n,q,cnt=0;
LL a[4],m[4],ys[100000],jc[100000],inv[100000];

void get_ys() {
    cnt=0;
    for (int i=1;i*i<=n;i++)
        if (n%i==0) {
            ys[++cnt]=i;
            if (i*i!=n) ys[++cnt]=n/i;
        }
    sort(ys+1,ys+cnt+1);    
}

LL power(LL x,LL p,LL Mod) {
    LL res=1;
    while (p) {
        if (p&1) res=(res*x)%Mod;
        p>>=1;
        x=(x*x)%Mod; 
    }
    return res;
}

LL exgcd(LL a,LL b,LL &x,LL &y) {
    if (b==0) { x=1; y=0; return a; }
        else {
            LL tmp=exgcd(b,a%b,y,x);
            y-=x*(a/b); return tmp;
        }
}

LL C(LL a,LL b,LL Mod) {
    if(a<b) return 0;
    if(a==b || !b) return 1;
    return (jc[a]*inv[b]*inv[a-b])%Mod;
} 

LL Lucas(LL a,LL b,LL Mod) {
    if (!a || !b) return 1;
    return C(a%Mod,b%Mod,Mod)*Lucas(a/Mod,b/Mod,Mod);
}

LL slove(LL Mod) {
    LL t=1;
    for (int i=1;i<=Mod;i++) {
        t=(t*i)%Mod;
        jc[i]=t;
        inv[i]=power(t,Mod-2,Mod);
    }
    LL res=0;
    for (int i=1;i<=cnt;i++) {
        res=(res+Lucas(n,ys[i],Mod))%Mod;
    }
    return res;
}

int main()
{
    scanf("%d%d",&n,&q);
    if (q%MOD==0) { cout<<0; return 0; }
    get_ys();
    a[0]=slove(2); m[0]=2;
    a[1]=slove(3); m[1]=3;
    a[2]=slove(4679); m[2]=4679;
    a[3]=slove(35617); m[3]=35617;
    
    LL ans=0,lcm=1,x,y,M;
    for (int i=0;i<=3;i++) lcm=lcm*m[i];
    for (int i=0;i<=3;i++) {
        M=lcm/m[i];
        exgcd(M,m[i],x,y);
        x=(x%m[i]+m[i])%m[i];
        ans=(ans+(LL)(a[i]*x%lcm*M))%lcm;
    }
    cout<<power(q,ans,MOD)<<endl; 
    return 0;
} 
View Code
原文地址:https://www.cnblogs.com/clno1/p/10921845.html