#include #include #include #include #include #include #include #include #include #include #include using namespace std; #define VIT(i, v) for (i = 0; i < v.size(); i++) #define IT(it, ds) for (it = ds.begin(); it != ds.end(); it++) #define FUP(i, n) for (i = 0; i < n; i++) typedef vector IVec; typedef vector DVec; typedef vector SVec; void usage(string s) { fprintf(stderr, "usage: mm r1 c1/r2 c2 blocksize seed(-1=time(0)) print(y|n)\n"); if (s != "") fprintf(stderr, "%s\n", s.c_str()); exit(1); } class MM { public: vector M1; vector M2; vector P; int r1, c1, c2; int Print; int Blocksize; void MultiplyBlock(int row1, int col1, int col2); void Multiply(); void PrintAll(); }; void MM::MultiplyBlock(int row1, int col1, int col2) { int pr, pc, tmp; for (pr = row1; pr < ((row1 + Blocksize > r1) ? r1 : row1 + Blocksize); pr++) { for (pc = col2; pc < ((col2 + Blocksize > c2) ? c2 : col2 + Blocksize); pc++) { for (tmp = col1; tmp < ((col1 + Blocksize > c1) ? c1 : col1 + Blocksize); tmp++) { P[pr*c2+pc] += (M1[pr*c1+tmp] * M2[pc*c1+tmp]); } } } } void MM::Multiply() { int row1, col1, col2; for (row1 = 0; row1 < r1; row1 += Blocksize) { for (col2 = 0; col2 < c2; col2 += Blocksize) { for (col1 = 0; col1 < c1; col1 += Blocksize) { MultiplyBlock(row1, col1, col2); } } } } void MM::PrintAll() { int i, j; printf("M1: %d x %d\n\n", r1, c1); for (i = 0; i < r1; i++) { for (j = 0; j < c1; j++) printf(" %6.4lf", M1[i*c1+j]); printf("\n"); } printf("\n"); printf("M2: %d x %d\n\n", c1, c2); for (i = 0; i < c1; i++) { for (j = 0; j < c2; j++) printf(" %6.4lf", M2[j*c1+i]); printf("\n"); } printf("\n"); printf("P: %d x %d\n\n", r1, c2); for (i = 0; i < r1; i++) { for (j = 0; j < c2; j++) printf(" %6.4lf", P[i*c2+j]); printf("\n"); } } int main(int argc, char **argv) { MM *M; int r1, c1, c2, i, j, bs; string s; long seed; double t0, t1; struct timeval tv; if (argc != 7) usage(""); M = new MM; if (sscanf(argv[1], "%d", &r1) == 0 || r1 <= 0) usage("Bad r1"); if (sscanf(argv[2], "%d", &c1) == 0 || c1 <= 0) usage("Bad c1/r2"); if (sscanf(argv[3], "%d", &c2) == 0 || c2 <= 0) usage("Bad c2"); if (sscanf(argv[4], "%d", &bs) == 0 || bs <= 0) usage("Bad bs"); M->r1 = r1; M->c1 = c1; M->c2 = c2; M->Blocksize = bs; seed = 0; sscanf(argv[5], "%ld", &seed); if (seed == -1) seed = time(0); srand48(seed); s = argv[6]; if (s == "y") { M->Print = 1; } else if (s == "n") { M->Print = 0; } else usage("Bad print"); M->M1.resize(r1*c1); M->M2.resize(c1*c2); M->P.resize(r1*c2, 0); for(i = 0; i < r1*c1; i++) M->M1[i] = drand48()*2.0; for(i = 0; i < c1; i++) { for(j = 0; j < c2; j++) M->M2[j*c1+i] = drand48()*2.0; } gettimeofday(&tv, NULL); t0 = tv.tv_usec; t0 /= 1000000.0; t0 += tv.tv_sec; M->Multiply(); gettimeofday(&tv, NULL); t1 = tv.tv_usec; t1 /= 1000000.0; t1 += tv.tv_sec; if (M->Print) M->PrintAll(); printf("Time: %.4lf\n", t1-t0); exit(0); }