1+ #include < muda/muda.h>
2+ #include < muda/container.h>
3+ #include " SparseMatrix.h"
4+
5+ using namespace muda ;
6+ template <typename T>
7+ SparseMatrix<T>::SparseMatrix(int size) : size(size)
8+ {
9+ row_idx = std::vector<int >(size);
10+ col_idx = std::vector<int >(size);
11+ val = std::vector<T>(size);
12+ }
13+ template <typename T>
14+ SparseMatrix<T>::SparseMatrix() = default ;
15+
16+ template <typename T>
17+ SparseMatrix<T>::~SparseMatrix ()
18+ {
19+ val.clear ();
20+ row_idx.clear ();
21+ col_idx.clear ();
22+ }
23+
24+ template <typename T>
25+ SparseMatrix<T>::SparseMatrix(SparseMatrix<T> &&rhs)
26+ {
27+ size = rhs.size ;
28+ row_idx = std::move (rhs.row_idx );
29+ col_idx = std::move (rhs.col_idx );
30+ val = std::move (rhs.val );
31+ }
32+
33+ template <typename T>
34+ SparseMatrix<T> &SparseMatrix<T>::operator =(SparseMatrix<T> &&rhs)
35+ {
36+ size = rhs.size ;
37+ row_idx = std::move (rhs.row_idx );
38+ col_idx = std::move (rhs.col_idx );
39+ val = std::move (rhs.val );
40+ return *this ;
41+ }
42+ template <typename T>
43+ SparseMatrix<T> &SparseMatrix<T>::operator =(SparseMatrix<T> &rhs)
44+ {
45+ size = rhs.size ;
46+ row_idx = rhs.row_idx ;
47+ col_idx = rhs.col_idx ;
48+ val = rhs.val ;
49+ return *this ;
50+ }
51+ template <typename T>
52+ SparseMatrix<T>::SparseMatrix(const SparseMatrix<T> &rhs)
53+ {
54+ size = rhs.size ;
55+ row_idx = rhs.row_idx ;
56+ col_idx = rhs.col_idx ;
57+ val = rhs.val ;
58+ }
59+
60+ template <typename T>
61+ void SparseMatrix<T>::set_value(int row, int col, T value, int loc)
62+ {
63+ assert (loc < size);
64+ row_idx[loc] = row;
65+ col_idx[loc] = col;
66+ val[loc] = value;
67+ }
68+ template <typename T>
69+ void SparseMatrix<T>::set_diagonal(T value)
70+ {
71+ for (int i = 0 ; i < size; i++)
72+ {
73+ set_value (i, i, value, i);
74+ }
75+ }
76+ template <typename T>
77+ SparseMatrix<T> &SparseMatrix<T>::combine(const SparseMatrix<T> &other)
78+ {
79+ int old_size = size;
80+ size += other.size ;
81+ row_idx.resize (size);
82+ col_idx.resize (size);
83+ val.resize (size);
84+ // copy memory
85+ for (int i = 0 ; i < other.size ; i++)
86+ {
87+ set_value (other.row_idx [i], other.col_idx [i], other.val [i], i + old_size);
88+ }
89+ return *this ;
90+ }
91+ template <typename T>
92+ SparseMatrix<T> &SparseMatrix<T>::operator *(const T &a)
93+ {
94+ DeviceBuffer<T> val_device (val);
95+ int N = val.size ();
96+ DeviceBuffer<T> c_device (N);
97+ ParallelFor (256 )
98+ .apply (N,
99+ [c_device = c_device.viewer (), val_device = val_device.cviewer (), a] __device__ (int i) mutable
100+ {
101+ c_device (i) = val_device (i) * a;
102+ })
103+ .wait ();
104+ c_device.copy_to (val);
105+ return *this ;
106+ }
107+ template <typename T>
108+ const std::vector<int > &SparseMatrix<T>::get_row_buffer() const
109+ {
110+ return row_idx;
111+ }
112+ template <typename T>
113+ const std::vector<int > &SparseMatrix<T>::get_col_buffer() const
114+ {
115+ return col_idx;
116+ }
117+ template <typename T>
118+ const std::vector<T> &SparseMatrix<T>::get_val_buffer() const
119+ {
120+ return val;
121+ }
122+
123+ template <typename T>
124+ std::vector<int > &SparseMatrix<T>::set_row_buffer()
125+ {
126+ return row_idx;
127+ }
128+ template <typename T>
129+ std::vector<int > &SparseMatrix<T>::set_col_buffer()
130+ {
131+ return col_idx;
132+ }
133+ template <typename T>
134+ std::vector<T> &SparseMatrix<T>::set_val_buffer()
135+ {
136+ return val;
137+ }
138+ template <typename T>
139+ int SparseMatrix<T>::get_size() const
140+ {
141+ return size;
142+ }
143+ template class SparseMatrix <float >;
144+ template class SparseMatrix <double >;
0 commit comments