Skip to content

Commit 2243903

Browse files
committed
add giant-step parameter to MatMulBase
1 parent b0af0df commit 2243903

File tree

3 files changed

+79
-74
lines changed

3 files changed

+79
-74
lines changed

src/Test_matmul1D.cpp

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,14 @@
2525
#include "matmul.h"
2626

2727
// Forward declerations
28-
static MatMulBase* buildRandomMatrix(const EncryptedArray& ea, long dim);
29-
static MatMulBase* buildRandomMultiMatrix(const EncryptedArray& ea, long dim);
30-
static MatMulBase* buildRandomBlockMatrix(const EncryptedArray& ea, long dim);
31-
static MatMulBase*
32-
buildRandomMultiBlockMatrix(const EncryptedArray& ea, long dim);
28+
static MatMulBase* buildRandomMatrix(const EncryptedArray& ea,
29+
long dim, long giantStep);
30+
static MatMulBase* buildRandomMultiMatrix(const EncryptedArray& ea,
31+
long dim, long giantStep);
32+
static MatMulBase* buildRandomBlockMatrix(const EncryptedArray& ea,
33+
long dim);
34+
static MatMulBase* buildRandomMultiBlockMatrix(const EncryptedArray& ea,
35+
long dim);
3336

3437

3538
// The callback interface for the matrix-multiplication routines.
@@ -45,8 +48,8 @@ template<class type> class RandomMultiMatrix : public MatMul<type> {
4548

4649
public:
4750
virtual ~RandomMultiMatrix() {}
48-
RandomMultiMatrix(const EncryptedArray& _ea, long _dim)
49-
: MatMul<type>(_ea), dim(_dim)
51+
RandomMultiMatrix(const EncryptedArray& _ea, long _dim, long g)
52+
: MatMul<type>(_ea, g), dim(_dim)
5053
{
5154
RBak bak; bak.save(); _ea.getAlMod().restoreContext();
5255
long n = _ea.size();
@@ -77,14 +80,15 @@ template<class type> class RandomMultiMatrix : public MatMul<type> {
7780
}
7881
};
7982

80-
static MatMulBase* buildRandomMultiMatrix(const EncryptedArray& ea, long dim)
83+
static MatMulBase*
84+
buildRandomMultiMatrix(const EncryptedArray& ea, long dim, long giantStep)
8185
{
8286
switch (ea.getTag()) {
8387
case PA_GF2_tag: {
84-
return new RandomMultiMatrix<PA_GF2>(ea, dim);
88+
return new RandomMultiMatrix<PA_GF2>(ea, dim, giantStep);
8589
}
8690
case PA_zz_p_tag: {
87-
return new RandomMultiMatrix<PA_zz_p>(ea, dim);
91+
return new RandomMultiMatrix<PA_zz_p>(ea, dim, giantStep);
8892
}
8993
default: return 0;
9094
}
@@ -171,8 +175,8 @@ template<class type> class RandomMatrix : public MatMul<type> {
171175

172176
public:
173177
virtual ~RandomMatrix() {}
174-
RandomMatrix(const EncryptedArray& _ea, long _dim):
175-
MatMul<type>(_ea), dim(_dim)
178+
RandomMatrix(const EncryptedArray& _ea, long _dim, long g):
179+
MatMul<type>(_ea,g), dim(_dim)
176180
{
177181
RBak bak; bak.save(); _ea.getAlMod().restoreContext();
178182
long n = _ea.size();
@@ -199,14 +203,15 @@ template<class type> class RandomMatrix : public MatMul<type> {
199203
}
200204
};
201205

202-
static MatMulBase* buildRandomMatrix(const EncryptedArray& ea, long dim)
206+
static MatMulBase*
207+
buildRandomMatrix(const EncryptedArray& ea, long dim, long giantStep)
203208
{
204209
switch (ea.getTag()) {
205210
case PA_GF2_tag: {
206-
return new RandomMatrix<PA_GF2>(ea, dim);
211+
return new RandomMatrix<PA_GF2>(ea, dim, giantStep);
207212
}
208213
case PA_zz_p_tag: {
209-
return new RandomMatrix<PA_zz_p>(ea, dim);
214+
return new RandomMatrix<PA_zz_p>(ea, dim, giantStep);
210215
}
211216
default: return 0;
212217
}
@@ -255,7 +260,8 @@ class RandomBlockMatrix : public BlockMatMul<type> {
255260
}
256261
};
257262

258-
static MatMulBase* buildRandomBlockMatrix(const EncryptedArray& ea, long dim)
263+
static MatMulBase*
264+
buildRandomBlockMatrix(const EncryptedArray& ea, long dim)
259265
{
260266
switch (ea.getTag()) {
261267
case PA_GF2_tag: {
@@ -269,18 +275,11 @@ static MatMulBase* buildRandomBlockMatrix(const EncryptedArray& ea, long dim)
269275
}
270276
//! \endcond
271277

272-
void TestIt(FHEcontext& context, long d, long dim, bool verbose)
278+
void TestIt(FHEcontext& context, long g, long dim, bool verbose)
273279
{
274-
ZZX G;
275-
if (d == 0)
276-
G = context.alMod.getFactorsOverZZ()[0];
277-
else
278-
G = makeIrredPoly(context.zMStar.getP(), d);
279-
280280
if (verbose) {
281281
context.zMStar.printout();
282282
cout << endl;
283-
cout << "G = " << G << "\n";
284283
}
285284

286285
FHESecKey secretKey(context);
@@ -289,12 +288,12 @@ void TestIt(FHEcontext& context, long d, long dim, bool verbose)
289288

290289
addSome1DMatrices(secretKey); // compute key-switching matrices that we need
291290
addFrbMatrices(secretKey); // compute key-switching matrices that we need
292-
EncryptedArray ea(context, G);
291+
EncryptedArray ea(context, context.alMod);
293292

294293
// Test a "normal" matrix over the extension field
295294
{
296295
// choose a random plaintext square matrix
297-
std::unique_ptr< MatMulBase > ptr(buildRandomMatrix(ea, dim));
296+
std::unique_ptr< MatMulBase > ptr(buildRandomMatrix(ea, dim, g));
298297

299298
// choose a random plaintext vector
300299
NewPlaintextArray v(ea);
@@ -307,7 +306,7 @@ void TestIt(FHEcontext& context, long d, long dim, bool verbose)
307306

308307
cout << " Multiplying 1D with MatMulBase... " << std::flush;
309308
matMul1D(v, *ptr, dim);
310-
matMul1D(ctxt2, *ptr, dim, cachezzX, /*giantStep=*/2);
309+
matMul1D(ctxt2, *ptr, dim, cachezzX);
311310
// multiply ciphertext and build cache
312311
NewPlaintextArray v1(ea);
313312
ea.decrypt(ctxt2, secretKey, v1); // decrypt the ciphertext vector
@@ -319,8 +318,7 @@ void TestIt(FHEcontext& context, long d, long dim, bool verbose)
319318

320319
cout << " Multiplying 1D with MatMulBase+dcrt cache... " << std::flush;
321320
ctxt2 = ctxt;
322-
matMul1D(ctxt2, *ptr, dim, cacheDCRT, /*giantStep=*/2);
323-
// upgrade cache and use in multiplication
321+
matMul1D(ctxt2, *ptr, dim, cacheDCRT); // upgrade cache and use in multiplication
324322

325323
ea.decrypt(ctxt2, secretKey, v1); // decrypt the ciphertext vector
326324

@@ -331,7 +329,7 @@ void TestIt(FHEcontext& context, long d, long dim, bool verbose)
331329
}
332330
{
333331
// choose a random plaintext square matrix
334-
std::unique_ptr< MatMulBase > ptr(buildRandomMatrix(ea,dim));
332+
std::unique_ptr< MatMulBase > ptr(buildRandomMatrix(ea,dim,g));
335333

336334
// choose a random plaintext vector
337335
NewPlaintextArray v(ea);
@@ -343,8 +341,8 @@ void TestIt(FHEcontext& context, long d, long dim, bool verbose)
343341
Ctxt ctxt2 = ctxt;
344342

345343
cout << " Multiplying 1D with MatMulBase+zzx cache... " << std::flush;
346-
buildCache4MatMul1D(*ptr, dim, cachezzX, /*giantStep=*/2);// build the cache
347-
matMul1D(ctxt, *ptr, dim, cacheEmpty, /*giantStep=*/2); // then use it
344+
buildCache4MatMul1D(*ptr, dim, cachezzX);// build the cache
345+
matMul1D(ctxt, *ptr, dim); // then use it
348346
matMul1D(v, *ptr, dim); // multiply the plaintext vector
349347

350348
NewPlaintextArray v1(ea);
@@ -359,7 +357,7 @@ void TestIt(FHEcontext& context, long d, long dim, bool verbose)
359357
// Test a "multi" matrix over the extension field
360358
{
361359
// choose a random plaintext square matrix
362-
std::unique_ptr< MatMulBase > ptr(buildRandomMultiMatrix(ea,dim));
360+
std::unique_ptr< MatMulBase > ptr(buildRandomMultiMatrix(ea,dim,g));
363361

364362
// choose a random plaintext vector
365363
NewPlaintextArray v(ea);
@@ -372,8 +370,7 @@ void TestIt(FHEcontext& context, long d, long dim, bool verbose)
372370

373371
cout << "\n Multiplying multi 1D with MatMulBase... " << std::flush;
374372
matMulti1D(v, *ptr, dim);
375-
matMulti1D(ctxt2, *ptr, dim, cachezzX, /*giantStep=*/2);
376-
// multiply ciphertext and build cache
373+
matMulti1D(ctxt2, *ptr, dim, cachezzX); // multiply ciphertext and build cache
377374
NewPlaintextArray v1(ea);
378375
ea.decrypt(ctxt2, secretKey, v1); // decrypt the ciphertext vector
379376

@@ -384,8 +381,7 @@ void TestIt(FHEcontext& context, long d, long dim, bool verbose)
384381

385382
cout <<" Multiplying multi 1D with MatMulBase+dcrt cache... "<< std::flush;
386383
ctxt2 = ctxt;
387-
matMulti1D(ctxt2, *ptr, dim, cacheDCRT, /*giantStep=*/2);
388-
// upgrade cache and use in multiplication
384+
matMulti1D(ctxt2, *ptr, dim, cacheDCRT); // upgrade cache and use in multiplication
389385

390386
ea.decrypt(ctxt2, secretKey, v1); // decrypt the ciphertext vector
391387

@@ -396,7 +392,7 @@ void TestIt(FHEcontext& context, long d, long dim, bool verbose)
396392
}
397393
{
398394
// choose a random plaintext square matrix
399-
std::unique_ptr< MatMulBase > ptr(buildRandomMultiMatrix(ea,dim));
395+
std::unique_ptr< MatMulBase > ptr(buildRandomMultiMatrix(ea,dim,g));
400396

401397
// choose a random plaintext vector
402398
NewPlaintextArray v(ea);
@@ -408,8 +404,8 @@ void TestIt(FHEcontext& context, long d, long dim, bool verbose)
408404
Ctxt ctxt2 = ctxt;
409405

410406
cout << " Multiplying multi 1D with MatMulBase+zzx cache... " << std::flush;
411-
buildCache4MatMulti1D(*ptr, dim, cachezzX, /*giantStep=*/2);// build the cache
412-
matMulti1D(ctxt, *ptr, dim, cacheEmpty, /*giantStep=*/2); // then use it
407+
buildCache4MatMulti1D(*ptr, dim, cachezzX);// build the cache
408+
matMulti1D(ctxt, *ptr, dim); // then use it
413409
matMulti1D(v, *ptr, dim); // multiply the plaintext vector
414410

415411
NewPlaintextArray v1(ea);
@@ -517,13 +513,12 @@ void TestIt(FHEcontext& context, long d, long dim, bool verbose)
517513
/* Testing the functionality of multiplying an encrypted vector by a
518514
* plaintext matrix, either over the extension- or the base-field/ring.
519515
*
520-
* Usage: Test_matmul1D [m p r d L dim verbose]
516+
* Usage: Test_matmul1D [optional params]
521517
*
522518
* m defines the cyclotomic polynomial Phi_m(X)
523519
* p is the plaintext base [default=2]
524520
* r is the lifting [default=1]
525-
* d is the degree of the field extension [default==1]
526-
* (d == 0 => factors[0] defined the extension)
521+
* g is the giant-step parameter [defauls=2]
527522
* L is the # of primes in the modulus chain [default=4]
528523
* dim is the dimension alng which we multiply [default=0]
529524
* verbose print timing info [default=0]
@@ -538,9 +533,8 @@ int main(int argc, char *argv[])
538533
amap.arg("p", p, "plaintext base");
539534
long r=1;
540535
amap.arg("r", r, "lifting");
541-
long d=0;
542-
amap.arg("d", d, "degree of the field extension");
543-
amap.note("d == 0 => factors[0] defines extension");
536+
long g=2;
537+
amap.arg("g", g, "giant-step parameter");
544538
long L=4;
545539
amap.arg("L", L, "# of levels in the modulus chain");
546540
long dim=0;
@@ -560,8 +554,8 @@ int main(int argc, char *argv[])
560554
cout << "*** matmul1D: m=" << m
561555
<< ", p=" << p
562556
<< ", r=" << r
563-
<< ", d=" << d
564557
<< ", L=" << L
558+
<< ", g=" << g
565559
<< ", dim=" << dim
566560
// << ", gens=" << gens
567561
// << ", ords=" << ords
@@ -576,7 +570,7 @@ int main(int argc, char *argv[])
576570
FHEcontext context(m, p, r, gens1, ords1);
577571
buildModChain(context, L, /*c=*/3);
578572

579-
TestIt(context, d, dim, verbose);
573+
TestIt(context, g, dim, verbose);
580574
cout << endl;
581575
if (verbose) {
582576
printAllTimers();

src/matmul.h

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,11 @@ class MatMulBase {
4545
std::unique_ptr<CachedzzxMatrix> zzxCache;
4646
std::unique_ptr<CachedDCRTMatrix> dcrtCache;
4747
std::mutex cachelock;
48+
49+
long gStep; // the giant-step parameter (if used)
50+
4851
public:
49-
MatMulBase(const EncryptedArray& _ea): ea(_ea) {}
52+
MatMulBase(const EncryptedArray& _ea, long g=1): ea(_ea), gStep(g) {}
5053
virtual ~MatMulBase() {}
5154

5255
const EncryptedArray& getEA() const { return ea; }
@@ -60,11 +63,21 @@ class MatMulBase {
6063
void releaseCache() { cachelock.unlock(); }
6164

6265
void upgradeCache(); // build DCRT cache from zzx cache
63-
6466
void installzzxcache(std::unique_ptr<CachedzzxMatrix>& zc)
6567
{ zzxCache.swap(zc); }
6668
void installDCRTcache(std::unique_ptr<CachedDCRTMatrix>& dc)
6769
{ dcrtCache.swap(dc); }
70+
71+
// setGstep is *not* thread safe and should never be called if
72+
// there are threads using the current cache.
73+
void setGstep(long g) {
74+
if (g != gStep && g>0) {
75+
zzxCache.reset();
76+
dcrtCache.reset();
77+
gStep = g;
78+
}
79+
}
80+
long getGstep() const { return gStep; }
6881
};
6982

7083
//! @class MatMul
@@ -79,7 +92,7 @@ template<class type>
7992
class MatMul : public MatMulBase { // type is PA_GF2 or PA_zz_p
8093
public:
8194
PA_INJECT(type)
82-
MatMul(const EncryptedArray& _ea): MatMulBase(_ea) {}
95+
MatMul(const EncryptedArray& _ea, long g=1): MatMulBase(_ea,g) {}
8396

8497
// Should return true when the entry is a zero. An application must
8598
// implement (at least) one of these get functions, calling the base
@@ -136,16 +149,14 @@ void buildCache4MatMul_sparse(MatMulBase& mat, MatrixCacheType buildCache);
136149
//! cache exists).
137150

138151
void matMul1D(Ctxt& ctxt, MatMulBase& mat, long dim,
139-
MatrixCacheType buildCache=cacheEmpty, long giantStep=1);
152+
MatrixCacheType buildCache=cacheEmpty);
140153
//! Build a cache without performing multiplication
141-
void buildCache4MatMul1D(MatMulBase& mat, long dim,
142-
MatrixCacheType buildCache, long giantStep=1);
154+
void buildCache4MatMul1D(MatMulBase& mat,long dim,MatrixCacheType buildCache);
143155

144156
void matMulti1D(Ctxt& ctxt, MatMulBase& mat, long dim,
145-
MatrixCacheType buildCache=cacheEmpty, long giantStep=1);
157+
MatrixCacheType buildCache=cacheEmpty);
146158
//! Build a cache without performing multiplication
147-
void buildCache4MatMulti1D(MatMulBase& mat,long dim,
148-
MatrixCacheType buildCache, long giantStep=1);
159+
void buildCache4MatMulti1D(MatMulBase& mat,long dim,MatrixCacheType buildCache);
149160

150161
// Versions for plaintext rather than cipehrtext, useful for debugging
151162
void matMul(NewPlaintextArray& pa, MatMulBase& mat);

0 commit comments

Comments
 (0)