Skip to content

Commit 2713a10

Browse files
committed
.
1 parent bf19ffc commit 2713a10

File tree

4 files changed

+124
-3
lines changed

4 files changed

+124
-3
lines changed

src/CModulus.cpp

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
#include "CModulus.h"
3333
#include "timing.h"
3434

35+
#include "FHEContext.h"
36+
// needed to get ALT_CRT
37+
3538

3639
// It is assumed that m,q,context, and root are already set. If root is set
3740
// to zero, it will be computed by the compRoots() method. Then rInv is
@@ -69,6 +72,42 @@ Cmodulus::Cmodulus(const PAlgebra &zms, long qq, long rt)
6972

7073
zz_pBak bak;
7174

75+
if (!ALT_CRT && zms.getPow2()) {
76+
// special case when m is a power of 2
77+
78+
assert( explicitModulus );
79+
bak.save();
80+
context = zz_pContext(INIT_USER_FFT, q);
81+
context.restore();
82+
83+
powers.set_ptr(new zz_pX);
84+
ipowers.set_ptr(new zz_pX);
85+
86+
87+
long k = zms.getPow2();
88+
long phim = 1L << (k-1);
89+
long w0 = zz_pInfo->p_info->RootTable[0][k];
90+
long w1 = zz_pInfo->p_info->RootTable[1][k];
91+
92+
powers->rep.SetLength(phim);
93+
powers_aux.SetLength(phim);
94+
for (long i = 0, w = 1; i < phim; i++) {
95+
powers->rep[i] = w;
96+
powers_aux[i] = PrepMulModPrecon(w, q);
97+
w = MulMod(w, w0, q);
98+
}
99+
100+
ipowers->rep.SetLength(phim);
101+
ipowers_aux.SetLength(phim);
102+
for (long i = 0, w = 1; i < phim; i++) {
103+
ipowers->rep[i] = w;
104+
ipowers_aux[i] = PrepMulModPrecon(w, q);
105+
w = MulMod(w, w1, q);
106+
}
107+
108+
return;
109+
}
110+
72111
if (explicitModulus) {
73112
bak.save(); // backup the current modulus
74113
context = BuildContext(q, NextPowerOfTwo(zms.getM()) + 1);
@@ -146,10 +185,38 @@ void Cmodulus::FFT(vec_long &y, const ZZX& x) const
146185
FHE_TIMER_START;
147186
zz_pBak bak; bak.save();
148187
context.restore();
149-
zz_p rt;
150-
zz_pX& tmp = Cmodulus::getScratch_zz_pX();
151188

189+
zz_pX& tmp = Cmodulus::getScratch_zz_pX();
152190
conv(tmp,x); // convert input to zpx format
191+
192+
if (!ALT_CRT && zMStar->getPow2()) {
193+
// special case when m is a power of 2
194+
195+
long k = zMStar->getPow2();
196+
long phim = (1L << (k-1));
197+
long dx = deg(tmp);
198+
long p = zz_p::modulus();
199+
200+
const zz_p *powers_p = (*powers).rep.elts();
201+
const mulmod_precon_t *powers_aux_p = powers_aux.elts();
202+
203+
y.SetLength(phim);
204+
long *yp = y.elts();
205+
206+
zz_p *tmp_p = tmp.rep.elts();
207+
208+
for (long i = 0; i <= dx; i++)
209+
yp[i] = MulModPrecon(rep(tmp_p[i]), rep(powers_p[i]), p, powers_aux_p[i]);
210+
for (long i = dx+1; i < phim; i++)
211+
yp[i] = 0;
212+
213+
FFTFwd(yp, yp, k-1, *zz_pInfo->p_info);
214+
215+
return;
216+
}
217+
218+
219+
zz_p rt;
153220
conv(rt, root); // convert root to zp format
154221

155222
BluesteinFFT(tmp, getM(), rt, *powers, powers_aux, *Rb); // call the FFT routine
@@ -169,8 +236,40 @@ void Cmodulus::iFFT(zz_pX &x, const vec_long& y)const
169236
FHE_TIMER_START;
170237
zz_pBak bak; bak.save();
171238
context.restore();
172-
zz_p rt;
173239

240+
if (!ALT_CRT && zMStar->getPow2()) {
241+
// special case when m is a power of 2
242+
243+
long k = zMStar->getPow2();
244+
long phim = (1L << (k-1));
245+
long p = zz_p::modulus();
246+
247+
const zz_p *ipowers_p = (*ipowers).rep.elts();
248+
const mulmod_precon_t *ipowers_aux_p = ipowers_aux.elts();
249+
250+
const long *yp = y.elts();
251+
252+
vec_long& tmp = Cmodulus::getScratch_vec_long();
253+
tmp.SetLength(phim);
254+
long *tmp_p = tmp.elts();
255+
256+
FFTRev1(tmp_p, yp, k-1, *zz_pInfo->p_info);
257+
258+
x.rep.SetLength(phim);
259+
zz_p *xp = x.rep.elts();
260+
261+
for (long i = 0; i < phim; i++)
262+
xp[i].LoopHole() = MulModPrecon(tmp_p[i], rep(ipowers_p[i]), p, ipowers_aux_p[i]);
263+
264+
265+
x.normalize();
266+
267+
return;
268+
}
269+
270+
271+
272+
zz_p rt;
174273
long m = getM();
175274

176275
// convert input to zpx format, initializing only the coeffs i s.t. (i,m)=1
@@ -201,6 +300,12 @@ zz_pX& Cmodulus::getScratch_zz_pX()
201300
return scratch;
202301
}
203302

