@@ -73,7 +73,18 @@ void run()
7373 double total_time = 0.0 ;
7474
7575 namespace mkl_rng = oneapi::mkl::rng;
76- mkl_rng::mcg59 engine (
76+ #if USE_PHILOX
77+ using EngineTypeHost = mkl_rng::philox4x32x10;
78+ using EngineTypeDevice = mkl_rng::device::philox4x32x10<VEC_SIZE>;
79+ #elif USE_MRG
80+ using EngineTypeHost = mkl_rng::mrg32k3a;
81+ using EngineTypeDevice = mkl_rng::device::mrg32k3a<VEC_SIZE>;
82+ #else
83+ using EngineTypeHost = mkl_rng::mcg59;
84+ using EngineTypeDevice = mkl_rng::device::mcg59<VEC_SIZE>;
85+ #endif
86+
87+ EngineTypeHost engine (
7788#if !INIT_ON_HOST
7889 my_queue,
7990#else
@@ -86,18 +97,10 @@ void run()
8697 auto rng_event_3 = mkl_rng::generate (mkl_rng::uniform<DataType>(1.0 , 5.0 ), engine, num_options, h_option_years_ptr);
8798
8899 std::size_t n_states = global_size;
89- using EngineType =
90- #if USE_PHILOX
91- mkl_rng::device::philox4x32x10<VEC_SIZE>;
92- #elif USE_MRG
93- mkl_rng::device::mrg32k3a<VEC_SIZE>;
94- #else
95- mkl_rng::device::mcg59<VEC_SIZE>;
96- #endif
97100
98101 // initialization needs only on first step
99102 auto deleter = [my_queue](auto * ptr) {sycl::free (ptr, my_queue);};
100- auto rng_states_uptr = std::unique_ptr<EngineType , decltype (deleter)>(sycl::malloc_device<EngineType >(n_states, my_queue), deleter);
103+ auto rng_states_uptr = std::unique_ptr<EngineTypeDevice , decltype (deleter)>(sycl::malloc_device<EngineTypeDevice >(n_states, my_queue), deleter);
101104 auto * rng_states = rng_states_uptr.get ();
102105
103106 my_queue.parallel_for <k_initialize_state<DataType>>(
@@ -107,9 +110,9 @@ void run()
107110 auto id = idx[0 ];
108111#if USE_MRG
109112 constexpr std::uint32_t seed = 12345u ;
110- rng_states[id] = EngineType ({ seed, seed, seed, seed, seed, seed }, { 0 , (4096 * id) });
113+ rng_states[id] = EngineTypeDevice ({ seed, seed, seed, seed, seed, seed }, { 0 , (4096 * id) });
111114#else
112- rng_states[id] = EngineType (rand_seed, id * ITEMS_PER_WORK_ITEM * VEC_SIZE * block_n);
115+ rng_states[id] = EngineTypeDevice (rand_seed, id * ITEMS_PER_WORK_ITEM * VEC_SIZE * block_n);
113116#endif
114117 })
115118 .wait_and_throw ();
0 commit comments