#include #include #include #include #include #include #include #include #include #include #include using namespace std; #define VIT(i, v) for (i = 0; i < v.size(); i++) #define IT(it, ds) for (it = ds.begin(); it != ds.end(); it++) #define FUP(i, n) for (i = 0; i < n; i++) typedef vector IVec; typedef vector DVec; typedef vector SVec; void usage(string s) { fprintf(stderr, "usage: mm r1 c1/r2 c2 seed(-1=time(0)) print(y|n)\n"); if (s != "") fprintf(stderr, "%s\n", s.c_str()); exit(1); } class MM { public: vector < vector > M1; vector < vector > M2; vector < vector > P; int Print; void Multiply(); void PrintAll(); }; void MM::Multiply() { int i, j, k; for (i = 0; i < P.size(); i++) { for (j = 0; j < P[0].size(); j++) { for (k = 0; k < M2.size(); k++) P[i][j] += (M1[i][k] * M2[k][j]); } } } void MM::PrintAll() { int i, j; printf("M1: %ld x %ld\n\n", M1.size(), M1[0].size()); for (i = 0; i < M1.size(); i++) { for (j = 0; j < M1[i].size(); j++) printf(" %6.4lf", M1[i][j]); printf("\n"); } printf("\n"); printf("M2: %ld x %ld\n\n", M2.size(), M2[0].size()); for (i = 0; i < M2.size(); i++) { for (j = 0; j < M2[i].size(); j++) printf(" %6.4lf", M2[i][j]); printf("\n"); } printf("\n"); printf("P: %ld x %ld\n\n", P.size(), P[0].size()); for (i = 0; i < P.size(); i++) { for (j = 0; j < P[i].size(); j++) printf(" %6.4lf", P[i][j]); printf("\n"); } } int main(int argc, char **argv) { MM *M; int r1, c1, c2, i, j; string s; long seed; double t0, t1; struct timeval tv; if (argc != 6) usage(""); M = new MM; if (sscanf(argv[1], "%d", &r1) == 0 || r1 <= 0) usage("Bad r1"); if (sscanf(argv[2], "%d", &c1) == 0 || c1 <= 0) usage("Bad c1/r2"); if (sscanf(argv[3], "%d", &c2) == 0 || c2 <= 0) usage("Bad c2"); seed = 0; sscanf(argv[4], "%ld", &seed); if (seed == -1) seed = time(0); srand48(seed); s = argv[5]; if (s == "y") { M->Print = 1; } else if (s == "n") { M->Print = 0; } else usage("Bad print"); M->M1.resize(r1); for (i = 0; i < M->M1.size(); i++) M->M1[i].resize(c1); M->M2.resize(c1); for (i = 0; i < M->M2.size(); i++) M->M2[i].resize(c2); M->P.resize(r1); for (i = 0; i < M->P.size(); i++) M->P[i].resize(c2, 0); for (i = 0; i < M->M1.size(); i++) { for (j = 0; j < M->M1[i].size(); j++) { M->M1[i][j] = drand48()*2.0; } } for (i = 0; i < M->M2.size(); i++) { for (j = 0; j < M->M2[i].size(); j++) { M->M2[i][j] = drand48()*2.0; } } gettimeofday(&tv, NULL); t0 = tv.tv_usec; t0 /= 1000000.0; t0 += tv.tv_sec; M->Multiply(); gettimeofday(&tv, NULL); t1 = tv.tv_usec; t1 /= 1000000.0; t1 += tv.tv_sec; if (M->Print) M->PrintAll(); printf("Time: %.4lf\n", t1-t0); exit(0); }