@@ -13,8 +13,8 @@ struct MassSpringEnergy<T, dim>::Impl
1313 DeviceBuffer<T> device_l2, device_k;
1414 DeviceBuffer<int > device_e;
1515 int N;
16- std::vector <T> host_grad ;
17- SparseMatrix<T> host_hess ;
16+ DeviceBuffer <T> device_grad ;
17+ DeviceTripletMatrix<T, 1 > device_hess ;
1818};
1919template <typename T, int dim>
2020MassSpringEnergy<T, dim>::MassSpringEnergy() = default ;
@@ -40,15 +40,13 @@ MassSpringEnergy<T, dim>::MassSpringEnergy(const std::vector<T> &x, const std::v
4040 pimpl_->device_e .copy_from (e);
4141 pimpl_->device_l2 .copy_from (l2);
4242 pimpl_->device_k .copy_from (k);
43- pimpl_->host_grad = std::vector<T>(pimpl_->N * dim);
4443 int size = e.size () / 2 ;
45- pimpl_->host_hess = SparseMatrix<T>(size * dim * dim);
4644}
4745
4846template <typename T, int dim>
49- void MassSpringEnergy<T, dim>::update_x(const std::vector <T> &x)
47+ void MassSpringEnergy<T, dim>::update_x(DeviceBuffer <T> &x)
5048{
51- pimpl_->device_x .copy_from (x);
49+ pimpl_->device_x .view (). copy_from (x);
5250}
5351
5452template <typename T, int dim>
@@ -94,15 +92,14 @@ T MassSpringEnergy<T, dim>::val()
9492} // Calculate the energy
9593
9694template <typename T, int dim>
97- std::vector <T> &MassSpringEnergy<T, dim>::grad()
95+ DeviceBuffer <T> &MassSpringEnergy<T, dim>::grad()
9896{
9997 auto &device_x = pimpl_->device_x ;
10098 auto &device_e = pimpl_->device_e ;
10199 auto &device_l2 = pimpl_->device_l2 ;
102100 auto &device_k = pimpl_->device_k ;
103101 auto N = pimpl_->device_e .size () / 2 ;
104102 DeviceBuffer<T> device_grad (pimpl_->N * dim);
105- auto &host_grad = pimpl_->host_grad ;
106103 ParallelFor (256 ).apply (N, [device_x = device_x.cviewer (), device_e = device_e.cviewer (), device_l2 = device_l2.cviewer (), device_k = device_k.cviewer (), device_grad = device_grad.viewer ()] __device__ (int i) mutable
107104 {
108105 int idx1= device_e (2 * i); // First node index
@@ -120,23 +117,22 @@ std::vector<T> &MassSpringEnergy<T, dim>::grad()
120117
121118 } })
122119 .wait ();
123- device_grad.copy_to (host_grad);
124- return host_grad;
120+ return device_grad;
125121}
126122
127123template <typename T, int dim>
128- SparseMatrix<T> MassSpringEnergy<T, dim>::hess()
124+ DeviceTripletMatrix<T, 1 > & MassSpringEnergy<T, dim>::hess()
129125{
130126 auto &device_x = pimpl_->device_x ;
131127 auto &device_e = pimpl_->device_e ;
132128 auto &device_l2 = pimpl_->device_l2 ;
133129 auto &device_k = pimpl_->device_k ;
134130 auto N = device_e.size () / 2 ;
135- auto &host_hess = pimpl_->host_hess ;
136- DeviceBuffer<T> device_hess (N * dim * dim * 4 );
137- DeviceBuffer< int > device_hess_row_idx (N * dim * dim * 4 );
138- DeviceBuffer< int > device_hess_col_idx (N * dim * dim * 4 );
139- ParallelFor (256 ).apply (N, [device_x = device_x.cviewer (), device_e = device_e.cviewer (), device_l2 = device_l2.cviewer (), device_k = device_k.cviewer (), device_hess = device_hess .viewer (), device_hess_row_idx = device_hess_row_idx.viewer (), device_hess_col_idx = device_hess_col_idx.viewer (), N] __device__ (int i) mutable
131+ auto device_hess = pimpl_->device_hess ;
132+ auto device_hess_row_idx = device_hess. row_indices ( );
133+ auto device_hess_col_idx = device_hess. col_indices ( );
134+ auto device_hess_val = device_hess. values ( );
135+ ParallelFor (256 ).apply (N, [device_x = device_x.cviewer (), device_e = device_e.cviewer (), device_l2 = device_l2.cviewer (), device_k = device_k.cviewer (), device_hess_val = device_hess_val .viewer (), device_hess_row_idx = device_hess_row_idx.viewer (), device_hess_col_idx = device_hess_col_idx.viewer (), N] __device__ (int i) mutable
140136 {
141137 int idx[2 ] = {device_e (2 * i), device_e (2 * i + 1 )}; // First node index
142138 T diff = 0 ;
@@ -164,14 +160,11 @@ SparseMatrix<T> MassSpringEnergy<T, dim>::hess()
164160 for (int d2 = 0 ; d2 < dim; d2++){
165161 device_hess_row_idx (indStart + d1 * dim + d2)= idx[ni]*dim + d1;
166162 device_hess_col_idx (indStart + d1 * dim + d2)= idx[nj] * dim + d2;
167- device_hess (indStart + d1 * dim + d2)= H_local (ni * dim + d1, nj * dim + d2);
163+ device_hess_val (indStart + d1 * dim + d2)= H_local (ni * dim + d1, nj * dim + d2);
168164 }
169165 } })
170166 .wait ();
171- device_hess.copy_to (host_hess.set_val_buffer ());
172- device_hess_row_idx.copy_to (host_hess.set_row_buffer ());
173- device_hess_col_idx.copy_to (host_hess.set_col_buffer ());
174- return host_hess;
167+ return device_hess;
175168
176169} // Calculate the Hessian of the energy
177170
0 commit comments