@@ -37,13 +37,8 @@ void test(queue &Q, int M, int N, int K)
3737 auto B = malloc_device<T>(ldb * N, Q);
3838 auto C = malloc_device<T>(ldc * N, Q);
3939
40- /* Fill A/B with random data */
4140 constexpr int rd_size = 1048576 ;
42- auto random_data = malloc_host<T>(rd_size, Q);
43- generate_random_data (rd_size, random_data);
44-
45- replicate_data (Q, A, lda * K, random_data, rd_size);
46- replicate_data (Q, B, ldb * N, random_data, rd_size);
41+ auto host_data = malloc_host<T>(rd_size, Q);
4742
4843 /* Measure time for a given number of GEMM calls */
4944 auto time_gemms = [=, &Q](int runs) -> double {
@@ -57,7 +52,36 @@ void test(queue &Q, int M, int N, int K)
5752 return duration<double >(end - start).count ();
5853 };
5954
60- /* Do a warmup call to initialize MKL and ensure kernels are JIT'ed if needed */
55+ /* Fill A/B with all ones to verify correctness */
56+ generate_ones (rd_size, host_data);
57+ replicate_data (Q, A, lda * K, host_data, rd_size);
58+ replicate_data (Q, B, ldb * N, host_data, rd_size);
59+
60+ /* Verify that the leading entries of C are correct */
61+ std::cout << " -> Verification..." ;
62+ (void ) time_gemms (1 );
63+ size_t elems = std::min (ldc * N, rd_size);
64+ Q.copy (C, host_data, elems).wait ();
65+ bool ok = true ;
66+ int linear_id = 0 ;
67+ for (size_t j = 0 ; j < N; j++) {
68+ for (size_t i = 0 ; i < M; i++) {
69+ linear_id = j*ldc + i;
70+ if (linear_id >= elems) break ;
71+ if (host_data[linear_id] != T (K)) {
72+ ok = false ;
73+ }
74+ }
75+ if (linear_id >= elems) break ;
76+ }
77+ std::cout << (ok ? " passes." : " FAILS!" ) << std::endl;
78+
79+ /* Fill A/B with random data */
80+ generate_random_data (rd_size, host_data);
81+ replicate_data (Q, A, lda * K, host_data, rd_size);
82+ replicate_data (Q, B, ldb * N, host_data, rd_size);
83+
84+ /* Do a warmup call with random data to initialize MKL and ensure kernels are JIT'ed if needed */
6185 std::cout << " -> Warmup...\n " ;
6286 (void ) time_gemms (1 );
6387
@@ -93,7 +117,7 @@ void test(queue &Q, int M, int N, int K)
93117 free (A, Q);
94118 free (B, Q);
95119 free (C, Q);
96- free (random_data , Q);
120+ free (host_data , Q);
97121}
98122
99123void usage (const char *pname)
0 commit comments