@@ -94,6 +94,7 @@ class ScratchCell {
9494 * consuming as few levels as possible).
9595 **/
9696class AddDAG {
97+ std::mutex scratch_mtx; // controls access to scratch vector
9798 std::vector<ScratchCell> scratch; // scratch space for ciphertexts
9899 std::map<NodeIdx,DAGnode> p; // p[i,j]= prod_{t=j}^i (a[t]+b[t])
99100 std::map<NodeIdx,DAGnode> q; // q[i,j]= a[j]b[j]*prod_{t=j+1}^i (a[t]+b[t])
@@ -278,11 +279,12 @@ void AddDAG::apply(CtPtrs& sum,
278279 throw std::logic_error (" DAG applied to wrong vectors" );
279280
280281 if (sizeLimit==0 ) sizeLimit = bSize+1 ;
281- sum.resize (sizeLimit, &b); // allocate space for the output
282+ if (lsize (sum)!=sizeLimit)
283+ sum.resize (sizeLimit, &b); // allocate space for the output
282284
283285 // Allow multi-threading in this loop
284286 NTL_EXEC_RANGE (sizeLimit, first, last)
285- for (long i=first; i<last; i++) {
287+ for (long i=first; i<last; i++) { // for (long i=0; i<sizeLimit; i++) {
286288 if (i<bSize)
287289 addCtxtFromNode (*(sum[i]), this ->findP (i,i), a, b);
288290 for (long j=std::min (i-1 , aSize-1 ); j>=0 ; --j) {
@@ -369,7 +371,7 @@ const Ctxt& AddDAG::getCtxt(DAGnode* node,
369371Ctxt* AddDAG::allocateCtxtLike (const Ctxt& c)
370372{
371373 // look for an unused cell in the scratch array
372- for (long i=0 ; i<( long ) scratch. size ( ); i++)
374+ for (long i=0 ; i<lsize ( scratch); i++)
373375 if (scratch[i].used == false ) { // found a free one, try to use it
374376 bool used = scratch[i].used .exchange (true ); // mark it as used
375377 if (used==false ) // make sure no other thread got there first
@@ -379,8 +381,9 @@ Ctxt* AddDAG::allocateCtxtLike(const Ctxt& c)
379381 // If not found, allocate a new cell
380382 ScratchCell sc (c); // cell points to new ctxt, with used=true
381383 Ctxt* pt = sc.ct .get (); // remember the raw pointer
382- scratch.push_back (std::move (sc)); // scratch now owns the pointer
383- return pt; // return the raw pointer
384+ std::unique_lock<std::mutex> lck (scratch_mtx); // protect scratch vector
385+ scratch.emplace_back (std::move (sc)); // scratch now owns the pointer
386+ return pt; // return the raw pointer
384387}
385388
386389// Mark a scratch ciphertext as unused. We assume that no two nodes
@@ -602,6 +605,17 @@ static void three4Two(CtPtrs& lsb, CtPtrs& msb,
602605 vecCopy (msb, tmpMsb);
603606}
604607
608+ // ! @brief An implementation of PtrMatrix using vector< PtrVector<T>* >
609+ template <typename T>
610+ struct PtrMatrix_PtPtrVector : PtrMatrix<T> {
611+ std::vector< PtrVector<T>* >& rows;
612+ PtrMatrix_PtPtrVector (std::vector< PtrVector<T>* >& mat): rows(mat) {}
613+ PtrVector<T>& operator [](long i) override // returns a row
614+ { return *rows[i]; }
615+ const PtrVector<T>& operator [](long i) const override // returns a row
616+ { return *rows[i]; }
617+ long size () const override { return lsize (rows); } // How many rows
618+ };
605619
606620// Calculates the sum of many numbers using the 3-for-2 method
607621void addManyNumbers (CtPtrs& sum, CtPtrMat& numbers, long sizeLimit,
@@ -619,37 +633,46 @@ void addManyNumbers(CtPtrs& sum, CtPtrMat& numbers, long sizeLimit,
619633 }
620634 if (lsize (numbers)==1 ) { vecCopy (sum, numbers[0 ]); return ; }
621635
622- // if just 2 numbers to add then use normal binary addition
623- // else enter loop below. We view numbers as a FIFO queue, each
624- // time removing the first three entries at the head and adding
625- // two new ones at the tail.
626- long head=0 , tail=0 ;
627- for (long leftInQ=lsize (numbers); leftInQ>2 ; leftInQ--) {
628- long h2 = (head+1 ) % lsize (numbers);
629- long h3 = (head+2 ) % lsize (numbers);
630- long t2 = (tail+1 ) % lsize (numbers);
631- const CtPtrs& h1ct = numbers[head];
632- const CtPtrs& h2ct = numbers[h2];
633- const CtPtrs& h3ct = numbers[h3];
634-
635- // If any of head,h1,h2 are too low level, then bootstrap everything
636- if (findMinLevel ({&h1ct, &h2ct, &h3ct}) < 3 ) {
637- assert (unpackSlotEncoding!=nullptr
638- && ct_ptr->getPubKey ().isBootstrappable ());
639- packedRecrypt (numbers, *unpackSlotEncoding,
640- *(ct_ptr->getContext ().ea ), /* belowLvl=*/ 10 );
641- }
636+ bool bootstrappable = ct_ptr->getPubKey ().isBootstrappable ();
637+ const EncryptedArray& ea = *(ct_ptr->getContext ().ea );
642638
643- // three4Two can work in-place
644- three4Two (numbers[tail], numbers[t2], h1ct, h2ct, h3ct, sizeLimit);
639+ long leftInQ = lsize (numbers);
640+ std::vector<CtPtrs*> numPtrs (leftInQ);
641+ for (long i=0 ; i<leftInQ; i++) numPtrs[i] = &(numbers[i]);
645642
646- head = (head+3 ) % lsize (numbers);
647- tail = (tail+2 ) % lsize (numbers);
643+ // use 3-for-2 repeatedly until only two numbers are leff to add
644+ while (leftInQ>2 ) {
645+ // If any number is too low level, then bootstrap everything
646+ PtrMatrix_PtPtrVector<Ctxt> wrapper (numPtrs);
647+ if (findMinLevel (wrapper)<3 ) {
648+ assert (bootstrappable && unpackSlotEncoding!=nullptr );
649+ packedRecrypt (wrapper, *unpackSlotEncoding, ea, /* belowLvl=*/ 10 );
650+ }
651+ // Prepare a vector for pointers to the output of this iteration
652+ long nTriples = leftInQ/3 ;
653+ long leftOver = leftInQ - (3 *nTriples);
654+ std::vector<CtPtrs*> numPtrs2 (2 *nTriples +leftOver);
655+
656+ if (leftOver>0 ) { // copy the leftover pointers
657+ numPtrs2[0 ] = numPtrs[3 *nTriples];
658+ if (leftOver>1 ) numPtrs2[1 ] = numPtrs[3 *nTriples +1 ];
659+ }
660+ // Allow multi-threading in this loop
661+ NTL_EXEC_RANGE (nTriples, first, last)
662+ for (long i=first; i<last; i++) { // call the three-for-two procedure
663+ three4Two (*numPtrs[3 *i], *numPtrs[3 *i+1 ], // three4Two works in-place
664+ *numPtrs[3 *i], *numPtrs[3 *i+1 ], *numPtrs[3 *i+2 ], sizeLimit);
665+
666+ numPtrs2[leftOver +2 *i] = numPtrs[3 *i]; // copy the output pointers
667+ numPtrs2[leftOver +2 *i +1 ] = numPtrs[3 *i +1 ];
668+ }
669+ NTL_EXEC_RANGE_END
670+ numPtrs.swap (numPtrs2); // swap input/output vectors
671+ leftInQ = lsize (numPtrs); // update the size
648672 }
649673 // final addition
650- long h2 = (head+1 ) % lsize (numbers);
651- addTwoNumbers (sum, numbers[head], numbers[h2], sizeLimit, unpackSlotEncoding);
652- } // NOTE: It'd be a little challenging to parallelize this
674+ addTwoNumbers (sum, *numPtrs[0 ], *numPtrs[1 ], sizeLimit, unpackSlotEncoding);
675+ }
653676
654677
655678// Multiply a positive a by a potentially negative b, we need to sign-extend b
0 commit comments