@@ -415,13 +415,17 @@ inline void getrf_batch_wrapper(sycl::queue &exec_queue, int n, T *a[], int lda,
415415 ptrs, events);
416416 mem_free_thread.detach ();
417417#else
418- std::int64_t m_int64 = n;
419- std::int64_t n_int64 = n;
420- std::int64_t lda_int64 = lda;
421- std::int64_t group_sizes = batch_size;
418+ std::int64_t *m_int64 = new std::int64_t ;
419+ std::int64_t *n_int64 = new std::int64_t ;
420+ std::int64_t *lda_int64 = new std::int64_t ;
421+ std::int64_t *group_sizes = new std::int64_t ;
422+ *m_int64 = n;
423+ *n_int64 = n;
424+ *lda_int64 = lda;
425+ *group_sizes = batch_size;
422426 std::int64_t scratchpad_size =
423427 oneapi::mkl::lapack::getrf_batch_scratchpad_size<Ty>(
424- exec_queue, & m_int64, & n_int64, & lda_int64, 1 , & group_sizes);
428+ exec_queue, m_int64, n_int64, lda_int64, 1 , group_sizes);
425429
426430 Ty *scratchpad = sycl::malloc_device<Ty>(scratchpad_size, exec_queue);
427431 std::int64_t *ipiv_int64 =
@@ -433,9 +437,9 @@ inline void getrf_batch_wrapper(sycl::queue &exec_queue, int n, T *a[], int lda,
433437 for (std::int64_t i = 0 ; i < batch_size; ++i)
434438 ipiv_int64_ptr[i] = ipiv_int64 + n * i;
435439
436- oneapi::mkl::lapack::getrf_batch (
437- exec_queue, &m_int64, &n_int64, (Ty **)a_shared, & lda_int64,
438- ipiv_int64_ptr, 1 , & group_sizes, scratchpad, scratchpad_size);
440+ oneapi::mkl::lapack::getrf_batch (exec_queue, m_int64, n_int64,
441+ (Ty **)a_shared, lda_int64, ipiv_int64_ptr ,
442+ 1 , group_sizes, scratchpad, scratchpad_size);
439443
440444 sycl::event e = exec_queue.submit ([&](sycl::handler &cgh) {
441445 cgh.parallel_for <
@@ -445,6 +449,15 @@ inline void getrf_batch_wrapper(sycl::queue &exec_queue, int n, T *a[], int lda,
445449 });
446450 });
447451
452+ exec_queue.submit ([&](sycl::handler &cgh) {
453+ cgh.depends_on (e);
454+ cgh.host_task ([=] {
455+ delete m_int64;
456+ delete n_int64;
457+ delete lda_int64;
458+ delete group_sizes;
459+ });
460+ });
448461 std::vector<void *> ptrs{scratchpad, ipiv_int64, ipiv_int64_ptr, a_shared};
449462 ::dpct::cs::enqueue_free (ptrs, {e}, exec_queue);
450463#endif
@@ -535,15 +548,22 @@ inline void getrs_batch_wrapper(sycl::queue &exec_queue,
535548 ptrs, events);
536549 mem_free_thread.detach ();
537550#else
538- std::int64_t n_int64 = n;
539- std::int64_t nrhs_int64 = nrhs;
540- std::int64_t lda_int64 = lda;
541- std::int64_t ldb_int64 = ldb;
542- std::int64_t group_sizes = batch_size;
551+ std::int64_t *n_int64 = new std::int64_t ;
552+ std::int64_t *nrhs_int64 = new std::int64_t ;
553+ std::int64_t *lda_int64 = new std::int64_t ;
554+ std::int64_t *ldb_int64 = new std::int64_t ;
555+ std::int64_t *group_sizes = new std::int64_t ;
556+ oneapi::mkl::transpose *trans_array = new oneapi::mkl::transpose;
557+ *n_int64 = n;
558+ *nrhs_int64 = nrhs;
559+ *lda_int64 = lda;
560+ *ldb_int64 = ldb;
561+ *group_sizes = batch_size;
562+ *trans_array = trans;
543563 std::int64_t scratchpad_size =
544564 oneapi::mkl::lapack::getrs_batch_scratchpad_size<Ty>(
545- exec_queue, &trans, & n_int64, & nrhs_int64, & lda_int64, & ldb_int64, 1 ,
546- & group_sizes);
565+ exec_queue, trans_array, n_int64, nrhs_int64, lda_int64, ldb_int64, 1 ,
566+ group_sizes);
547567
548568 Ty *scratchpad = sycl::malloc_device<Ty>(scratchpad_size, exec_queue);
549569 std::int64_t *ipiv_int64 =
@@ -569,10 +589,21 @@ inline void getrs_batch_wrapper(sycl::queue &exec_queue,
569589 ipiv_int64_ptr[i] = ipiv_int64 + n * i;
570590
571591 sycl::event e = oneapi::mkl::lapack::getrs_batch (
572- exec_queue, &trans, & n_int64, & nrhs_int64, (Ty **)a_shared, & lda_int64,
573- ipiv_int64_ptr, (Ty **)b_shared, & ldb_int64, 1 , & group_sizes, scratchpad,
592+ exec_queue, trans_array, n_int64, nrhs_int64, (Ty **)a_shared, lda_int64,
593+ ipiv_int64_ptr, (Ty **)b_shared, ldb_int64, 1 , group_sizes, scratchpad,
574594 scratchpad_size);
575595
596+ exec_queue.submit ([&](sycl::handler &cgh) {
597+ cgh.depends_on (e);
598+ cgh.host_task ([=] {
599+ delete n_int64;
600+ delete nrhs_int64;
601+ delete lda_int64;
602+ delete ldb_int64;
603+ delete group_sizes;
604+ delete trans_array;
605+ });
606+ });
576607 std::vector<void *> ptrs{scratchpad, ipiv_int64_ptr, ipiv_int64, a_shared,
577608 b_shared};
578609 ::dpct::cs::enqueue_free (ptrs, {e}, exec_queue);
@@ -659,12 +690,15 @@ inline void getri_batch_wrapper(sycl::queue &exec_queue, int n, const T *a[],
659690 ptrs, events);
660691 mem_free_thread.detach ();
661692#else
662- std::int64_t n_int64 = n;
663- std::int64_t ldb_int64 = ldb;
664- std::int64_t group_sizes = batch_size;
693+ std::int64_t *n_int64 = new std::int64_t ;
694+ std::int64_t *ldb_int64 = new std::int64_t ;
695+ std::int64_t *group_sizes = new std::int64_t ;
696+ *n_int64 = n;
697+ *ldb_int64 = ldb;
698+ *group_sizes = batch_size;
665699 std::int64_t scratchpad_size =
666700 oneapi::mkl::lapack::getri_batch_scratchpad_size<Ty>(
667- exec_queue, & n_int64, & ldb_int64, 1 , & group_sizes);
701+ exec_queue, n_int64, ldb_int64, 1 , group_sizes);
668702
669703 Ty *scratchpad = sycl::malloc_device<Ty>(scratchpad_size, exec_queue);
670704 std::int64_t *ipiv_int64 =
@@ -695,9 +729,17 @@ inline void getri_batch_wrapper(sycl::queue &exec_queue, int n, const T *a[],
695729 }
696730
697731 sycl::event e = oneapi::mkl::lapack::getri_batch (
698- exec_queue, & n_int64, (Ty **)b_shared, & ldb_int64, ipiv_int64_ptr, 1 ,
699- & group_sizes, scratchpad, scratchpad_size);
732+ exec_queue, n_int64, (Ty **)b_shared, ldb_int64, ipiv_int64_ptr, 1 ,
733+ group_sizes, scratchpad, scratchpad_size);
700734
735+ exec_queue.submit ([&](sycl::handler &cgh) {
736+ cgh.depends_on (e);
737+ cgh.host_task ([=] {
738+ delete n_int64;
739+ delete ldb_int64;
740+ delete group_sizes;
741+ });
742+ });
701743 std::vector<void *> ptrs{scratchpad, ipiv_int64_ptr, ipiv_int64, a_shared,
702744 b_shared};
703745 ::dpct::cs::enqueue_free (ptrs, {e}, exec_queue);
@@ -780,13 +822,17 @@ inline void geqrf_batch_wrapper(sycl::queue exec_queue, int m, int n, T *a[],
780822 mem_free_thread_a.detach ();
781823 mem_free_thread_tau.detach ();
782824#else
783- std::int64_t m_int64 = n;
784- std::int64_t n_int64 = n;
785- std::int64_t lda_int64 = lda;
786- std::int64_t group_sizes = batch_size;
825+ std::int64_t *m_int64 = new std::int64_t ;
826+ std::int64_t *n_int64 = new std::int64_t ;
827+ std::int64_t *lda_int64 = new std::int64_t ;
828+ std::int64_t *group_sizes = new std::int64_t ;
829+ *m_int64 = n;
830+ *n_int64 = n;
831+ *lda_int64 = lda;
832+ *group_sizes = batch_size;
787833 std::int64_t scratchpad_size =
788834 oneapi::mkl::lapack::geqrf_batch_scratchpad_size<Ty>(
789- exec_queue, & m_int64, & n_int64, & lda_int64, 1 , & group_sizes);
835+ exec_queue, m_int64, n_int64, lda_int64, 1 , group_sizes);
790836
791837 Ty *scratchpad = sycl::malloc_device<Ty>(scratchpad_size, exec_queue);
792838 T **a_shared = sycl::malloc_shared<T *>(batch_size, exec_queue);
@@ -795,9 +841,18 @@ inline void geqrf_batch_wrapper(sycl::queue exec_queue, int m, int n, T *a[],
795841 exec_queue.memcpy (tau_shared, tau, batch_size * sizeof (T *)).wait ();
796842
797843 sycl::event e = oneapi::mkl::lapack::geqrf_batch (
798- exec_queue, & m_int64, & n_int64, (Ty **)a_shared, & lda_int64,
799- (Ty **)tau_shared, 1 , & group_sizes, scratchpad, scratchpad_size);
844+ exec_queue, m_int64, n_int64, (Ty **)a_shared, lda_int64,
845+ (Ty **)tau_shared, 1 , group_sizes, scratchpad, scratchpad_size);
800846
847+ exec_queue.submit ([&](sycl::handler &cgh) {
848+ cgh.depends_on (e);
849+ cgh.host_task ([=] {
850+ delete m_int64;
851+ delete n_int64;
852+ delete lda_int64;
853+ delete group_sizes;
854+ });
855+ });
801856 std::vector<void *> ptrs{scratchpad, a_shared, tau_shared};
802857 ::dpct::cs::enqueue_free (ptrs, {e}, exec_queue);
803858#endif
0 commit comments