Skip to content

Commit fd75f06

Browse files
author
Shai Halevi
committed
multithreading support/bugfixes
1 parent 48bb963 commit fd75f06

File tree

2 files changed

+60
-32
lines changed

2 files changed

+60
-32
lines changed

src/Test_binaryArith.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,15 @@ int main(int argc, char *argv[])
6060
amap.arg("bootstrap", bootstrap, "test multiplication with bootstrapping");
6161
long seed=0;
6262
amap.arg("seed", seed, "PRG seed");
63+
long nthreads=1;
64+
amap.arg("nthreads", nthreads, "number of threads");
6365
amap.arg("verbose", verbose, "print more information");
6466

6567
amap.parse(argc, argv);
6668
assert(prm >= 0 && prm < 4);
6769
if (seed) NTL::SetSeed(ZZ(seed));
70+
if (nthreads>1) NTL::SetNumThreads(nthreads);
71+
6872
if (bitSize<=0) bitSize=5;
6973
else if (bitSize>32) bitSize=32;
7074

@@ -105,6 +109,7 @@ int main(int argc, char *argv[])
105109
if (verbose) {
106110
cout <<"input bitSize="<<bitSize<<", output size bound="<<outSize
107111
<<", running "<<nTests<<" tests for each function\n";
112+
if (nthreads>1) cout << " using "<<nthreads<<" threads\n";
108113
cout << "computing key-independent tables..." << std::flush;
109114
}
110115
FHEcontext context(m, p, /*r=*/1, gens, ords);

