POJ 3735 Training little cats

题目链接: http://poj.org/problem?id=3735

    一开始的想法是建立两个(n+1)*(n+1)的矩阵A,对角线均为1,其他元素初始为0,另一个(n+1)*1的矩阵B,对于元素(b i1) ∈ B,1<= i <=n(矩阵元素坐标在讨论中均以1开始),表示第 i 只猫的花生个数,初始均为0,b(n+1)1=1。

    对于g i 操作,将(a i(n+1)) 加1。((aij) A)

    对于e i 操作,将第 i 行全置为零。

    对于s i j 操作,将第 i 行和第 j 行交换。

    这样,修改后的矩阵A对其求m次幂,再与矩阵B相乘得到一个(n+1)*1的矩阵,第 i 行(1<= i <=n)表示第n只猫最后的花生个数。

    对于题目样列,算式如下,

    仔细思考,上述矩阵可以简化,重新定义(n+1)*(n+1)的矩阵A,对角线均为1,对于(a i1) ∈ A,1<= i <=n表示第n只猫的花生个数,初始均为0

    对于g i 操作,将(a i 1) 加1。

    对于e i 操作,将第 i 行全置为零。

    对于s i j 操作,将第 i 行和第 j 行交换。

    这个方法与前面的是等价的,对A求m次幂,(a i1) ∈ Am,1<= i <=n表示第 i 只猫的花生个数。

    程序实现后一直WA,参考了别人的代码,发现他有对矩阵乘法进行优化。

1     for(int i=0;i<=n;i++){
2         for(int k=0;k<=n;k++){
3             if(a.v[i][k])
4             for(int j=0;j<=n;j++)
5                 if(b.v[k][j])
6                 c.v[i][j]+=a.v[i][k]*b.v[k][j];
7         }
8     }

    n最大可达100,矩阵的每个元素必须用long long存储,运算量还是非常大的。但矩阵是稀疏的,大部分元素是0,修改矩阵乘法后,程序运行时间减少到了60ms。

 1 #include <iostream>
 2 #include <algorithm>
 3 #include <map>
 4 #include <vector>
 5 #include <functional>
 6 #include <string>
 7 #include <cstring>
 8 #include <queue>
 9 #include <set>
10 #include <cmath>
11 #include <cstdio>
12 using namespace std;
13 #define IOS ios_base::sync_with_stdio(false)
14 typedef long long LL;
15 const int INF = 0x3f3f3f3f;
16 
17 const int maxn=105;
18 int n;
19 typedef struct matrix{
20     LL v[maxn][maxn];
21     void init(){memset(v,0,sizeof(v));}
22 }M;
23 M a,res;
24 M mul(const M &a,const M &b)
25 {
26     M c; c.init();
27     //稀疏矩阵
28     for(int i=0;i<=n;i++){
29         for(int k=0;k<=n;k++){
30             if(a.v[i][k])
31             for(int j=0;j<=n;j++)
32                 if(b.v[k][j])
33                 c.v[i][j]+=a.v[i][k]*b.v[k][j];
34         }
35     }
36     return c;
37 }
38 M power(M x,LL p)
39 {
40     M tmp; tmp.init();
41     for(int i=0;i<=n;i++)
42         tmp.v[i][i]=1;
43     while(p){
44         if(p&1) tmp=mul(x,tmp);
45         x=mul(x,x);
46         p>>=1;
47     }
48     return tmp;
49 }
50 
51 int main()
52 {
53     int m,k,p,q;
54     char ch;
55     while(~scanf("%d%d%d",&n,&m,&k)&&n+m+k){
56         a.init();
57         for(int i=0;i<=n;i++) a.v[i][i]=1;
58         while(k--){
59             scanf(" %c%d",&ch,&p);
60             if(ch=='g'){
61                 a.v[p][0]++;
62             }else if(ch=='e'){
63                 for(int i=0;i<=n;i++) a.v[p][i]=0;
64             }else{
65                 scanf("%d",&q);
66                 for(int i=0;i<=n;i++) swap(a.v[p][i],a.v[q][i]);
67             }
68         }
69         res=power(a,m);
70         for(int i=1;i<n;i++) printf("%lld ",res.v[i][0]);
71         printf("%lld
",res.v[n][0]);
72     }
73 }
原文地址:https://www.cnblogs.com/cumulonimbus/p/5698880.html