Skip to content

Commit 87f32bc

Browse files
committed
5,run
1 parent d06b5bb commit 87f32bc

6 files changed

Lines changed: 106 additions & 76 deletions

File tree

simulators/5_mov_dirichlet/include/SpringEnergy.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@ template <typename T, int dim>
1010
class SpringEnergy
1111
{
1212
public:
13-
SpringEnergy(const std::vector<T> &x, const std::vector<T> &m, const std::vector<int> &DBC, const std::vector<T> &DBC_v,const std::vector<T> &DBC_limit, T k,T h);
13+
SpringEnergy(const std::vector<T> &x, const std::vector<T> &m, const std::vector<int> &DBC, T k);
1414
SpringEnergy();
1515
~SpringEnergy();
1616
SpringEnergy(SpringEnergy &&rhs);
1717
SpringEnergy(const SpringEnergy &rhs);
1818
SpringEnergy &operator=(SpringEnergy &&rhs);
1919

2020
void update_x(const DeviceBuffer<T> &x);
21-
void update_DBC_target();
21+
void update_DBC_target(const std::vector<T> &DBC_target);
22+
void update_k(T new_k);
2223
T val(); // Calculate the value of the energy
2324
const DeviceBuffer<T> &grad(); // Calculate the gradient of the energy
2425
const DeviceTripletMatrix<T, 1> &hess(); // Calculate the Hessian matrix of the energy

simulators/5_mov_dirichlet/include/uti.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ template <typename T>
2323
T min_vector(const DeviceBuffer<T> &a);
2424

2525
template <typename T, int dim>
26-
void search_dir(const DeviceBuffer<T> &grad, const DeviceTripletMatrix<T, 1> &hess, DeviceBuffer<T> &dir, const DeviceBuffer<int> &DBC);
26+
void search_dir(const DeviceBuffer<T> &grad, const DeviceTripletMatrix<T, 1> &hess, DeviceBuffer<T> &dir, const DeviceBuffer<int> &DBC, const std::vector<T> &DBC_target, std::vector<int> &DBC_satified);
2727

2828
template <typename T>
2929
void display_vec(const DeviceBuffer<T> &vec);
3030

3131
template <typename T, int dim>
32-
void set_DBC(DeviceBuffer<T> &grad, DeviceCSRMatrix<T> &hess, const DeviceBuffer<int> &DBC);
32+
void set_DBC(DeviceBuffer<T> &grad, DeviceCSRMatrix<T> &hess, const DeviceBuffer<int> &DBC, const std::vector<T> &DBC_target, std::vector<int> &DBC_satified);

simulators/5_mov_dirichlet/src/SpringEnergy.cu

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ struct SpringEnergy<T, dim>::Impl
1212
DeviceBuffer<T> device_x;
1313
DeviceBuffer<T> device_m;
1414
DeviceBuffer<int> device_DBC;
15-
DeviceBuffer<T> device_DBC_target,device_DBC_v,device_DBC_limit;
15+
DeviceBuffer<T> device_DBC_target;
1616
DeviceBuffer<T> device_grad;
1717
DeviceTripletMatrix<T, 1> device_hess;
18-
T k,h;
18+
T k, h;
1919
int N;
2020
};
2121

@@ -36,18 +36,15 @@ SpringEnergy<T, dim>::SpringEnergy(const SpringEnergy<T, dim> &rhs)
3636
: pimpl_{std::make_unique<Impl>(*rhs.pimpl_)} {}
3737

3838
template <typename T, int dim>
39-
SpringEnergy<T, dim>::SpringEnergy(const std::vector<T> &x, const std::vector<T> &m, const std::vector<int> &DBC, const std::vector<T> &DBC_v, const std::vector<T> &DBC_limit,T k,T h)
39+
SpringEnergy<T, dim>::SpringEnergy(const std::vector<T> &x, const std::vector<T> &m, const std::vector<int> &DBC, T k)
4040
: pimpl_{std::make_unique<Impl>()}
4141
{
4242
pimpl_->N = x.size() / dim;
4343
pimpl_->device_x.copy_from(x);
4444
pimpl_->device_m.copy_from(m);
4545
pimpl_->device_DBC.copy_from(DBC);
46-
pimpl_->device_DBC_v.copy_from(DBC_v);
47-
pimpl_->device_DBC_limit.copy_from(DBC_limit);
4846
pimpl_->device_DBC_target.resize(DBC.size() * dim);
4947
pimpl_->k = k;
50-
pimpl_->h = h;
5148
pimpl_->device_grad.resize(pimpl_->N * dim);
5249
pimpl_->device_hess.resize_triplets(pimpl_->N * dim * dim);
5350
pimpl_->device_hess.reshape(x.size(), x.size());
@@ -60,42 +57,14 @@ void SpringEnergy<T, dim>::update_x(const DeviceBuffer<T> &x)
6057
}
6158

