3232#include " CModulus.h"
3333#include " timing.h"
3434
35+ #include " FHEContext.h"
36+ // needed to get ALT_CRT
37+
3538
3639// It is assumed that m,q,context, and root are already set. If root is set
3740// to zero, it will be computed by the compRoots() method. Then rInv is
@@ -69,6 +72,42 @@ Cmodulus::Cmodulus(const PAlgebra &zms, long qq, long rt)
6972
7073 zz_pBak bak;
7174
75+ if (!ALT_CRT && zms.getPow2 ()) {
76+ // special case when m is a power of 2
77+
78+ assert ( explicitModulus );
79+ bak.save ();
80+ context = zz_pContext (INIT_USER_FFT, q);
81+ context.restore ();
82+
83+ powers.set_ptr (new zz_pX);
84+ ipowers.set_ptr (new zz_pX);
85+
86+
87+ long k = zms.getPow2 ();
88+ long phim = 1L << (k-1 );
89+ long w0 = zz_pInfo->p_info ->RootTable [0 ][k];
90+ long w1 = zz_pInfo->p_info ->RootTable [1 ][k];
91+
92+ powers->rep .SetLength (phim);
93+ powers_aux.SetLength (phim);
94+ for (long i = 0 , w = 1 ; i < phim; i++) {
95+ powers->rep [i] = w;
96+ powers_aux[i] = PrepMulModPrecon (w, q);
97+ w = MulMod (w, w0, q);
98+ }
99+
100+ ipowers->rep .SetLength (phim);
101+ ipowers_aux.SetLength (phim);
102+ for (long i = 0 , w = 1 ; i < phim; i++) {
103+ ipowers->rep [i] = w;
104+ ipowers_aux[i] = PrepMulModPrecon (w, q);
105+ w = MulMod (w, w1, q);
106+ }
107+
108+ return ;
109+ }
110+
72111 if (explicitModulus) {
73112 bak.save (); // backup the current modulus
74113 context = BuildContext (q, NextPowerOfTwo (zms.getM ()) + 1 );
@@ -146,10 +185,38 @@ void Cmodulus::FFT(vec_long &y, const ZZX& x) const
146185 FHE_TIMER_START;
147186 zz_pBak bak; bak.save ();
148187 context.restore ();
149- zz_p rt;
150- zz_pX& tmp = Cmodulus::getScratch_zz_pX ();
151188
189+ zz_pX& tmp = Cmodulus::getScratch_zz_pX ();
152190 conv (tmp,x); // convert input to zpx format
191+
192+ if (!ALT_CRT && zMStar->getPow2 ()) {
193+ // special case when m is a power of 2
194+
195+ long k = zMStar->getPow2 ();
196+ long phim = (1L << (k-1 ));
197+ long dx = deg (tmp);
198+ long p = zz_p::modulus ();
199+
200+ const zz_p *powers_p = (*powers).rep .elts ();
201+ const mulmod_precon_t *powers_aux_p = powers_aux.elts ();
202+
203+ y.SetLength (phim);
204+ long *yp = y.elts ();
205+
206+ zz_p *tmp_p = tmp.rep .elts ();
207+
208+ for (long i = 0 ; i <= dx; i++)
209+ yp[i] = MulModPrecon (rep (tmp_p[i]), rep (powers_p[i]), p, powers_aux_p[i]);
210+ for (long i = dx+1 ; i < phim; i++)
211+ yp[i] = 0 ;
212+
213+ FFTFwd (yp, yp, k-1 , *zz_pInfo->p_info );
214+
215+ return ;
216+ }
217+
218+
219+ zz_p rt;
153220 conv (rt, root); // convert root to zp format
154221
155222 BluesteinFFT (tmp, getM (), rt, *powers, powers_aux, *Rb); // call the FFT routine
@@ -169,8 +236,40 @@ void Cmodulus::iFFT(zz_pX &x, const vec_long& y)const
169236 FHE_TIMER_START;
170237 zz_pBak bak; bak.save ();
171238 context.restore ();
172- zz_p rt;
173239
240+ if (!ALT_CRT && zMStar->getPow2 ()) {
241+ // special case when m is a power of 2
242+
243+ long k = zMStar->getPow2 ();
244+ long phim = (1L << (k-1 ));
245+ long p = zz_p::modulus ();
246+
247+ const zz_p *ipowers_p = (*ipowers).rep .elts ();
248+ const mulmod_precon_t *ipowers_aux_p = ipowers_aux.elts ();
249+
250+ const long *yp = y.elts ();
251+
252+ vec_long& tmp = Cmodulus::getScratch_vec_long ();
253+ tmp.SetLength (phim);
254+ long *tmp_p = tmp.elts ();
255+
256+ FFTRev1 (tmp_p, yp, k-1 , *zz_pInfo->p_info );
257+
258+ x.rep .SetLength (phim);
259+ zz_p *xp = x.rep .elts ();
260+
261+ for (long i = 0 ; i < phim; i++)
262+ xp[i].LoopHole () = MulModPrecon (tmp_p[i], rep (ipowers_p[i]), p, ipowers_aux_p[i]);
263+
264+
265+ x.normalize ();
266+
267+ return ;
268+ }
269+
270+
271+
272+ zz_p rt;
174273 long m = getM ();
175274
176275 // convert input to zpx format, initializing only the coeffs i s.t. (i,m)=1
@@ -201,6 +300,12 @@ zz_pX& Cmodulus::getScratch_zz_pX()
201300 return scratch;
202301}
203302
303+ Vec<long >& Cmodulus::getScratch_vec_long ()
304+ {
305+ NTL_THREAD_LOCAL static Vec<long > scratch;
306+ return scratch;
307+ }
308+
204309
205310fftRep& Cmodulus::getScratch_fftRep (long k)
206311{
0 commit comments