@@ -88,8 +88,8 @@ return_type_t<T_x, T_alpha, T_beta> categorical_logit_glm_lpmf(
8888 = opencl_kernels::categorical_logit_glm.get_option (" LOCAL_SIZE_" );
8989 const int wgs = (N_instances + local_size - 1 ) / local_size;
9090
91- bool need_alpha_derivative = is_autodiff_v<T_alpha>;
92- bool need_beta_derivative = is_autodiff_v<T_beta>;
91+ constexpr bool need_alpha_derivative = is_autodiff_v<T_alpha>;
92+ constexpr bool need_beta_derivative = is_autodiff_v<T_beta>;
9393
9494 matrix_cl<double > logp_cl (wgs, 1 );
9595 matrix_cl<double > exp_lin_cl (N_instances, N_classes);
@@ -127,13 +127,12 @@ return_type_t<T_x, T_alpha, T_beta> categorical_logit_glm_lpmf(
127127 if constexpr (is_y_vector) {
128128 partials<0 >(ops_partials)
129129 = indexing (beta_val, col_index (x.rows (), x.cols ()),
130- rowwise_broadcast (forward_as<matrix_cl< int >>( y_val) - 1 ))
130+ rowwise_broadcast (y_val - 1 ))
131131 - elt_multiply (exp_lin_cl * transpose (beta_val),
132132 rowwise_broadcast (inv_sum_exp_lin_cl));
133133 } else {
134134 partials<0 >(ops_partials)
135- = indexing (beta_val, col_index (x.rows (), x.cols ()),
136- forward_as<int >(y_val) - 1 )
135+ = indexing (beta_val, col_index (x.rows (), x.cols ()), y_val - 1 )
137136 - elt_multiply (exp_lin_cl * transpose (beta_val),
138137 rowwise_broadcast (inv_sum_exp_lin_cl));
139138 }
@@ -152,9 +151,8 @@ return_type_t<T_x, T_alpha, T_beta> categorical_logit_glm_lpmf(
152151 try {
153152 opencl_kernels::categorical_logit_glm_beta_derivative (
154153 cl::NDRange (local_size * N_attributes), cl::NDRange (local_size),
155- forward_as<arena_matrix_cl<double >>(partials<2 >(ops_partials)),
156- temp, y_val_cl, x_val, N_instances, N_attributes, N_classes,
157- is_y_vector);
154+ partials<2 >(ops_partials), temp, y_val_cl, x_val, N_instances,
155+ N_attributes, N_classes, is_y_vector);
158156 } catch (const cl::Error& e) {
159157 check_opencl_error (function, e);
160158 }
0 commit comments