6259
template <typename T, int dim>
63-
void SpringEnergy<T, dim>::update_DBC_target()
60+
void SpringEnergy<T, dim>::update_DBC_target(const std::vector<T> &DBC_target)
6461
{
65-
auto &device_x = pimpl_->device_x;
66-
auto &device_DBC = pimpl_->device_DBC;
67-
auto &device_DBC_target = pimpl_->device_DBC_target;
68-
auto h = pimpl_->h;
69-
auto &device_DBC_v = pimpl_->device_DBC_v;
70-
auto &device_DBC_limit = pimpl_->device_DBC_limit;
71-
int N = device_DBC.size();
72-
device_DBC_target.fill(0);
73-
74-
ParallelFor(256).apply(N, [device_x = device_x.cviewer(), device_DBC = device_DBC.cviewer(), device_DBC_target = device_DBC_target.viewer(), device_DBC_v = device_DBC_v.cviewer(), h, device_DBC_limit = device_DBC_limit.cviewer()] __device__(int i) mutable
75-
{
76-
int idx = device_DBC(i);
77-
T d=0;
78-
for (int j = 0; j < dim; ++j)
79-
{
80-
d += (device_DBC_limit(i*dim + j) - device_x(idx * dim + j)) * (device_DBC_v(i*dim + j));
81-
}
82-
if(d>0)
83-
{
84-
for (int j = 0; j < dim; ++j)
85-
{
86-
device_DBC_target(i*dim + j) = device_x(idx * dim + j) + h * device_DBC_v(i*dim + j);
87-
}
88-
}
89-
else
90-
{
91-
for (int j = 0; j < dim; ++j)
92-
{
93-
device_DBC_target(i*dim + j) = device_x(idx*dim + j);
94-
}
95-
}
96-
}).wait();
97-
98-
62+
pimpl_->device_DBC_target.copy_from(DBC_target);
63+
}
64+
template <typename T, int dim>
65+
void SpringEnergy<T, dim>::update_k(T new_k)
66+
{
67+
pimpl_->k = new_k;
9968
}
10069

10170
template <typename T, int dim>

