@@ -12,7 +12,7 @@ struct SpringEnergy<T, dim>::Impl
1212 DeviceBuffer<T> device_x;
1313 DeviceBuffer<T> device_m;
1414 DeviceBuffer<int > device_DBC;
15- DeviceBuffer<Eigen::Matrix<T, dim, 1 >> device_DBC_target;
15+ DeviceBuffer<T> device_DBC_target,device_DBC_v,device_DBC_limit ;
1616 DeviceBuffer<T> device_grad;
1717 DeviceTripletMatrix<T, 1 > device_hess;
1818 T k,h;
@@ -36,14 +36,16 @@ SpringEnergy<T, dim>::SpringEnergy(const SpringEnergy<T, dim> &rhs)
3636 : pimpl_{std::make_unique<Impl>(*rhs.pimpl_ )} {}
3737
3838template <typename T, int dim>
39- SpringEnergy<T, dim>::SpringEnergy(const std::vector<T> &x, const std::vector<T> &m, const std::vector<int > &DBC, const std::vector<T> &DBC_target, T k,T h)
39+ SpringEnergy<T, dim>::SpringEnergy(const std::vector<T> &x, const std::vector<T> &m, const std::vector<int > &DBC, const std::vector<T> &DBC_v, const std::vector<T> &DBC_limit, T k,T h)
4040 : pimpl_{std::make_unique<Impl>()}
4141{
4242 pimpl_->N = x.size () / dim;
4343 pimpl_->device_x .copy_from (x);
4444 pimpl_->device_m .copy_from (m);
4545 pimpl_->device_DBC .copy_from (DBC);
46- pimpl_->device_DBC_target .copy_from (DBC_target);
46+ pimpl_->device_DBC_v .copy_from (DBC_v);
47+ pimpl_->device_DBC_limit .copy_from (DBC_limit);
48+ pimpl_->device_DBC_target .resize (DBC.size () * dim);
4749 pimpl_->k = k;
4850 pimpl_->h = h;
4951 pimpl_->device_grad .resize (pimpl_->N * dim);
@@ -60,15 +62,39 @@ void SpringEnergy<T, dim>::update_x(const DeviceBuffer<T> &x)
6062template <typename T, int dim>
6163void SpringEnergy<T, dim>::update_DBC_target()
6264{
63- // for i in range(0, len(DBC)):
64- // if (DBC_limit[i] - x_n[DBC[i]]).dot(DBC_v[i]) > 0:
65- // DBC_target.append(x_n[DBC[i]] + h * DBC_v[i])
66- // else:
67- // DBC_target.append(x_n[DBC[i]])
6865 auto &device_x = pimpl_->device_x ;
6966 auto &device_DBC = pimpl_->device_DBC ;
7067 auto &device_DBC_target = pimpl_->device_DBC_target ;
7168 auto h = pimpl_->h ;
69+ auto &device_DBC_v = pimpl_->device_DBC_v ;
70+ auto &device_DBC_limit = pimpl_->device_DBC_limit ;
71+ int N = device_DBC.size ();
72+ device_DBC_target.fill (0 );
73+
74+ ParallelFor (256 ).apply (N, [device_x = device_x.cviewer (), device_DBC = device_DBC.cviewer (), device_DBC_target = device_DBC_target.viewer (), device_DBC_v = device_DBC_v.cviewer (), h, device_DBC_limit = device_DBC_limit.cviewer ()] __device__ (int i) mutable
75+ {
76+ int idx = device_DBC (i);
77+ T d=0 ;
78+ for (int j = 0 ; j < dim; ++j)
79+ {
80+ d += (device_DBC_limit (i*dim + j) - device_x (idx * dim + j)) * (device_DBC_v (i*dim + j));
81+ }
82+ if (d>0 )
83+ {
84+ for (int j = 0 ; j < dim; ++j)
85+ {
86+ device_DBC_target (i*dim + j) = device_x (idx * dim + j) + h * device_DBC_v (i*dim + j);
87+ }
88+ }
89+ else
90+ {
91+ for (int j = 0 ; j < dim; ++j)
92+ {
93+ device_DBC_target (i*dim + j) = device_x (idx*dim + j);
94+ }
95+ }
96+ }).wait ();
97+
7298
7399}
74100
0 commit comments