src/binaryArith.cpp

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class ScratchCell {
9494
* consuming as few levels as possible).
9595
**/
9696
class AddDAG {
97+
std::mutex scratch_mtx; // controls access to scratch vector
9798
std::vector<ScratchCell> scratch; // scratch space for ciphertexts
9899
std::map<NodeIdx,DAGnode> p; // p[i,j]= prod_{t=j}^i (a[t]+b[t])
99100
std::map<NodeIdx,DAGnode> q; // q[i,j]= a[j]b[j]*prod_{t=j+1}^i (a[t]+b[t])
@@ -278,11 +279,12 @@ void AddDAG::apply(CtPtrs& sum,
278279
throw std::logic_error("DAG applied to wrong vectors");
279280

280281
if (sizeLimit==0) sizeLimit = bSize+1;
281-
sum.resize(sizeLimit, &b); // allocate space for the output
282+
if (lsize(sum)!=sizeLimit)
283+
sum.resize(sizeLimit, &b); // allocate space for the output
282284

283285
// Allow multi-threading in this loop
284286
NTL_EXEC_RANGE(sizeLimit, first, last)
285-
for (long i=first; i<last; i++) {
287+
for (long i=first; i<last; i++) { // for (long i=0; i<sizeLimit; i++) {
286288
if (i<bSize)
287289
addCtxtFromNode(*(sum[i]), this->findP(i,i), a, b);
288290
for (long j=std::min(i-1, aSize-1); j>=0; --j) {
@@ -369,7 +371,7 @@ const Ctxt& AddDAG::getCtxt(DAGnode* node,
369371
Ctxt* AddDAG::allocateCtxtLike(const Ctxt& c)
370372
{
371373
// look for an unused cell in the scratch array
372-
for (long i=0; i<(long)scratch.size(); i++)
374+
for (long i=0; i<lsize(scratch); i++)
373375
if (scratch[i].used == false) { // found a free one, try to use it
374376
bool used = scratch[i].used.exchange(true); // mark it as used
375377
if (used==false) // make sure no other thread got there first
@@ -379,8 +381,9 @@ Ctxt* AddDAG::allocateCtxtLike(const Ctxt& c)
379381
// If not found, allocate a new cell
380382
ScratchCell sc(c); // cell points to new ctxt, with used=true
381383
Ctxt* pt = sc.ct.get(); // remember the raw pointer
382-
scratch.push_back(std::move(sc)); // scratch now owns the pointer
383-
return pt; // return the raw pointer
384+
std::unique_lock<std::mutex> lck(scratch_mtx); // protect scratch vector
385+
scratch.emplace_back(std::move(sc)); // scratch now owns the pointer
386+
return pt; // return the raw pointer
384387
}
385388

386389
// Mark a scratch ciphertext as unused. We assume that no two nodes
@@ -602,6 +605,17 @@ static void three4Two(CtPtrs& lsb, CtPtrs& msb,
602605
vecCopy(msb, tmpMsb);
603606
}
604607

608+
//! @brief An implementation of PtrMatrix using vector< PtrVector<T>* >
609+
template<typename T>
610+
struct PtrMatrix_PtPtrVector : PtrMatrix<T> {
611+
std::vector< PtrVector<T>* >& rows;
612+
PtrMatrix_PtPtrVector(std::vector< PtrVector<T>* >& mat): rows(mat) {}
613+
PtrVector<T>& operator[](long i) override // returns a row
614+
{ return *rows[i]; }
615+
const PtrVector<T>& operator[](long i) const override // returns a row
616+
{ return *rows[i]; }
617+
long size() const override { return lsize(rows); } // How many rows
618+
};
605619

606620
// Calculates the sum of many numbers using the 3-for-2 method
607621
void addManyNumbers(CtPtrs& sum, CtPtrMat& numbers, long sizeLimit,
@@ -619,37 +633,46 @@ void addManyNumbers(CtPtrs& sum, CtPtrMat& numbers, long sizeLimit,
619633
}
620634
if (lsize(numbers)==1) { vecCopy(sum, numbers[0]); return; }
621635

622-
// if just 2 numbers to add then use normal binary addition
623-
// else enter loop below. We view numbers as a FIFO queue, each
624-
// time removing the first three entries at the head and adding
625-
// two new ones at the tail.
626-
long head=0, tail=0;
627-
for (long leftInQ=lsize(numbers); leftInQ>2; leftInQ--) {
628-
long h2 = (head+1) % lsize(numbers);
629-
long h3 = (head+2) % lsize(numbers);
630-
long t2 = (tail+1) % lsize(numbers);
631-
const CtPtrs& h1ct = numbers[head];
632-
const CtPtrs& h2ct = numbers[h2];
633-
const CtPtrs& h3ct = numbers[h3];
634-
635-
// If any of head,h1,h2 are too low level, then bootstrap everything
636-
if (findMinLevel({&h1ct, &h2ct, &h3ct}) < 3) {
637-
assert(unpackSlotEncoding!=nullptr
638-
&& ct_ptr->getPubKey().isBootstrappable());
639-
packedRecrypt(numbers, *unpackSlotEncoding,
640-
*(ct_ptr->getContext().ea), /*belowLvl=*/10);
641-
}
636+
bool bootstrappable = ct_ptr->getPubKey().isBootstrappable();
637+
const EncryptedArray& ea = *(ct_ptr->getContext().ea);
642638

643-
// three4Two can work in-place
644-
three4Two(numbers[tail], numbers[t2], h1ct, h2ct, h3ct, sizeLimit);
639+
long leftInQ = lsize(numbers);
640+
std::vector<CtPtrs*> numPtrs(leftInQ);
641+
for (long i=0; i<leftInQ; i++) numPtrs[i] = &(numbers[i]);
645642

646-
head = (head+3) % lsize(numbers);
647-
tail = (tail+2) % lsize(numbers);
643+
// use 3-for-2 repeatedly until only two numbers are leff to add
644+
while (leftInQ>2) {
645+
// If any number is too low level, then bootstrap everything
646+
PtrMatrix_PtPtrVector<Ctxt> wrapper(numPtrs);
647+
if (findMinLevel(wrapper)<3) {
648+
assert(bootstrappable && unpackSlotEncoding!=nullptr);
649+
packedRecrypt(wrapper, *unpackSlotEncoding, ea, /*belowLvl=*/10);
650+
}
651+
// Prepare a vector for pointers to the output of this iteration
652+
long nTriples = leftInQ/3;
653+
long leftOver = leftInQ - (3*nTriples);
654+
std::vector<CtPtrs*> numPtrs2(2*nTriples +leftOver);
655+
656+
if (leftOver>0) { // copy the leftover pointers
657+
numPtrs2[0] = numPtrs[3*nTriples];
658+
if (leftOver>1) numPtrs2[1] = numPtrs[3*nTriples +1];
659+
}
660+
// Allow multi-threading in this loop
661+
NTL_EXEC_RANGE(nTriples, first, last)
662+
for (long i=first; i<last; i++) { // call the three-for-two procedure
663+
three4Two(*numPtrs[3*i], *numPtrs[3*i+1], // three4Two works in-place
664+
*numPtrs[3*i], *numPtrs[3*i+1], *numPtrs[3*i+2], sizeLimit);
665+
666+
numPtrs2[leftOver +2*i] = numPtrs[3*i]; // copy the output pointers
667+
numPtrs2[leftOver +2*i +1] = numPtrs[3*i +1];
668+
}
669+
NTL_EXEC_RANGE_END
670+
numPtrs.swap(numPtrs2); // swap input/output vectors
671+
leftInQ = lsize(numPtrs); // update the size
648672
}
649673
// final addition
650-
long h2 = (head+1) % lsize(numbers);
651-
addTwoNumbers(sum, numbers[head], numbers[h2], sizeLimit, unpackSlotEncoding);
652-
} // NOTE: It'd be a little challenging to parallelize this
674+
addTwoNumbers(sum, *numPtrs[0], *numPtrs[1], sizeLimit, unpackSlotEncoding);
675+
}
653676

654677

655678
// Multiply a positive a by a potentially negative b, we need to sign-extend b

0 commit comments

Comments
 (0)