partialMatch.h
1 /* Copyright (C) 2020 IBM Corp.
2  * This program is Licensed under the Apache License, Version 2.0
3  * (the "License"); you may not use this file except in compliance
4  * with the License. You may obtain a copy of the License at
5  * http://www.apache.org/licenses/LICENSE-2.0
6  * Unless required by applicable law or agreed to in writing, software
7  * distributed under the License is distributed on an "AS IS" BASIS,
8  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9  * See the License for the specific language governing permissions and
10  * limitations under the License. See accompanying LICENSE file.
11  */
12 
13 #ifndef HELIB_PARTIALMATCH_H
14 #define HELIB_PARTIALMATCH_H
15 
16 #include <sstream>
17 #include <stack>
18 
19 #include <helib/Matrix.h>
20 #include <helib/PolyMod.h>
21 
22 // This code is in flux and should be considered bery alpha.
23 // Not recommended for public use.
24 
25 namespace helib {
26 
38 template <typename TXT>
40  Matrix<TXT> query,
41  const Matrix<Ptxt<BGV>>& database)
42 {
43  if (query.dims(0) != 1)
44  throw InvalidArgument("Query must be a row vector");
45  if (query.dims(1) != database.dims(1))
46  throw InvalidArgument(
47  "Database and query must have same number of columns");
48  // TODO: case where query.dims(0) != database.dims(0)
49 
50  // Replicate the query once per row of the database
51  // TODO: Some such replication will be needed once blocks/bands exist
52  std::vector<long> columns(database.dims(0), 0l);
53  Matrix<TXT>& mask = query;
54  mask.transpose();
55  mask = mask.columns(columns);
56  mask.transpose();
57 
58  (mask -= database)
59  .apply([&](auto& entry) { mapTo01(ea, entry); })
60  .apply([](auto& entry) { entry.negate(); })
61  .apply([](auto& entry) { entry.addConstant(NTL::ZZX(1l)); });
62 
63  return mask;
64 }
65 
77 template <typename TXT>
79  Matrix<TXT> query,
80  const Matrix<Ctxt>& database)
81 {
82  if (query.dims(0) != 1)
83  throw InvalidArgument("Query must be a row vector");
84  if (query.dims(1) != database.dims(1))
85  throw InvalidArgument(
86  "Database and query must have same number of columns");
87  // TODO: case where query.dims(0) != database.dims(0)
88 
89  // Replicate the query once per row of the database
90  // TODO: Some such replication will be needed once blocks/bands exist
91  std::vector<long> columns(database.dims(0), 0l);
92  Matrix<TXT>& mask = query;
93  mask.transpose();
94  mask = mask.columns(columns);
95  mask.transpose();
96 
97  (mask -= database)
98  .apply([&](auto& entry) { mapTo01(ea, entry); })
99  .apply([](auto& entry) { entry.negate(); })
100  .apply([](auto& entry) { entry.addConstant(NTL::ZZX(1l)); });
101 
102  return mask;
103 }
104 
118 template <typename TXT>
120  const std::vector<std::vector<long>> index_sets,
121  const std::vector<long>& offsets,
122  const std::vector<Matrix<long>>& weights,
123  const Matrix<TXT>& mask)
124 {
125  assertEq<InvalidArgument>(index_sets.size(),
126  offsets.size(),
127  "index_sets and offsets must have matching size");
128  assertEq<InvalidArgument>(index_sets.size(),
129  weights.size(),
130  "index_sets and weights must have matching size");
131  auto ones(mask(0, 0));
132  ones.clear();
133  ones.addConstant(NTL::ZZX(1L));
134  Matrix<TXT> result(ones, mask.dims(0), 1l);
135  for (std::size_t i = 0; i < index_sets.size(); ++i) {
136  const auto& index_set = index_sets.at(i);
137  const auto& weight_set = weights.at(i);
138  long offset = offsets.at(i);
139 
140  assertEq<InvalidArgument>(
141  weight_set.dims(0),
142  index_set.size(),
143  "found mismatch between index set size and weight set size");
144 
145  assertEq<InvalidArgument>(weight_set.dims(1),
146  1lu,
147  "all weight sets must be column vectors");
148 
149  Matrix<TXT> submatrix = mask.columns(index_set);
150  Matrix<TXT> factor(submatrix * weight_set);
151  // factor should in fact be a 1*1 matrix
152  factor.apply([&](auto& entry) { entry.addConstant(NTL::ZZX(offset)); });
153  result.template entrywiseOperation<TXT>(
154  factor,
155  [](auto& lhs, const auto& rhs) -> decltype(auto) {
156  lhs.multiplyBy(rhs);
157  return lhs;
158  });
159  }
160  return result;
161 }
162 
171 inline PolyMod partialMatchEncode(uint32_t input, const Context& context)
172 {
173  const long p = context.zMStar.getP();
174  std::vector<long> coeffs(context.zMStar.getOrdP());
175  // TODO - shouldn't keep checking input.
176  for (long i = 0; i < long(coeffs.size()) && input != 0; ++i) {
177  coeffs[i] = input % p;
178  input /= p;
179  }
180  return PolyMod(coeffs, context.slotRing);
181 }
182 
183 struct Expr;
184 class ColNumber;
185 
189 using QueryExpr = std::shared_ptr<Expr>;
190 
197 inline std::shared_ptr<ColNumber> makeQueryExpr(int cl)
198 {
199  return std::make_shared<ColNumber>(cl);
200 }
201 
207 struct Expr
208 {
209  virtual std::string eval() const = 0;
210  virtual ~Expr() = default;
211 };
212 
218 class ColNumber : public Expr
219 {
220 public:
225  std::string eval() const override { return std::to_string(column); }
226 
231  ColNumber(int c) : column(c) {}
232 
233 private:
234  int column;
235 };
236 
242 class And : public Expr
243 {
244 public:
252  std::string eval() const override
253  {
254  return lhs->eval() + " " + rhs->eval() + " &&";
255  }
256 
262  And(const QueryExpr& l, const QueryExpr& r) : lhs(l), rhs(r) {}
263 
264 private:
265  QueryExpr lhs;
266  QueryExpr rhs;
267 };
268 
274 class Or : public Expr
275 {
276 public:
284  std::string eval() const override
285  {
286  return lhs->eval() + " " + rhs->eval() + " ||";
287  }
288 
294  Or(const QueryExpr& l, const QueryExpr& r) : lhs(l), rhs(r) {}
295 
296 private:
297  QueryExpr lhs;
298  QueryExpr rhs;
299 };
300 
308 inline std::shared_ptr<And> operator&&(const QueryExpr& lhs,
309  const QueryExpr& rhs)
310 {
311  return std::make_shared<And>(lhs, rhs);
312 }
313 
321 inline std::shared_ptr<Or> operator||(const QueryExpr& lhs,
322  const QueryExpr& rhs)
323 {
324  return std::make_shared<Or>(lhs, rhs);
325 }
326 
331 struct Query_t
332 {
337  std::vector<std::vector<long>> Fs;
338 
343  std::vector<long> mus;
344 
350  std::vector<Matrix<long>> taus;
351 
356  bool containsOR = false;
357 
366  Query_t(const std::vector<std::vector<long>>& index_sets,
367  const std::vector<long>& offsets,
368  const std::vector<Matrix<long>>& weights,
369  const bool isThereAnOR) :
370  Fs(index_sets), mus(offsets), taus(weights), containsOR(isThereAnOR)
371  {}
372 
381  Query_t(std::vector<std::vector<long>>&& index_sets,
382  std::vector<long>&& offsets,
383  std::vector<Matrix<long>>&& weights,
384  bool isThereAnOR) :
385  Fs(index_sets), mus(offsets), taus(weights), containsOR(isThereAnOR)
386  {}
387 };
388 
395 {
396 
397  // 'outer' vec are the and groups and 'inner' are the or groups
398  using vecvec = std::vector<std::vector<long>>;
399 
400 public:
406  QueryBuilder(const QueryExpr& expr) : query_str(expr->eval()) {}
407 
414  Query_t build(long columns) const
415  {
416 
417  // Convert the query to "type 1" by expanding out necessary ORs
418  vecvec expr = expandOr(query_str);
419  bool containsOR = false;
420 
421  vecvec Fs(expr.size());
422  {
423  std::vector<long> v(columns);
424  std::iota(v.begin(), v.end(), 0);
425  std::fill(Fs.begin(), Fs.end(), v);
426  }
427  std::vector<long> mus(expr.size(), 1);
428  std::vector<Matrix<long>> taus;
429  taus.reserve(expr.size());
430 
431  // Create the taus
432  for (long i = 0; i < long(expr.size()); ++i) { // Each tau
433  mus[i] = 0; // Set mu to zero.
434  Matrix<long> M(columns, 1); // Create temp tau matrix
435  containsOR = (expr[i].size() > 1) ? true : false;
436  for (long j = 0; j < long(expr[i].size()); ++j) // Each column index
437  M(expr[i][j], 0) = 1; // Mark those columns as 1
438  taus.push_back(std::move(M));
439  }
440 
441  return Query_t(std::move(Fs), std::move(mus), std::move(taus), containsOR);
442  }
443 
444 private:
445  std::string query_str;
446 
447  void printStack(std::stack<vecvec> stack)
448  {
449  while (!stack.empty()) {
450  printVecVec(stack.top());
451  stack.pop();
452  }
453  }
454 
455  void printVecVec(const vecvec& vv)
456  {
457  for (const auto& v : vv) {
458  std::cout << "[ ";
459  for (const auto& e : v) {
460  std::cout << e << " ";
461  }
462  std::cout << "]";
463  }
464  std::cout << "\n";
465  }
466 
467  bool isNumber(const std::string& s) const
468  {
469  // Positive only
470  return std::all_of(s.begin(), s.end(), ::isdigit);
471  }
472 
473  vecvec expandOr(const std::string& s) const
474  {
475  std::stack<vecvec> convertStack;
476 
477  std::istringstream input{s};
478  std::ostringstream output{};
479 
480  std::string symbol;
481 
482  while (input >> symbol) {
483  if (!symbol.compare("&&")) {
484  // Squash the top into penultimate.
485  auto op = convertStack.top();
486  convertStack.pop();
487  auto& top = convertStack.top();
488  top.insert(top.end(), op.begin(), op.end());
489  } else if (!symbol.compare("||")) {
490  // Cartesian-esque product
491  auto op1 = convertStack.top();
492  convertStack.pop();
493  auto op2 = convertStack.top();
494  convertStack.pop();
495 
496  vecvec prod;
497  prod.reserve(op1.size() * op2.size());
498  for (const auto& i : op1)
499  for (const auto& j : op2) {
500  auto x = i;
501  x.insert(x.end(), j.begin(), j.end());
502  prod.push_back(std::move(x));
503  }
504 
505  convertStack.push(std::move(prod));
506  } else {
507  // Assume it is a number. But sanity check anyway.
508  assertTrue(isNumber(symbol),
509  "String is not a number: '" + symbol + "'");
510  convertStack.emplace(vecvec(1, {std::stol(symbol)}));
511  }
512  }
513 
514  // Now read answer off stack (should be size == 1).
515  assertEq<LogicError>(1UL,
516  convertStack.size(),
517  "Size of stack after expandOr should be 1");
518 
519  return std::move(convertStack.top());
520  }
521 };
522 
529 template <typename TXT>
530 class Database
531 {
532 public:
533  // FIXME: Generally, should Database own the Matrix uniquely?
534  // Should we force good practice and ask that Context always be shared_ptr?
535 
536  // FIXME: Should probably move Matrix or make it unique_ptr or both?
542  Database(const Matrix<TXT>& M, std::shared_ptr<const Context> c) :
543  data(M), context(c)
544  {}
545 
546  // FIXME: Should this option really exist?
555  Database(const Matrix<TXT>& M, const Context& c) :
556  data(M),
557  context(std::shared_ptr<const helib::Context>(&c, [](auto UNUSED p) {}))
558  {}
559 
560  // FIXME: Combination of TXT = ctxt and TXT2 = ptxt does not work
571  template <typename TXT2>
572  Matrix<TXT2> contains(const Query_t& lookup_query,
573  const Matrix<TXT2>& query_data) const;
574 
575  // FIXME: Combination of TXT = ctxt and TXT2 = ptxt does not work
585  template <typename TXT2>
586  Matrix<TXT2> getScore(const Query_t& weighted_query,
587  const Matrix<TXT2>& query_data) const;
588 
589  // TODO - correct name?
594  long columns() { return data.dims(1); }
595 
596 private:
597  Matrix<TXT> data;
598  std::shared_ptr<const Context> context;
599 };
600 
601 template <typename TXT>
602 template <typename TXT2>
604  const Query_t& lookup_query,
605  const Matrix<TXT2>& query_data) const
606 {
607  auto result = getScore<TXT2>(lookup_query, query_data);
608 
609  if (lookup_query.containsOR) {
610  // FLT on the scores
611  result.apply([&](auto& txt) {
612  txt.power(context->alMod.getPPowR() - 1);
613  return txt;
614  });
615  }
616 
617  return result;
618 }
619 
620 template <typename TXT>
621 template <typename TXT2>
623  const Query_t& weighted_query,
624  const Matrix<TXT2>& query_data) const
625 {
626  auto mask = calculateMasks(*(context->ea), query_data, this->data);
627 
628  auto result = calculateScores(weighted_query.Fs,
629  weighted_query.mus,
630  weighted_query.taus,
631  mask);
632 
633  return result;
634 }
635 
636 } // namespace helib
637 
638 #endif
An object that mimics the functionality of the Ctxt object, and acts as a convenient entry point for ...
Definition: Ptxt.h:280
An object used to construct a Query_t object from a logical expression.
Definition: partialMatch.h:395
Database(const Matrix< TXT > &M, std::shared_ptr< const Context > c)
Constructor.
Definition: partialMatch.h:542
Matrix< TXT > calculateMasks(const EncryptedArray &ea, Matrix< TXT > query, const Matrix< Ptxt< BGV >> &database)
Given a query set and a database, calculates a mask of {0,1} where 1 signifies a matching element and...
Definition: partialMatch.h:39
Matrix< TXT > calculateScores(const std::vector< std::vector< long >> index_sets, const std::vector< long > &offsets, const std::vector< Matrix< long >> &weights, const Matrix< TXT > &mask)
Given a mask and information about the query to be performed, calculates a score for each matching el...
Definition: partialMatch.h:119
void assertTrue(const T &value, const std::string &message)
Definition: assertions.h:61
QueryBuilder(const QueryExpr &expr)
Constructor.
Definition: partialMatch.h:406
Matrix< TXT2 > getScore(const Query_t &weighted_query, const Matrix< TXT2 > &query_data) const
Function for performing a weighted partial match given a query expression and query data.
Definition: partialMatch.h:622
Query_t(std::vector< std::vector< long >> &&index_sets, std::vector< long > &&offsets, std::vector< Matrix< long >> &&weights, bool isThereAnOR)
Constructor.
Definition: partialMatch.h:381
std::string eval() const override
Function for returning the logical AND expression in reverse polish notation where the AND operation ...
Definition: partialMatch.h:252
long getOrdP() const
The order of p in (Z/mZ)^*.
Definition: PAlgebra.h:171
Base structure for logical expressions.
Definition: partialMatch.h:208
An object representing a column of a database as an expression which inherits from Expr.
Definition: partialMatch.h:219
std::shared_ptr< And > operator&&(const QueryExpr &lhs, const QueryExpr &rhs)
Overloaded operator for creating a shared pointer to an AND expression.
Definition: partialMatch.h:308
And(const QueryExpr &l, const QueryExpr &r)
Constructor.
Definition: partialMatch.h:262
std::size_t dims(int i) const
Definition: Matrix.h:225
A simple wrapper for a smart pointer to an EncryptedArrayBase. This is the interface that higher-leve...
Definition: EncryptedArray.h:1233
long getP() const
Returns p.
Definition: PAlgebra.h:165
std::vector< long > mus
std::vector of offsets. Each offset is a constant value. There should be a single offset for each ind...
Definition: partialMatch.h:343
Structure containing all information required for an HE query.
Definition: partialMatch.h:332
std::vector< Matrix< long > > taus
std::vector of a set of weights. Each weight set corresponds to a single index set where each individ...
Definition: partialMatch.h:350
ColNumber(int c)
Constructor.
Definition: partialMatch.h:231
std::string eval() const override
Function for returning the logical OR expression in reverse polish notation where the OR operation is...
Definition: partialMatch.h:284
Query_t build(long columns) const
Function for building the Query_t object from the expression.
Definition: partialMatch.h:414
std::string eval() const override
Function for returning the column number of the object.
Definition: partialMatch.h:225
An object representing the logical OR expression which inherits from Expr.
Definition: partialMatch.h:275
std::shared_ptr< ColNumber > makeQueryExpr(int cl)
Utility function for creating a shared pointer to a specified column in a query.
Definition: partialMatch.h:197
long columns()
Returns number of columns in the database.
Definition: partialMatch.h:594
Inherits from Exception and std::invalid_argument.
Definition: exceptions.h:140
std::vector< std::vector< long > > Fs
std::vector of index sets. These index sets specify the indexes of the columns in each column subset.
Definition: partialMatch.h:337
An object representing the logical AND expression which inherits from Expr.
Definition: partialMatch.h:243
Definition: apiAttributes.h:21
const std::vector< T > & data() const
Definition: Matrix.h:465
Or(const QueryExpr &l, const QueryExpr &r)
Constructor.
Definition: partialMatch.h:294
Query_t(const std::vector< std::vector< long >> &index_sets, const std::vector< long > &offsets, const std::vector< Matrix< long >> &weights, const bool isThereAnOR)
Constructor.
Definition: partialMatch.h:366
PolyMod partialMatchEncode(uint32_t input, const Context &context)
Given a value, encode the value across the coefficients of a polynomial.
Definition: partialMatch.h:171
Tensor< T, 2 > & transpose()
Definition: Matrix.h:393
Definition: Matrix.h:149
Database(const Matrix< TXT > &M, const Context &c)
Constructor.
Definition: partialMatch.h:555
bool containsOR
Flag indicating if the query contains a logical OR operation. This is used for optimization purposes.
Definition: partialMatch.h:356
Maintaining the parameters.
Definition: Context.h:121
std::shared_ptr< Or > operator||(const QueryExpr &lhs, const QueryExpr &rhs)
Overloaded operator for creating a shared pointer to an OR expression.
Definition: partialMatch.h:321
Matrix< TXT2 > contains(const Query_t &lookup_query, const Matrix< TXT2 > &query_data) const
Function for performing a database lookup given a query expression and query data.
Definition: partialMatch.h:603
virtual std::string eval() const =0
An object that contains an NTL::ZZX polynomial along with a coefficient modulus p2r and a polynomial ...
Definition: PolyMod.h:47
void mapTo01(const EncryptedArray &ea, Ctxt &ctxt)
Definition: eqtesting.cpp:35
std::shared_ptr< PolyModRing > slotRing
The structure of a single slot of the plaintext space.
Definition: Context.h:145
def prod(iterable)
Definition: encode.py:67
std::shared_ptr< Expr > QueryExpr
An alias for a shared pointer to an Expr object.
Definition: partialMatch.h:189
Tensor< T, N > columns(const std::vector< long > &js) const
Definition: Matrix.h:272
PAlgebra zMStar
The structure of Zm*.
Definition: Context.h:131
virtual ~Expr()=default
Tensor< T, N > & apply(std::function< void(T &x)> fn)
Definition: Matrix.h:372
An object representing a database which is a HElib::Matrix<TXT>.
Definition: partialMatch.h:531