SumRegister.h
1 /* Copyright (C) 2020 IBM Corp.
2  * This program is Licensed under the Apache License, Version 2.0
3  * (the "License"); you may not use this file except in compliance
4  * with the License. You may obtain a copy of the License at
5  * http://www.apache.org/licenses/LICENSE-2.0
6  * Unless required by applicable law or agreed to in writing, software
7  * distributed under the License is distributed on an "AS IS" BASIS,
8  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9  * See the License for the specific language governing permissions and
10  * limitations under the License. See accompanying LICENSE file.
11  */
12 
13 #ifndef HELIB_SUMREGISTER_H
14 #define HELIB_SUMREGISTER_H
15 
16 #include <iostream>
17 #include <vector>
18 #include <cmath>
19 #include <memory>
20 
21 namespace helib {
22 
29 template <typename T>
31 {
32 private:
33  std::vector<std::unique_ptr<T>> intermediateResults;
34 
35  unsigned int maxNumOfInputs;
36  unsigned int remainingInputs;
37  bool flushRequiredFlag = false;
38  bool resultFlag = false;
39  size_t depth = 0;
40 
41 public:
46  SumRegister(unsigned int _maxNumOfInputs) :
47  maxNumOfInputs(_maxNumOfInputs), remainingInputs(_maxNumOfInputs)
48  {
49  // log2 returns -inf if arg is zero!
50  if (_maxNumOfInputs != 0)
51  this->depth = std::ceil(std::log2(_maxNumOfInputs));
52 
53  // Set if NOT power of 2
54  if ((_maxNumOfInputs & (_maxNumOfInputs - 1)) || (_maxNumOfInputs == 0)) {
55  this->flushRequiredFlag = true;
56  }
57 
58  // Default size will be 1.
59  this->intermediateResults = std::vector<std::unique_ptr<T>>(depth + 1);
60  }
61 
67  void add(std::unique_ptr<T>& t)
68  {
69  if (this->remainingInputs == 0) {
70  return;
71  }
72  this->remainingInputs--;
73 
74  if (this->intermediateResults.at(0) != nullptr) {
75  *this->intermediateResults.at(0) += *t;
76  if (depth > 0) {
77  for (size_t i = 1, j = 0; i < this->intermediateResults.size();
78  i++, j++) {
79  if (intermediateResults[i] != nullptr) {
80  *this->intermediateResults[i] += *this->intermediateResults[j];
81  this->intermediateResults[j].reset();
82  } else {
83  this->intermediateResults[i] =
84  std::move(this->intermediateResults[j]);
85  break;
86  }
87  }
88  }
89 
90  if (intermediateResults.at(depth) != nullptr)
91  resultFlag = true;
92 
93  } else {
94  intermediateResults.at(0) = std::move(t);
95  }
96 
97  if ((this->remainingInputs == 0) && this->flushRequiredFlag) {
98  this->flush();
99  }
100  }
101 
107  std::unique_ptr<T> getResult()
108  {
109  return std::move(intermediateResults.at(depth));
110  }
111 
116  bool hasResult() const { return resultFlag; }
117 
122  size_t getDepth() const { return depth; }
123 
127  void flush()
128  {
129  // Flushing with a result should do nothing.
130  if (this->hasResult()) {
131  return;
132  }
133 
134  for (size_t i = 0, j = 1; i < depth; i++, j++) {
135  if (this->intermediateResults[i] == nullptr)
136  continue;
137 
138  if (this->intermediateResults[j] == nullptr) {
139  this->intermediateResults[j] = std::move(this->intermediateResults[i]);
140  } else {
141  *this->intermediateResults[j] += *this->intermediateResults[i];
142  this->intermediateResults[i].reset();
143  }
144  }
145 
146  if (intermediateResults.at(depth) != nullptr)
147  resultFlag = true;
148  }
149 
153  void print()
154  {
155  std::cout << "Current values\n";
156  for (size_t i = 0; i < this->intermediateResults.size(); i++) {
157  std::cout << "[" << i << "]: "
158  << (this->intermediateResults[i] != nullptr
159  ? *this->intermediateResults[i]
160  : 0)
161  << " (" << this->intermediateResults[i] << ")" << '\n';
162  }
163  }
164 
169  std::vector<std::unique_ptr<T>>& getIntermediates()
170  {
171  return intermediateResults;
172  };
173 };
174 
175 } // namespace helib
176 #endif // HELIB_SUMREGISTER_H
SumRegister(unsigned int _maxNumOfInputs)
Constructor.
Definition: SumRegister.h:46
void print()
Print the information in the binary tree.
Definition: SumRegister.h:153
size_t getDepth() const
Get depth of summation binary tree.
Definition: SumRegister.h:122
std::vector< std::unique_ptr< T > > & getIntermediates()
Get the intermediate results.
Definition: SumRegister.h:169
std::unique_ptr< T > getResult()
Get the result of the summation.
Definition: SumRegister.h:107
bool hasResult() const
Check result exists.
Definition: SumRegister.h:116
void flush()
Flush the binary tree to force producing a result on current tree.
Definition: SumRegister.h:127
Definition: apiAttributes.h:21
Class to do a binary tree summation as results appear to keep memory usage to a minimum.
Definition: SumRegister.h:31
void add(std::unique_ptr< T > &t)
Add to the sum another object of type T.
Definition: SumRegister.h:67