simulators/5_mov_dirichlet/src/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
int main()
44
{
5-
float rho = 1000, k = 4e4, initial_stretch = 1, n_seg = 8, h = 0.01, side_len = 1, tol = 0.01, mu = 0.11;
5+
float rho = 1000, k = 4e4, initial_stretch = 1, n_seg = 3, h = 0.01, side_len = 1, tol = 0.01, mu = 0.11;
66
// printf("Running mass-spring simulator with parameters: rho = %f, k = %f, initial_stretch = %f, n_seg = %d, h = %f, side_len = %f, tol = %f\n", rho, k, initial_stretch, n_seg, h, side_len, tol);
77
MovDirichletSimulator<float, 2> simulator(rho, side_len, initial_stretch, k, h, tol, mu, n_seg);
88
simulator.run();

simulators/5_mov_dirichlet/src/simulator.cu

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ template <typename T, int dim>
1414
struct MovDirichletSimulator<T, dim>::Impl
1515
{
1616
int n_seg;
17-
T h, rho, side_len, initial_stretch, m, tol, mu;
17+
T h, rho, side_len, initial_stretch, m, tol, mu, DBC_stiff;
1818
int resolution = 900, scale = 200, offset = resolution / 2, radius = 5;
19-
std::vector<T> x, x_tilde, v, k, l2;
20-
std::vector<int> e;
19+
std::vector<T> x, x_tilde, v, k, l2, DBC_limit, DBC_v, DBC_target;
20+
std::vector<int> e, DBC, DBC_satisfied;
2121
DeviceBuffer<int> device_DBC;
2222
DeviceBuffer<T> device_contact_area;
2323
sf::RenderWindow window;
@@ -32,6 +32,7 @@ struct MovDirichletSimulator<T, dim>::Impl
3232
void update_x_tilde(const DeviceBuffer<T> &new_x_tilde);
3333
void update_v(const DeviceBuffer<T> &new_v);
3434
void update_DBC_target();
35+
void update_DBC_stiff(T new_DBC_stiff);
3536
T IP_val();
3637
void step_forward();
3738
void draw();
@@ -61,10 +62,18 @@ template <typename T, int dim>
6162
MovDirichletSimulator<T, dim>::Impl::Impl(T rho, T side_len, T initial_stretch, T K, T h_, T tol_, T mu_, int n_seg) : tol(tol_), h(h_), mu(mu_), window(sf::VideoMode(resolution, resolution), "MovDirichletSimulator")
6263
{
6364
generate(side_len, n_seg, x, e);
64-
std::vector<int> DBC(x.size() / dim, 0);
65+
DBC.push_back((n_seg + 1) * (n_seg + 1));
66+
DBC_target.resize(DBC.size() * dim);
67+
DBC_limit.push_back(0);
68+
DBC_limit.push_back(-0.6);
69+
DBC_v.push_back(0);
70+
DBC_v.push_back(-0.5);
71+
DBC_stiff = 10;
72+
x.push_back(0);
73+
x.push_back(side_len * 0.6);
6574
std::vector<T> contact_area(x.size() / dim, side_len / n_seg);
6675
std::vector<T> ground_n(dim);
67-
ground_n[0] = 0.1, ground_n[1] = 1;
76+
ground_n[0] = 0, ground_n[1] = 1;
6877
T n_norm = ground_n[0] * ground_n[0] + ground_n[1] * ground_n[1];
6978
n_norm = sqrt(n_norm);
7079
for (int i = 0; i < dim; i++)
@@ -94,7 +103,7 @@ MovDirichletSimulator<T, dim>::Impl::Impl(T rho, T side_len, T initial_stretch,
94103
gravityenergy = GravityEnergy<T, dim>(N, m);
95104
barrierenergy = BarrierEnergy<T, dim>(x, ground_n, ground_o, contact_area);
96105
frictionenergy = FrictionEnergy<T, dim>(v, h, ground_n);
97-
springenergy = SpringEnergy<T, dim>(x, std::vector<T>(N, m), DBC, std::vector<T>(N * dim, 0), std::vector<T>(N * dim, 0), 0, h);
106+
springenergy = SpringEnergy<T, dim>(x, std::vector<T>(N, m), DBC, DBC_stiff);
98107
DeviceBuffer<T> x_device(x);
99108
update_x(x_device);
100109
device_DBC = DeviceBuffer<int>(DBC);
@@ -132,15 +141,22 @@ void MovDirichletSimulator<T, dim>::Impl::step_forward()
132141
DeviceBuffer<T> x_tilde(x.size()); // Predictive position
133142
update_x_tilde(add_vector<T>(x, v, 1, h));
134143
frictionenergy.update_mu_lambda(barrierenergy.compute_mu_lambda(mu));
144+
update_DBC_target();
145+
update_DBC_stiff(10);
135146
DeviceBuffer<T> x_n = x; // Copy current positions to x_n
136147
update_v(add_vector<T>(x, x_n, 1 / h, -1 / h));
137148
int iter = 0;
138149
T E_last = IP_val();
139150
DeviceBuffer<T> p = search_direction();
140151
T residual = max_vector(p) / h;
141152
// std::cout << "Initial residual " << residual << "\n";
142-
while (residual > tol)
153+
while (residual > tol || DBC_satisfied.back() != 1) // use last one for simplisity, should check all
143154
{
155+
if (residual <= tol && DBC_satisfied.back() != 1)
156+
{
157+
update_DBC_stiff(DBC_stiff * 2);
158+
E_last = IP_val();
159+
}
144160
// Line search
145161
T alpha = barrierenergy.init_step_size(p);
146162
DeviceBuffer<T> x0 = x;
@@ -195,7 +211,35 @@ void MovDirichletSimulator<T, dim>::Impl::update_v(const DeviceBuffer<T> &new_v)
195211
template <typename T, int dim>
196212
void MovDirichletSimulator<T, dim>::Impl::update_DBC_target()
197213
{
198-
springenergy.update_DBC_target();
214+
for (int i = 0; i < DBC.size(); i++)
215+
{
216+
T diff = 0;
217+
for (int d = 0; d < dim; d++)
218+
{
219+
diff += (DBC_limit[i * dim + d] - x[DBC[i] * dim + d]) * DBC_v[i * dim + d];
220+
}
221+
if (diff > 0)
222+
{
223+
for (int d = 0; d < dim; d++)
224+
{
225+
DBC_target[i * dim + d] = x[DBC[i] * dim + d] + h * DBC_v[i * dim + d];
226+
}
227+
}
228+
else
229+
{
230+
for (int d = 0; d < dim; d++)
231+
{
232+
DBC_target[i * dim + d] = x[DBC[i] * dim + d];
233+
}
234+
}
235+
}
236+
springenergy.update_DBC_target(DBC_target);
237+
}
238+
template <typename T, int dim>
239+
void MovDirichletSimulator<T, dim>::Impl::update_DBC_stiff(T new_DBC_stiff)
240+
{
241+
DBC_stiff = new_DBC_stiff;
242+
springenergy.update_k(new_DBC_stiff);
199243
}
200244
template <typename T, int dim>
201245
void MovDirichletSimulator<T, dim>::Impl::draw()
@@ -212,28 +256,27 @@ void MovDirichletSimulator<T, dim>::Impl::draw()
212256
}
213257

214258
// Draw masses as circles
215-
for (int i = 0; i < x.size() / dim; ++i)
259+
for (int i = 0; i < (x.size() - 1) / dim; ++i)
216260
{
217261
sf::CircleShape circle(radius); // Set a fixed radius for each mass
218262
circle.setFillColor(sf::Color::Red);
219263
circle.setPosition(screen_projection_x(x[i * dim]) - radius, screen_projection_y(x[i * dim + 1]) - radius); // Center the circle on the mass
220264
window.draw(circle);
221265
}
222-
223266
window.display(); // Display the rendered frame
224267
}
225268

226269
template <typename T, int dim>
227270
T MovDirichletSimulator<T, dim>::Impl::IP_val()
228271
{
229272

230-
return inertialenergy.val() + (massspringenergy.val() + gravityenergy.val() + barrierenergy.val() + frictionenergy.val()) * h * h;
273+
return inertialenergy.val() + (massspringenergy.val() + gravityenergy.val() + barrierenergy.val() + frictionenergy.val()) * h * h + springenergy.val();
231274
}
232275

233276
template <typename T, int dim>
234277
DeviceBuffer<T> MovDirichletSimulator<T, dim>::Impl::IP_grad()
235278
{
236-
return add_vector<T>(add_vector<T>(add_vector<T>(add_vector<T>(inertialenergy.grad(), massspringenergy.grad(), 1.0, h * h), gravityenergy.grad(), 1.0, h * h), barrierenergy.grad(), 1.0, h * h), frictionenergy.grad(), 1.0, h * h);
279+
return add_vector<T>(add_vector<T>(add_vector<T>(add_vector<T>(add_vector<T>(inertialenergy.grad(), massspringenergy.grad(), 1.0, h * h), gravityenergy.grad(), 1.0, h * h), barrierenergy.grad(), 1.0, h * h), frictionenergy.grad(), 1.0, h * h), springenergy.grad(), 1.0, 1.0);
237280
}
238281

239282
template <typename T, int dim>
@@ -246,6 +289,8 @@ DeviceTripletMatrix<T, 1> MovDirichletSimulator<T, dim>::Impl::IP_hess()
246289
hess = add_triplet<T>(hess, barrier_hess, 1.0, h * h);
247290
DeviceTripletMatrix<T, 1> friction_hess = frictionenergy.hess();
248291
hess = add_triplet<T>(hess, friction_hess, 1.0, h * h);
292+
DeviceTripletMatrix<T, 1> spring_hess = springenergy.hess();
293+
hess = add_triplet<T>(hess, spring_hess, 1.0, 1.0);
249294
return hess;
250295
}
251296
template <typename T, int dim>
@@ -255,7 +300,21 @@ DeviceBuffer<T> MovDirichletSimulator<T, dim>::Impl::search_direction()
255300
dir.resize(x.size());
256301
DeviceBuffer<T> grad = IP_grad();
257302
DeviceTripletMatrix<T, 1> hess = IP_hess();
258-
search_dir<T, dim>(grad, hess, dir, device_DBC);
303+
// check whether each DBC is satisfied
304+
DBC_satisfied.resize(x.size() / dim, 0);
305+
for (int i = 0; i < DBC.size(); i++)
306+
{
307+
T diff = 0;
308+
for (int d = 0; d < dim; d++)
309+
{
310+
diff += (x[DBC[i] * dim + d] - DBC_target[i * dim + d]) * (x[DBC[i] * dim + d] - DBC_target[i * dim + d]);
311+
}
312+
if (diff / h < tol)
313+
{
314+
DBC_satisfied[DBC[i]] = 1;
315+
}
316+
}
317+
search_dir<T, dim>(grad, hess, dir, device_DBC, DBC_target, DBC_satisfied);
259318
return dir;
260319
}
261320

simulators/5_mov_dirichlet/src/uti.cu

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ template float min_vector<float>(const DeviceBuffer<float> &a);
117117
template double min_vector<double>(const DeviceBuffer<double> &a);
118118

119119
template <typename T, int dim>
120-
void search_dir(const DeviceBuffer<T> &grad, const DeviceTripletMatrix<T, 1> &hess, DeviceBuffer<T> &dir, const DeviceBuffer<int> &DBC)
120+
void search_dir(const DeviceBuffer<T> &grad, const DeviceTripletMatrix<T, 1> &hess, DeviceBuffer<T> &dir, const DeviceBuffer<int> &DBC, const std::vector<T> &DBC_target, std::vector<int> &DBC_satified)
121121
{
122122
static LinearSystemContext ctx;
123123
auto neg_grad = mult_vector<T>(grad, -1);
@@ -126,7 +126,7 @@ void search_dir(const DeviceBuffer<T> &grad, const DeviceTripletMatrix<T, 1> &he
126126
ctx.convert(hess, A_coo);
127127
DeviceCSRMatrix<T> A_csr;
128128
ctx.convert(A_coo, A_csr);
129-
set_DBC<T, dim>(neg_grad, A_csr, DBC);
129+
set_DBC<T, dim>(neg_grad, A_csr, DBC, DBC_target, DBC_satified);
130130
DeviceDenseVector<T> grad_device;
131131
grad_device.resize(N);
132132
grad_device.buffer_view().copy_from(neg_grad);
@@ -136,10 +136,10 @@ void search_dir(const DeviceBuffer<T> &grad, const DeviceTripletMatrix<T, 1> &he
136136
ctx.sync();
137137
dir.view().copy_from(dir_device.buffer_view());
138138
}
139-
template void search_dir<float, 2>(const DeviceBuffer<float> &grad, const DeviceTripletMatrix<float, 1> &hess, DeviceBuffer<float> &dir, const DeviceBuffer<int> &DBC);
140-
template void search_dir<float, 3>(const DeviceBuffer<float> &grad, const DeviceTripletMatrix<float, 1> &hess, DeviceBuffer<float> &dir, const DeviceBuffer<int> &DBC);
141-
template void search_dir<double, 2>(const DeviceBuffer<double> &grad, const DeviceTripletMatrix<double, 1> &hess, DeviceBuffer<double> &dir, const DeviceBuffer<int> &DBC);
142-
template void search_dir<double, 3>(const DeviceBuffer<double> &grad, const DeviceTripletMatrix<double, 1> &hess, DeviceBuffer<double> &dir, const DeviceBuffer<int> &DBC);
139+
template void search_dir<float, 2>(const DeviceBuffer<float> &grad, const DeviceTripletMatrix<float, 1> &hess, DeviceBuffer<float> &dir, const DeviceBuffer<int> &DBC, const std::vector<float> &DBC_target, std::vector<int> &DBC_satified);
140+
template void search_dir<float, 3>(const DeviceBuffer<float> &grad, const DeviceTripletMatrix<float, 1> &hess, DeviceBuffer<float> &dir, const DeviceBuffer<int> &DBC, const std::vector<float> &DBC_target, std::vector<int> &DBC_satified);
141+
template void search_dir<double, 2>(const DeviceBuffer<double> &grad, const DeviceTripletMatrix<double, 1> &hess, DeviceBuffer<double> &dir, const DeviceBuffer<int> &DBC, const std::vector<double> &DBC_target, std::vector<int> &DBC_satified);
142+
template void search_dir<double, 3>(const DeviceBuffer<double> &grad, const DeviceTripletMatrix<double, 1> &hess, DeviceBuffer<double> &dir, const DeviceBuffer<int> &DBC, const std::vector<double> &DBC_target, std::vector<int> &DBC_satified);
143143

144144
template <typename T>
145145
void display_vec(const DeviceBuffer<T> &vec)
@@ -155,15 +155,17 @@ void display_vec(const DeviceBuffer<T> &vec)
155155
}
156156
template void display_vec<float>(const DeviceBuffer<float> &vec);
157157
template void display_vec<double>(const DeviceBuffer<double> &vec);
158-
159158
template <typename T, int dim>
160-
void set_DBC(DeviceBuffer<T> &grad, DeviceCSRMatrix<T> &hess, const DeviceBuffer<int> &DBC)
159+
void set_DBC(DeviceBuffer<T> &grad, DeviceCSRMatrix<T> &hess, const DeviceBuffer<int> &DBC, const std::vector<T> &DBC_target, std::vector<int> &DBC_satified)
161160
{
162161
int N = hess.non_zeros();
163162
int Nr = hess.rows();
163+
int NDBC = DBC.size();
164+
DeviceBuffer<int> device_DBC_satisfied(NDBC);
165+
device_DBC_satisfied.copy_from(DBC_satified);
164166
ParallelFor(256)
165167
.apply(N,
166-
[hess_row_offsets = hess.row_offsets().cviewer(), hess_col_indices = hess.col_indices().cviewer(), hess_values = hess.values().viewer(), DBC = DBC.cviewer(), Nr] __device__(int i) mutable
168+
[hess_row_offsets = hess.row_offsets().cviewer(), hess_col_indices = hess.col_indices().cviewer(), hess_values = hess.values().viewer(), device_DBC_satisfied = device_DBC_satisfied.cviewer(), DBC = DBC.cviewer(), Nr] __device__(int i) mutable
167169
{
168170
// search for the row index
169171
int right = Nr;
@@ -183,18 +185,17 @@ void set_DBC(DeviceBuffer<T> &grad, DeviceCSRMatrix<T> &hess, const DeviceBuffer
183185
}
184186
int row = left - 1;
185187
int col = hess_col_indices(i);
186-
if (DBC(int(row / dim)) || DBC(int(col / dim)))
188+
if ((DBC(int(row / dim)) && device_DBC_satisfied(int(row / dim))) || (DBC(int(col / dim)) && device_DBC_satisfied(int(col / dim))))
187189
{
188190
hess_values(i) = row == col ? 1 : 0;
189191
}
190192
})
191193
.wait();
192-
int NDBC = DBC.size();
193194
ParallelFor(256)
194195
.apply(NDBC,
195-
[grad = grad.viewer(), DBC = DBC.cviewer()] __device__(int i) mutable
196+
[grad = grad.viewer(), DBC = DBC.cviewer(), device_DBC_satisfied = device_DBC_satisfied.viewer()] __device__(int i) mutable
196197
{
197-
if (DBC(i) == 1)
198+
if (DBC(i) == 1 && device_DBC_satisfied(i))
198199
{
199200
for (int d = 0; d < dim; d++)
200201
{
@@ -205,7 +206,7 @@ void set_DBC(DeviceBuffer<T> &grad, DeviceCSRMatrix<T> &hess, const DeviceBuffer
205206
.wait();
206207
}
207208

208-
template void set_DBC<float, 2>(DeviceBuffer<float> &grad, DeviceCSRMatrix<float> &hess, const DeviceBuffer<int> &DBC);
209-
template void set_DBC<float, 3>(DeviceBuffer<float> &grad, DeviceCSRMatrix<float> &hess, const DeviceBuffer<int> &DBC);
210-
template void set_DBC<double, 2>(DeviceBuffer<double> &grad, DeviceCSRMatrix<double> &hess, const DeviceBuffer<int> &DBC);
211-
template void set_DBC<double, 3>(DeviceBuffer<double> &grad, DeviceCSRMatrix<double> &hess, const DeviceBuffer<int> &DBC);
209+
template void set_DBC<float, 2>(DeviceBuffer<float> &grad, DeviceCSRMatrix<float> &hess, const DeviceBuffer<int> &DBC, const std::vector<float> &DBC_target, std::vector<int> &DBC_satified);
210+
template void set_DBC<float, 3>(DeviceBuffer<float> &grad, DeviceCSRMatrix<float> &hess, const DeviceBuffer<int> &DBC, const std::vector<float> &DBC_target, std::vector<int> &DBC_satified);
211+
template void set_DBC<double, 2>(DeviceBuffer<double> &grad, DeviceCSRMatrix<double> &hess, const DeviceBuffer<int> &DBC, const std::vector<double> &DBC_target, std::vector<int> &DBC_satified);
212+
template void set_DBC<double, 3>(DeviceBuffer<double> &grad, DeviceCSRMatrix<double> &hess, const DeviceBuffer<int> &DBC, const std::vector<double> &DBC_target, std::vector<int> &DBC_satified);

0 commit comments

Comments
 (0)