@@ -88,8 +88,8 @@ return_type_t<T_x, T_alpha, T_beta> categorical_logit_glm_lpmf(
8888 = opencl_kernels::categorical_logit_glm.get_option (" LOCAL_SIZE_" );
8989 const int wgs = (N_instances + local_size - 1 ) / local_size;
9090
91- bool need_alpha_derivative = is_autodiff_v<T_alpha>;
92- bool need_beta_derivative = is_autodiff_v<T_beta>;
91+ constexpr bool need_alpha_derivative = is_autodiff_v<T_alpha>;
92+ constexpr bool need_beta_derivative = is_autodiff_v<T_beta>;
9393
9494 matrix_cl<double > logp_cl (wgs, 1 );
9595 matrix_cl<double > exp_lin_cl (N_instances, N_classes);
@@ -127,13 +127,13 @@ return_type_t<T_x, T_alpha, T_beta> categorical_logit_glm_lpmf(
127127 if constexpr (is_y_vector) {
128128 partials<0 >(ops_partials)
129129 = indexing (beta_val, col_index (x.rows (), x.cols ()),
130- rowwise_broadcast (forward_as<matrix_cl< int >>( y_val) - 1 ))
130+ rowwise_broadcast (y_val - 1 ))
131131 - elt_multiply (exp_lin_cl * transpose (beta_val),
132132 rowwise_broadcast (inv_sum_exp_lin_cl));
133133 } else {
134134 partials<0 >(ops_partials)
135135 = indexing (beta_val, col_index (x.rows (), x.cols ()),
136- forward_as< int >( y_val) - 1 )
136+ y_val - 1 )
137137 - elt_multiply (exp_lin_cl * transpose (beta_val),
138138 rowwise_broadcast (inv_sum_exp_lin_cl));
139139 }
@@ -152,7 +152,7 @@ return_type_t<T_x, T_alpha, T_beta> categorical_logit_glm_lpmf(
152152 try {
153153 opencl_kernels::categorical_logit_glm_beta_derivative (
154154 cl::NDRange (local_size * N_attributes), cl::NDRange (local_size),
155- forward_as<arena_matrix_cl< double >>( partials<2 >(ops_partials) ),
155+ partials<2 >(ops_partials),
156156 temp, y_val_cl, x_val, N_instances, N_attributes, N_classes,
157157 is_y_vector);
158158 } catch (const cl::Error& e) {
0 commit comments