303+
Vec<long>& Cmodulus::getScratch_vec_long()
304+
{
305+
NTL_THREAD_LOCAL static Vec<long> scratch;
306+
return scratch;
307+
}
308+
204309

205310
fftRep& Cmodulus::getScratch_fftRep(long k)
206311
{

src/CModulus.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ class Cmodulus {
109109
// which is not officially sanctioned by NTL, but should be OK.
110110
static zz_pX& getScratch_zz_pX();
111111

112+
static Vec<long>& getScratch_vec_long();
113+
112114
// returns thread-local scratch space
113115
// DIRT: this use a couple of internal, undocumented
114116
// NTL interfaces

src/PAlgebra.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,20 @@ PAlgebra::PAlgebra(unsigned long mm, unsigned long pp,
156156
assert( ProbPrime(pp) );
157157
assert( (mm % pp) != 0 );
158158
assert( mm < NTL_SP_BOUND );
159+
assert( mm > 1 );
159160

160161
cM = 1.0; // default value for the ring constant
161162
m = mm;
162163
p = pp;
163164

165+
long k = NextPowerOfTwo(m);
166+
if (mm == (1L << k))
167+
pow2 = k;
168+
else
169+
pow2 = 0;
170+
171+
172+
164173
// For dry-run, use a tiny m value for the PAlgebra tables
165174
if (isDryRun()) mm = (p==3)? 4 : 3;
166175

src/PAlgebra.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class PAlgebra {
6161
unsigned long ordP; // the order of p in (Z/mZ)^*
6262
unsigned long nSlots; // phi(m)/ordP = # of plaintext slots
6363

64+
long pow2; // if m = 2^k, then pow2 == k; otherwise, pow2 == 0
65+
6466
vector<long> gens; // Our generators for (Z/mZ)^* (other than p)
6567
vector<long> ords; // ords[i] is the order of gens[i] in quotient group kept
6668
// with a negative sign if different than order in (Z/mZ)*
@@ -121,6 +123,9 @@ class PAlgebra {
121123
//! The number of plaintext slots = phi(m)/ord(p)
122124
unsigned long getNSlots() const { return nSlots; }
123125

126+
//! if m = 2^k, then pow2 == k; otherwise, pow2 == 0
127+
long getPow2() const { return pow2; }
128+
124129
//! The cyclotomix polynomial Phi_m(X)
125130
const ZZX& getPhimX() const { return PhimX; }
126131

0 commit comments

Comments
 (0)