Skip to content

Commit c7bd67c

Browse files
committed
Merge branch 'matmul'
2 parents 06d88a2 + 2243903 commit c7bd67c

File tree

10 files changed

+938
-167
lines changed

10 files changed

+938
-167
lines changed

src/CModulus.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ Cmodulus::Cmodulus(const PAlgebra &zms, long qq, long rt)
8787

8888
long k = zms.getPow2();
8989
long phim = 1L << (k-1);
90+
91+
assert(k <= zz_pInfo->MaxRoot);
92+
// rootTables get initialized 0..zz_pInfo->Maxroot
93+
94+
#ifdef FHE_OPENCL
95+
altFFTInfo = MakeSmart<AltFFTPrimeInfo>();
96+
InitAltFFTPrimeInfo(*altFFTInfo, *zz_pInfo->p_info, k-1);
97+
#endif
98+
9099
long w0 = zz_pInfo->p_info->RootTable[0][k];
91100
long w1 = zz_pInfo->p_info->RootTable[1][k];
92101

@@ -105,6 +114,7 @@ Cmodulus::Cmodulus(const PAlgebra &zms, long qq, long rt)
105114
ipowers_aux[i] = PrepMulModPrecon(w, q);
106115
w = MulMod(w, w1, q);
107116
}
117+
108118

109119
return;
110120
}
@@ -176,6 +186,10 @@ Cmodulus& Cmodulus::operator=(const Cmodulus &other)
176186
iRb = other.iRb;
177187
phimx = other.phimx;
178188

189+
#ifdef FHE_OPENCL
190+
altFFTInfo = other.altFFTInfo;
191+
#endif
192+
179193

180194

181195
return *this;
@@ -206,7 +220,11 @@ void Cmodulus::FFT_aux(vec_long &y, zz_pX& tmp) const
206220
for (long i = dx+1; i < phim; i++)
207221
yp[i] = 0;
208222

223+
#ifdef FHE_OPENCL
224+
AltFFTFwd(yp, yp, k-1, *altFFTInfo);
225+
#else
209226
FFTFwd(yp, yp, k-1, *zz_pInfo->p_info);
227+
#endif
210228

211229
return;
212230
}
@@ -278,7 +296,11 @@ void Cmodulus::iFFT(zz_pX &x, const vec_long& y)const
278296
tmp.SetLength(phim);
279297
long *tmp_p = tmp.elts();
280298

299+
#ifdef FHE_OPENCL
300+
AltFFTRev1(tmp_p, yp, k-1, *altFFTInfo);
301+
#else
281302
FFTRev1(tmp_p, yp, k-1, *zz_pInfo->p_info);
303+
#endif
282304

283305
x.rep.SetLength(phim);
284306
zz_p *xp = x.rep.elts();

src/CModulus.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
#include "cloned_ptr.h"
2828

2929

30+
#ifdef FHE_OPENCL
31+
#include "FFT.h"
32+
#endif
33+
3034

3135
/**
3236
* @class Cmodulus
@@ -70,6 +74,12 @@ class Cmodulus {
7074

7175
public:
7276

77+
#ifdef FHE_OPENCL
78+
SmartPtr<AltFFTPrimeInfo> altFFTInfo;
79+
// We need to allow copying...the underlying object
80+
// is immutable
81+
#endif
82+
7383
// Destructor and constructors
7484

7585
// Default constructor

src/EncryptedArray.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,9 @@ template<class type> class EncryptedArrayDerived : public EncryptedArrayBase {
329329
virtual void rotate(Ctxt& ctxt, long k) const;
330330
virtual void shift(Ctxt& ctxt, long k) const;
331331
virtual void rotate1D(Ctxt& ctxt, long i, long k, bool dc=false) const;
332+
template<class U> void // avoid this being "hidden" by other rotate1D's
333+
rotate1D(vector<U>& out, const vector<U>& in, long i, long offset) const
334+
{ EncryptedArrayBase::rotate1D(out, in, i, offset); }
332335
virtual void shift1D(Ctxt& ctxt, long i, long k) const;
333336

334337
virtual void encode(ZZX& ptxt, const vector< long >& array) const

0 commit comments

Comments
 (0)