set.h
1 /* Copyright (C) 2020 IBM Corp.
2  * This program is Licensed under the Apache License, Version 2.0
3  * (the "License"); you may not use this file except in compliance
4  * with the License. You may obtain a copy of the License at
5  * http://www.apache.org/licenses/LICENSE-2.0
6  * Unless required by applicable law or agreed to in writing, software
7  * distributed under the License is distributed on an "AS IS" BASIS,
8  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9  * See the License for the specific language governing permissions and
10  * limitations under the License. See accompanying LICENSE file.
11  */
12 
13 #ifndef HELIB_SET_H
14 #define HELIB_SET_H
15 
16 #include <helib/SumRegister.h>
17 #include <NTL/BasicThreadPool.h>
18 
19 namespace helib {
20 
21 // Binary tree Summation of Ctxts.
22 // Destructive but more efficient algorithm.
23 // TODO: Can write a generic binSum to work with pointers
24 // and make this a vector of unique pointers
31 template <typename TXT>
32 inline void binSumReduction(std::vector<TXT>& ctxtArray)
33 {
34  int cnt = 0;
35  int end = ctxtArray.size();
36 
37  while (end > 1) {
38  ++cnt;
39  int odd = end & 1;
40  int comps = end >> 1;
41 
42  NTL_EXEC_RANGE(comps, first, last)
43  for (unsigned i = first; i < last; i++) {
44  ctxtArray.at(i) += ctxtArray.at(comps + i + odd);
45  }
46  NTL_EXEC_RANGE_END
47 
48  // Free the end of the vector.
49  ctxtArray.erase(ctxtArray.begin() + comps + odd, ctxtArray.end());
50  ctxtArray.shrink_to_fit();
51 
52  // Update end.
53  end = (end + odd) >> 1;
54  }
55 }
56 
66 template <typename TXT>
67 inline TXT calculateSetIntersection(const TXT& query,
68  const std::vector<NTL::ZZX>& server_set)
69 {
70  long availableThreads =
71  std::min(NTL::AvailableThreads(), long(server_set.size()));
72  std::vector<TXT> interResult(availableThreads, query);
73 
74  NTL::PartitionInfo pinfo(server_set.size());
75 
76  NTL_EXEC_INDEX(availableThreads, index)
77  long first, last;
78  pinfo.interval(first, last, index);
79  SumRegister<TXT> sumRegister(last - first);
80 
81  for (long i = first; i < last; ++i) {
82  auto lquery = std::make_unique<TXT>(query);
83  Ptxt<BGV> entry(query.getContext(), server_set[i]);
84  *lquery -= entry;
85  mapTo01(*query.getContext().ea, *lquery);
86  lquery->negate();
87  lquery->addConstant(NTL::ZZX(1L));
88  sumRegister.add(lquery);
89  }
90 
91  sumRegister.flush();
92 
93  assertTrue(sumRegister.hasResult(), "Sum Register did not have a result.");
94 
95  interResult.at(index) = *(sumRegister.getResult());
96  NTL_EXEC_INDEX_END
97 
98  // Final binary sum to add the results of the sum registers
99  binSumReduction<TXT>(interResult);
100  return interResult.at(0) *= query;
101 }
102 
103 } // namespace helib
104 
105 #endif
An object that mimics the functionality of the Ctxt object, and acts as a convenient entry point for ...
Definition: Ptxt.h:280
void binSumReduction(std::vector< TXT > &ctxtArray)
Performs a binary summation of a vector of elements.
Definition: set.h:32
void assertTrue(const T &value, const std::string &message)
Definition: assertions.h:61
std::unique_ptr< T > getResult()
Get the result of the summation.
Definition: SumRegister.h:107
TXT calculateSetIntersection(const TXT &query, const std::vector< NTL::ZZX > &server_set)
Given two sets, calculates and returns the set intersection.
Definition: set.h:67
bool hasResult() const
Check result exists.
Definition: SumRegister.h:116
void flush()
Flush the binary tree to force producing a result on current tree.
Definition: SumRegister.h:127
Definition: apiAttributes.h:21
Class to do a binary tree summation as results appear to keep memory usage to a minimum.
Definition: SumRegister.h:31
void mapTo01(const EncryptedArray &ea, Ctxt &ctxt)
Definition: eqtesting.cpp:35
void add(std::unique_ptr< T > &t)
Add to the sum another object of type T.
Definition: SumRegister.h:67