matmul.h
1 /* Copyright (C) 2012-2020 IBM Corp.
2  * This program is Licensed under the Apache License, Version 2.0
3  * (the "License"); you may not use this file except in compliance
4  * with the License. You may obtain a copy of the License at
5  * http://www.apache.org/licenses/LICENSE-2.0
6  * Unless required by applicable law or agreed to in writing, software
7  * distributed under the License is distributed on an "AS IS" BASIS,
8  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9  * See the License for the specific language governing permissions and
10  * limitations under the License. See accompanying LICENSE file.
11  */
12 #ifndef HELIB_MATMUL_H
13 #define HELIB_MATMUL_H
14 
15 #include <helib/EncryptedArray.h>
16 
17 namespace helib {
18 
19 class MatMulFullExec;
20 
21 // Abstract base class for representing a linear transformation on a full
22 // std::vector.
24 {
25 public:
26  virtual ~MatMulFull() {}
27  virtual const EncryptedArray& getEA() const = 0;
29 };
30 
31 // Concrete derived class that defines the matrix entries.
32 template <typename type>
34 {
35 public:
36  PA_INJECT(type)
37 
38  // Get (i, j) entry of matrix.
39  // Should return true when the entry is a zero.
40  virtual bool get(RX& out, long i, long j) const = 0;
41 };
42 
43 //====================================
44 
46 
47 // Abstract base class for representing a block linear transformation on a full
48 // std::vector.
50 {
51 public:
52  virtual ~BlockMatMulFull() {}
53  virtual const EncryptedArray& getEA() const = 0;
55 };
56 
57 // Concrete derived class that defines the matrix entries.
58 template <typename type>
60 {
61 public:
62  PA_INJECT(type)
63 
64  // Get (i, j) entry of matrix.
65  // Each entry is a d x d matrix over the base ring.
66  // Should return true when the entry is a zero.
67  virtual bool get(mat_R& out, long i, long j) const = 0;
68 };
69 
70 //====================================
71 
72 class MatMul1DExec;
73 
74 // Abstract base class for representing a 1D linear transformation.
75 class MatMul1D
76 {
77 public:
78  virtual ~MatMul1D() {}
79  virtual const EncryptedArray& getEA() const = 0;
80  virtual long getDim() const = 0;
82 };
83 
84 // An intermediate class that is mainly intended for internal use.
85 template <typename type>
86 class MatMul1D_partial : public MatMul1D
87 {
88 public:
89  PA_INJECT(type)
90 
91  // Get the i'th diagonal, encoded as a single constant.
92  // MatMul1D_derived (below) supplies a default implementation,
93  // which can be overridden in special circumstances.
94  virtual void processDiagonal(RX& poly,
95  long i,
96  const EncryptedArrayDerived<type>& ea) const = 0;
97 };
98 
99 // Concrete derived class that defines the matrix entries.
100 template <typename type>
101 class MatMul1D_derived : public MatMul1D_partial<type>
102 {
103 public:
104  PA_INJECT(type)
105 
106  // Should return true if their are multiple (different) transforms
107  // among the various components.
108  virtual bool multipleTransforms() const = 0;
109 
110  // Get coordinate (i, j) of the kth component.
111  // Should return true when the entry is a zero.
112  virtual bool get(RX& out, long i, long j, long k) const = 0;
113 
114  void processDiagonal(RX& poly,
115  long i,
116  const EncryptedArrayDerived<type>& ea) const override;
117 };
118 
119 class MatMul1D_CKKS : public MatMul1D
120 {
121 public:
122  // Get coordinate (i, j)
123  virtual std::complex<double> get(long i, long j) const = 0;
124 
125  void processDiagonal(zzX& poly,
126  double& size,
127  double& factor,
128  long i,
129  const EncryptedArrayCx& ea) const;
130 };
131 
132 //====================================
133 
134 class BlockMatMul1DExec;
135 
136 // Abstract base class for representing a block 1D linear transformation.
138 {
139 public:
140  virtual ~BlockMatMul1D() {}
141  virtual const EncryptedArray& getEA() const = 0;
142  virtual long getDim() const = 0;
144 };
145 
146 // An intermediate class that is mainly intended for internal use.
147 template <typename type>
149 {
150 public:
151  PA_INJECT(type)
152 
153  // Get the i'th diagonal, encoded as a std::vector of d constants,
154  // where d is the order of p.
155  // BlockMatMul1D_derived (below) supplies a default implementation,
156  // which can be overridden in special circumstances.
157  virtual bool processDiagonal(std::vector<RX>& poly,
158  long i,
159  const EncryptedArrayDerived<type>& ea) const = 0;
160 };
161 
162 // Concrete derived class that defines the matrix entries.
163 template <typename type>
165 {
166 public:
167  PA_INJECT(type)
168 
169  // Should return true if their are multiple (different) transforms
170  // among the various components.
171  virtual bool multipleTransforms() const = 0;
172 
173  // Get coordinate (i, j) of the kth component.
174  // Each entry is a d x d matrix over the base ring.
175  // Should return true when the entry is a zero.
176  virtual bool get(mat_R& out, long i, long j, long k) const = 0;
177 
178  bool processDiagonal(std::vector<RX>& poly,
179  long i,
180  const EncryptedArrayDerived<type>& ea) const override;
181 };
182 
183 //====================================
184 
185 struct ConstMultiplier;
186 // Defined in matmul.cpp.
187 // Holds a constant by which a ciphertext can be multiplied.
188 // Internally, it is represented as either zzX or a DoubleCRT.
189 // The former occupies less space, but the latter makes for
190 // much faster multiplication.
191 
193 {
194  std::vector<std::shared_ptr<ConstMultiplier>> multiplier;
195 
196  // Upgrade zzX constants to DoubleCRT constants.
197  void upgrade(const Context& context);
198 };
199 
200 //====================================
201 
202 // Abstract base case for multiplying an encrypted std::vector by a plaintext
203 // matrix.
205 {
206 public:
207  virtual ~MatMulExecBase() {}
208 
209  virtual const EncryptedArray& getEA() const = 0;
210 
211  // Upgrade zzX constants to DoubleCRT constants.
212  virtual void upgrade() = 0;
213 
214  // If ctxt encrypts a row std::vector v, then this replaces ctxt
215  // by an encryption of the row std::vector v*mat, where mat is
216  // a matrix provided to the constructor of one of the
217  // concrete subclasses MatMul1DExec, BlockMatMul1DExec,
218  // MatMulFullExec, BlockMatMulFullExec, defined below.
219  virtual void mul(Ctxt& ctxt) const = 0;
220 };
221 
222 //====================================
223 
224 // Class used to multiply an encrypted row std::vector by a 1D linear
225 // transformation.
227 {
228 public:
230 
231  long dim;
232  long D;
233  bool native;
234  bool minimal;
235  long g;
236 
238  ConstMultiplierCache cache1; // only for non-native dimension
239 
240  // The constructor encodes all the constants for a given
241  // matrix in zzX format.
242  // The mat argument defines the entries of the matrix.
243  // Use the upgrade method (below) to convert to DoubleCRT format.
244  // If the minimal flag is set to true, a strategy that relies
245  // on a minimal number of key switching matrices will be used;
246  // this is intended for use in conjunction with the
247  // addMinimal{1D,Frb}Matrices routines declared in helib.h.
248  // If the minimal flag is false, it is best to use the
249  // addSome{1D,Frb}Matrices routines declared in helib.h.
250  explicit MatMul1DExec(const MatMul1D& mat, bool minimal = false);
251 
252  // Replaces an encryption of row std::vector v by encryption of v*mat
253  void mul(Ctxt& ctxt) const override;
254 
255  // Upgrades encoded constants from zzX to DoubleCRT.
256  void upgrade() override
257  {
258  cache.upgrade(ea.getContext());
259  cache1.upgrade(ea.getContext());
260  }
261 
262  const EncryptedArray& getEA() const override { return ea; }
263 };
264 
265 //====================================
266 
267 // Class used to multiply an encrypted row std::vector by a block 1D linear
268 // transformation.
270 {
271 public:
273 
274  long dim;
275  long D;
276  long d;
277  bool native;
278  long strategy;
279 
281  ConstMultiplierCache cache1; // only for non-native dimension
282 
283  // The constructor encodes all the constants for a given
284  // matrix in zzX format.
285  // The mat argument defines the entries of the matrix.
286  // Use the upgrade method (below) to convert to DoubleCRT format.
287  // If the minimal flag is set to true, a strategy that relies
288  // on a minimal number of key switching matrices will be used;
289  // this is intended for use in conjunction with the
290  // addMinimal{1D,Frb}Matrices routines declared in helib.h.
291  // If the minimal flag is false, it is best to use the
292  // addSome{1D,Frb}Matrices routines declared in helib.h.
293  explicit BlockMatMul1DExec(const BlockMatMul1D& mat, bool minimal = false);
294 
295  // Replaces an encryption of row std::vector v by encryption of v*mat
296  void mul(Ctxt& ctxt) const override;
297 
298  // Upgrades encoded constants from zzX to DoubleCRT.
299  void upgrade() override
300  {
301  cache.upgrade(ea.getContext());
302  cache1.upgrade(ea.getContext());
303  }
304 
305  const EncryptedArray& getEA() const override { return ea; }
306 };
307 
308 //====================================
309 
310 // Class used to multiply an encrypted row std::vector by a full linear
311 // transformation.
313 {
314 public:
316  bool minimal;
317  std::vector<long> dims;
318  std::vector<MatMul1DExec> transforms;
319 
320  // The constructor encodes all the constants for a given
321  // matrix in zzX format.
322  // The mat argument defines the entries of the matrix.
323  // Use the upgrade method (below) to convert to DoubleCRT format.
324  // If the minimal flag is set to true, a strategy that relies
325  // on a minimal number of key switching matrices will be used;
326  // this is intended for use in conjunction with the
327  // addMinimal{1D,Frb}Matrices routines declared in helib.h.
328  // If the minimal flag is false, it is best to use the
329  // addSome{1D,Frb}Matrices routines declared in helib.h.
330  explicit MatMulFullExec(const MatMulFull& mat, bool minimal = false);
331 
332  // Replaces an encryption of row std::vector v by encryption of v*mat
333  void mul(Ctxt& ctxt) const override;
334 
335  // Upgrades encoded constants from zzX to DoubleCRT.
336  void upgrade() override
337  {
338  for (auto& t : transforms)
339  t.upgrade();
340  }
341 
342  const EncryptedArray& getEA() const override { return ea; }
343 
344  // This really should be private.
345  long rec_mul(Ctxt& acc, const Ctxt& ctxt, long dim, long idx) const;
346 };
347 
348 //====================================
349 
350 // Class used to multiply an encrypted row std::vector by a full block linear
351 // transformation.
353 {
354 public:
356  bool minimal;
357  std::vector<long> dims;
358  std::vector<BlockMatMul1DExec> transforms;
359 
360  // The constructor encodes all the constants for a given
361  // matrix in zzX format.
362  // The mat argument defines the entries of the matrix.
363  // Use the upgrade method (below) to convert to DoubleCRT format.
364  // If the minimal flag is set to true, a strategy that relies
365  // on a minimal number of key switching matrices will be used;
366  // this is intended for use in conjunction with the
367  // addMinimal{1D,Frb}Matrices routines declared in helib.h.
368  // If the minimal flag is false, it is best to use the
369  // addSome{1D,Frb}Matrices routines declared in helib.h.
370  explicit BlockMatMulFullExec(const BlockMatMulFull& mat,
371  bool minimal = false);
372 
373  // Replaces an encryption of row std::vector v by encryption of v*mat
374  void mul(Ctxt& ctxt) const override;
375 
376  // Upgrades encoded constants from zzX to DoubleCRT.
377  void upgrade() override
378  {
379  for (auto& t : transforms)
380  t.upgrade();
381  }
382 
383  const EncryptedArray& getEA() const override { return ea; }
384 
385  // This really should be private.
386  long rec_mul(Ctxt& acc, const Ctxt& ctxt, long dim, long idx) const;
387 };
388 
389 //===================================
390 
391 // ctxt = \sum_{i=0}^{d-1} \sigma^i(ctxt),
392 // where d = order of p mod m, and \sigma is the Frobenius map
393 
394 void traceMap(Ctxt& ctxt);
395 
396 //====================================
397 
398 // These routines apply linear transformation to plaintext arrays.
399 // Mainly for testing purposes.
400 void mul(PlaintextArray& pa, const MatMul1D& mat);
401 void mul(PlaintextArray& pa, const BlockMatMul1D& mat);
402 void mul(PlaintextArray& pa, const MatMulFull& mat);
403 void mul(PlaintextArray& pa, const BlockMatMulFull& mat);
404 
405 // These are used mainly for performance evaluation.
406 
407 extern int fhe_test_force_bsgs;
408 // Controls whether or not we use BSGS multiplication.
409 // 1 to force on, -1 to force off, 0 for default behaviour.
410 
411 extern int fhe_test_force_hoist;
412 // Controls whether ot not we use hoisting.
413 // -1 to force off, 0 for default behaviour.
414 
415 } // namespace helib
416 
417 #endif // ifndef HELIB_MATMUL_H
std::vector< long > dims
Definition: matmul.h:357
void mul(const EncryptedArray &ea, PlaintextArray &pa, const PlaintextArray &other)
Definition: EncryptedArray.cpp:1061
ConstMultiplierCache cache1
Definition: matmul.h:281
void upgrade(const Context &context)
Definition: matmul.cpp:400
virtual std::complex< double > get(long i, long j) const =0
virtual ~MatMulExecBase()
Definition: matmul.h:207
NTL::Vec< long > zzX
Definition: zzX.h:24
virtual long getDim() const =0
void upgrade() override
Definition: matmul.h:299
int fhe_test_force_bsgs
Definition: matmul.cpp:23
BlockMatMul1DExec ExecType
Definition: matmul.h:143
long dim
Definition: matmul.h:274
const EncryptedArray & getEA() const override
Definition: matmul.h:383
Definition: matmul.h:193
virtual const EncryptedArray & getEA() const =0
void traceMap(Ctxt &ctxt)
Definition: matmul.cpp:2830
Definition: matmul.h:120
Definition: matmul.h:270
A simple wrapper for a smart pointer to an EncryptedArrayBase. This is the interface that higher-leve...
Definition: EncryptedArray.h:1233
BlockMatMulFullExec ExecType
Definition: matmul.h:54
virtual const EncryptedArray & getEA() const =0
MatMulFullExec ExecType
Definition: matmul.h:28
virtual void upgrade()=0
std::vector< long > dims
Definition: matmul.h:317
virtual ~BlockMatMulFull()
Definition: matmul.h:52
const EncryptedArray & getEA() const override
Definition: matmul.h:305
long strategy
Definition: matmul.h:278
const EncryptedArray & ea
Definition: matmul.h:272
virtual long getDim() const =0
ConstMultiplierCache cache
Definition: matmul.h:280
long D
Definition: matmul.h:275
const EncryptedArray & ea
Definition: matmul.h:229
virtual ~MatMul1D()
Definition: matmul.h:78
ConstMultiplierCache cache1
Definition: matmul.h:238
A different derived class to be used for the approximate-numbers scheme.
Definition: EncryptedArray.h:667
const EncryptedArray & ea
Definition: matmul.h:315
Definition: matmul.h:149
Definition: matmul.h:76
Derived concrete implementation of EncryptedArrayBase.
Definition: EncryptedArray.h:315
Definition: matmul.h:60
virtual const EncryptedArray & getEA() const =0
const EncryptedArray & ea
Definition: matmul.h:355
const Context & getContext() const
Definition: EncryptedArray.h:1301
virtual const EncryptedArray & getEA() const =0
virtual const EncryptedArray & getEA() const =0
bool minimal
Definition: matmul.h:234
void upgrade() override
Definition: matmul.h:336
bool native
Definition: matmul.h:277
std::vector< BlockMatMul1DExec > transforms
Definition: matmul.h:358
long g
Definition: matmul.h:235
bool minimal
Definition: matmul.h:356
virtual void mul(Ctxt &ctxt) const =0
Definition: apiAttributes.h:21
Definition: matmul.h:34
Definition: matmul.h:102
std::vector< std::shared_ptr< ConstMultiplier > > multiplier
Definition: matmul.h:194
Definition: matmul.h:205
bool minimal
Definition: matmul.h:316
const EncryptedArray & getEA() const override
Definition: matmul.h:262
void upgrade() override
Definition: matmul.h:256
MatMul1DExec ExecType
Definition: matmul.h:81
const EncryptedArray & getEA() const override
Definition: matmul.h:342
Maintaining the parameters.
Definition: Context.h:121
Definition: matmul.h:138
void upgrade() override
Definition: matmul.h:377
virtual ~BlockMatMul1D()
Definition: matmul.h:140
Definition: matmul.h:227
Definition: matmul.h:24
bool native
Definition: matmul.h:233
ConstMultiplierCache cache
Definition: matmul.h:237
Definition: matmul.h:165
Definition: matmul.h:353
long d
Definition: matmul.h:276
A Ctxt object holds a single ciphertext.
Definition: Ctxt.h:273
virtual bool get(RX &out, long i, long j) const =0
Definition: matmul.h:87
long D
Definition: matmul.h:232
int fhe_test_force_hoist
Definition: matmul.cpp:24
long dim
Definition: matmul.h:231
virtual ~MatMulFull()
Definition: matmul.h:26
std::vector< MatMul1DExec > transforms
Definition: matmul.h:318
Definition: matmul.h:313
Definition: matmul.h:50
Definition: matmul.cpp:308