@@ -195,9 +195,9 @@ long FHEcontext::AddFFTPrime(bool special)
195195 return p;
196196}
197197
198- // Adds several primes to the chain. If byNumber=true then totalSize
199- // specifies the number of primes to add. If byNumber=false then
200- // totalSize specifies the target natural log all the added primes.
198+ // Adds several primes to the chain. If byNumber=true then totalSize specifies
199+ // the number of primes to add. If byNumber=false then totalSize specifies the
200+ // target natural log all the added primes.
201201// Returns natural log of the product of all added primes.
202202double AddManyPrimes (FHEcontext& context, double totalSize,
203203 bool byNumber, bool special)
@@ -206,10 +206,6 @@ double AddManyPrimes(FHEcontext& context, double totalSize,
206206 Error (" AddManyPrimes: m undefined or larger than 2^20" );
207207 // NOTE: Below we are ensured that 16m*log(m) << NTL_SP_BOUND
208208
209- // cout << "AddManyPrimes(..., totalSize="<<((int)totalSize)
210- // << ",byNumber="<<byNumber<<",special="<<special<<")\n";
211- // cout << " context.bitsPerLevel="<<context.bitsPerLevel<<endl;
212-
213209 double sizeLogSoFar = 0.0 ; // log of added primes so far
214210 double addedSoFar = 0.0 ; // Either size or number, depending on 'byNumber'
215211
@@ -219,9 +215,9 @@ double AddManyPrimes(FHEcontext& context, double totalSize,
219215 long sizeBits = 2 *context.bitsPerLevel ;
220216#endif
221217 if (special) { // try to use similar size for all the special primes
222- // how many special primes would we need
223- long numPrimes = ceil (totalSize/( NTL_SP_NBITS* log ( 2.0 )));
224- sizeBits = 1 +ceil (totalSize/( log ( 2.0 )* numPrimes)); // bitsize of each prime
218+ double totalBits = totalSize/ log ( 2.0 );
219+ long numPrimes = ceil (totalBits/ NTL_SP_NBITS); // how many special primes
220+ sizeBits = 1 +ceil (totalBits/ numPrimes); // what's the size of each
225221 // Added one so we don't undershoot our target
226222 }
227223 if (sizeBits>NTL_SP_NBITS) sizeBits = NTL_SP_NBITS;
@@ -235,7 +231,9 @@ double AddManyPrimes(FHEcontext& context, double totalSize,
235231 }
236232
237233 // make p-1 divisible by m*2^k for as large k as possible
238- if (context.zMStar .getPow2 ()!=0 ) // if m is not a power of two
234+ // (not needed when m itself a power of two)
235+
236+ if (context.zMStar .getM () & 1 ) // m is odd, so not power of two
239237 while (twoM < sizeBound/(sizeBits*2 )) twoM *= 2 ;
240238
241239 long bigP = sizeBound - (sizeBound%twoM) +1 ; // 1 mod 2m
@@ -256,10 +254,8 @@ double AddManyPrimes(FHEcontext& context, double totalSize,
256254 return sizeLogSoFar;
257255}
258256
259- void buildModChain (FHEcontext &context, long nLevels, long nDgts)
257+ void buildModChain (FHEcontext &context, long nLevels, long nDgts, long extraBits )
260258{
261- // cout << "buildModChain called with "<<nLevels
262- // <<" levels and "<<nDgts<<" digits\n";
263259#ifdef NO_HALF_SIZE_PRIME
264260 long nPrimes = nLevels;
265261#else
@@ -290,7 +286,7 @@ void buildModChain(FHEcontext &context, long nLevels, long nDgts)
290286 context.digits .resize (nDgts); // allocate space
291287
292288 IndexSet s1;
293- double sizeLogSoFar = 0.0 ;
289+ double sizeSoFar = 0.0 ;
294290 double maxDigitSize = 0.0 ;
295291 if (nDgts>1 ) { // we break ciphetext into a few digits when key-switching
296292 double dsize = context.logOfProduct (context.ctxtPrimes )/nDgts; // estimate
@@ -301,10 +297,9 @@ void buildModChain(FHEcontext &context, long nLevels, long nDgts)
301297 long idx = context.ctxtPrimes .first ();
302298 for (long i=0 ; i<nDgts-1 ; i++) { // set all digits but the last
303299 IndexSet s;
304- while (idx <= context.ctxtPrimes .last ()
305- && (empty (s)||sizeLogSoFar<target)) {
300+ while (idx <= context.ctxtPrimes .last () && (empty (s)||sizeSoFar<target)) {
306301 s.insert (idx);
307- sizeLogSoFar += log ((double )context.ithPrime (idx));
302+ sizeSoFar += log ((double )context.ithPrime (idx));
308303 idx = context.ctxtPrimes .next (idx);
309304 }
310305 assert (!empty (s));
@@ -332,11 +327,10 @@ void buildModChain(FHEcontext &context, long nLevels, long nDgts)
332327 }
333328
334329 // Add special primes to the chain for the P factor of key-switching
335- long p2r = (context.rcData .alMod )? context.rcData .alMod ->getPPowR ()
336- : context.alMod .getPPowR ();
330+ long p2r = context.alMod .getPPowR ();
337331 double sizeOfSpecialPrimes
338332 = maxDigitSize + log (nDgts) + log (context.stdev *2 )
339- + log ((double )p2r);
333+ + log ((double )p2r) + (extraBits* log ( 2.0 )) ;
340334
341335 AddPrimesBySize (context, sizeOfSpecialPrimes, true );
342336}
0 commit comments