12 #ifndef HELIB_MATRIX_H
13 #define HELIB_MATRIX_H
23 #include <type_traits>
24 #include <initializer_list>
26 #include <NTL/BasicThreadPool.h>
28 #include "assertions.h"
30 #include "zeroValue.h"
41 template <std::
size_t N>
50 template <
typename Iter1,
typename Iter2>
52 const Iter1& lastLength,
53 const Iter2& firstStride,
54 const Iter2& lastStride,
55 const std::vector<long>& st) :
58 std::copy(firstLength, lastLength, this->lengths.begin());
59 std::copy(firstStride, lastStride, this->strides.begin());
60 this->size = (std::accumulate(
lengths.begin(),
63 std::multiplies<std::size_t>()));
67 template <
typename Iter1,
typename Iter2>
69 const Iter1& lastLength,
70 const Iter2& firstStride,
71 const Iter2& lastStride,
75 std::copy(firstLength, lastLength, this->lengths.begin());
76 std::copy(firstStride, lastStride, this->strides.begin());
77 this->size = (std::accumulate(
lengths.begin(),
80 std::multiplies<std::size_t>()));
84 template <
typename... Dims>
90 std::multiplies<std::size_t>()))
93 this->strides.back() = 1;
99 template <
typename... Dims>
102 static_assert(
sizeof...(Dims) == N,
"Wrong number of indices given.");
104 std::array<std::size_t, N> args{std::size_t(dims)...};
106 for (
long i = 0; i < long(N); ++i) {
107 if (args[i] >= this->lengths[i]) {
109 "Index given: " + std::to_string(args[i]) +
110 ". Max value is: " + std::to_string(this->lengths[i]));
114 if (this->start.size() == 1) {
115 return std::inner_product(args.begin(),
117 this->strides.begin(),
121 return std::inner_product(args.begin(),
123 this->strides.begin(),
124 this->start.at(args.at(1)) *
strides.back());
128 std::size_t
order()
const {
return N; }
134 else if (this->size == rhs.
size && this->start == rhs.
start &&
147 template <
typename T, std::
size_t N>
153 std::shared_ptr<std::vector<T>> elements_ptr;
154 bool full_view =
true;
162 template <
typename U = T,
164 typename std::enable_if_t<
165 !std::is_convertible<U, std::size_t>::value>* =
nullptr>
168 elements_ptr(std::make_shared<std::vector<T>>(subscripts.
size, obj))
171 template <
typename... Dims>
173 subscripts{std::size_t(
dims)...},
174 elements_ptr(std::make_shared<std::vector<T>>(subscripts.
size))
178 Tensor(std::initializer_list<std::vector<T>> lst) :
179 subscripts{lst.
size(), lst.begin()->size()},
181 std::make_shared<std::vector<T>>(lst.size() * lst.begin()->size()))
183 int column_length = lst.begin()->size();
185 for (
const auto& v : lst) {
186 if (column_length !=
long(v.size()))
188 "Column dimensions do not match on initializer list.");
192 this->elements_ptr->begin() + (column_length * cnt++));
197 const std::shared_ptr<std::vector<T>>& elems) :
198 subscripts(ts), elements_ptr(elems), full_view(false)
209 std::size_t
order()
const {
return N; }
211 template <
typename... Args>
214 return this->elements_ptr->at(subscripts(args...));
217 template <
typename... Args>
220 return this->elements_ptr->at(subscripts(args...));
223 std::size_t
size()
const {
return this->subscripts.
size; }
225 std::size_t
dims(
int i)
const {
return this->subscripts.
lengths.at(i); }
233 }
else if (this->subscripts != rhs.subscripts) {
236 *elements_ptr == *(rhs.elements_ptr)) {
239 for (
size_t i = 0; i <
dims(0); ++i)
240 for (
size_t j = 0; j <
dims(1); ++j)
241 if (this->
operator()(i, j) != rhs(i, j)) {
253 this->subscripts.lengths.end(),
254 this->subscripts.strides.begin() + 1,
255 this->subscripts.strides.end(),
256 i * this->subscripts.strides.at(0));
257 return Tensor<T, N - 1>(ts, this->elements_ptr);
264 this->subscripts.lengths.end() - 1,
265 this->subscripts.strides.begin(),
266 this->subscripts.strides.end() - 1,
268 return Tensor<T, N - 1>(ts, this->elements_ptr);
275 for (
const auto& j : js)
276 assertInRange<LogicError>(
279 static_cast<long>(this->
dims(1)),
280 "Index for column does not exist. Given index " + std::to_string(j) +
281 ". Expected index in " +
"range [0, " +
282 std::to_string(this->
dims(1)) +
").");
285 std::vector<std::size_t> lengths = {this->
dims(0), js.size()};
287 std::vector<long> offsets(js);
288 for (std::size_t i = 0; i < offsets.size(); ++i) {
294 this->subscripts.strides.begin(),
295 this->subscripts.strides.end(),
301 template <
typename T2>
303 std::function<T&(T&,
const T2&)> operation)
306 std::array<std::size_t, N> rhs_subscripts;
307 for (std::size_t i = 0; i < N; ++i) {
308 rhs_subscripts[i] = rhs.
dims(i);
312 if (!std::equal(this->subscripts.
lengths.begin(),
313 this->subscripts.lengths.end(),
314 rhs_subscripts.begin())) {
319 if (
static_cast<const void*
>(&this->
data()) ==
320 static_cast<const void*
>(&rhs.
data())) {
325 if (this->full_view && rhs.
fullView()) {
326 const std::vector<T2>& rhs_v = rhs.
data();
327 for (std::size_t i = 0; i < this->elements_ptr->size(); ++i) {
328 operation((*this->elements_ptr)[i], rhs_v[i]);
332 for (std::size_t j = 0; j < this->
dims(1); ++j) {
333 for (std::size_t i = 0; i < this->
dims(0); ++i) {
334 operation(this->
operator()(i, j), rhs(i, j));
342 template <
typename T2>
345 return entrywiseOperation<T2>(
347 [](
auto& lhs,
const auto& rhs) -> decltype(
auto) {
352 template <
typename T2>
355 return entrywiseOperation<T2>(
357 [](
auto& lhs,
const auto& rhs) -> decltype(
auto) {
362 template <
typename T2>
365 return entrywiseOperation<T2>(
367 [](
auto& lhs,
const auto& rhs) -> decltype(
auto) {
375 if (this->full_view) {
376 NTL_EXEC_RANGE(
long(this->elements_ptr->size()), first, last)
377 for (
long i = first; i < last; ++i)
378 fn((*elements_ptr)[i]);
383 NTL_EXEC_RANGE(this->
dims(1), first, last)
384 for (
long j = first; j < last; ++j)
385 for (std::size_t i = 0; i < this->
dims(0); ++i)
386 fn(this->
operator()(i, j));
398 std::vector<int> permutation(
size());
399 std::iota(permutation.begin(), permutation.end(), 0);
400 for (
int& num : permutation)
404 std::vector<std::vector<int>> cycles;
405 std::vector<bool> seen(
size(),
false);
406 int num_processed = 0;
408 while (num_processed <
long(
size())) {
410 std::vector<int> cycle = {current_pos};
411 seen[current_pos] =
true;
412 while (permutation.at(cycle.back()) != cycle.front()) {
413 seen[permutation.at(cycle.back())] =
true;
414 cycle.push_back(permutation.at(cycle.back()));
416 num_processed += cycle.size();
417 cycles.push_back(std::move(cycle));
419 while (current_pos <
long(
size()) && seen[current_pos])
423 std::vector<std::pair<int, int>> swaps;
424 for (
const auto& cycle : cycles)
425 if (cycle.size() >= 2)
426 for (
int i = cycle.size() - 1; i > 0; --i)
427 swaps.emplace_back(cycle[i], cycle[i - 1]);
429 for (
const auto&
swap : swaps)
430 std::swap(elements_ptr->at(
swap.first), elements_ptr->at(
swap.second));
433 std::make_shared<std::vector<T>>(
size(),
data().front());
446 j = (j + 1) % subscripts.
lengths[1];
448 i = (i + 1) % subscripts.
lengths[0];
452 new_i = (new_i + 1) % subscripts.
lengths[1];
454 new_j = (new_j + 1) % subscripts.
lengths[0];
456 elements_ptr = new_elements;
459 std::reverse(subscripts.
lengths.begin(), subscripts.
lengths.end());
461 subscripts.
start = {0};
465 const std::vector<T>&
data()
const {
return *this->elements_ptr; }
469 template <
typename T,
471 typename std::enable_if_t<
472 std::is_convertible<T, std::size_t>::value>* =
nullptr>
475 HELIB_NTIMER_START(MatrixMultiplicationConv);
486 NTL_EXEC_RANGE(M1.
dims(0), first, last)
487 for (
long i = first; i < last; ++i)
489 for (std::size_t j = 0; j < M2.
dims(1); ++j)
490 for (std::size_t k = 0; k < M2.
dims(0); ++k) {
497 HELIB_NTIMER_STOP(MatrixMultiplicationConv);
502 template <
typename T,
504 typename std::enable_if_t<
505 !std::is_convertible<T, std::size_t>::value>* =
nullptr>
506 inline Tensor<T, 2>
operator*(
const Tensor<T, 2>& M1,
const Tensor<T2, 2>& M2)
508 HELIB_NTIMER_START(MatrixMultiplicationNotConv);
513 if (M1.dims(1) != M2.dims(0)) {
515 "The number of columns in left matrix (" + std::to_string(M1.dims(1)) +
516 ") do not match the number of rows of the right matrix (" +
517 std::to_string(M2.dims(0)) +
").");
521 NTL_EXEC_RANGE(M1.dims(0), first, last)
522 for (
long i = first; i < last; ++i)
525 for (std::size_t j = 0; j < M2.dims(1); ++j)
526 for (std::size_t k = 0; k < M2.dims(0); ++k) {
533 HELIB_NTIMER_STOP(MatrixMultiplicationNotConv);
538 template <
typename T>
541 template <
typename T>
545 template <
typename T>
550 std::vector<Matrix<T>> columns;
556 template <
typename T>
559 for (std::size_t i = 0; i < M.
dims(0); ++i) {
560 for (std::size_t j = 0; j < M.
dims(1); ++j)
561 out << M(i, j) <<
" ";
568 #endif // ifndef HELIB_MATRIX_H