Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 34 additions & 88 deletions src/pke/include/cryptocontext.h

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/pke/include/schemebase/base-advancedshe.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class AdvancedSHEBase {
using TugType = typename Element::TugType;

constexpr static std::string_view NOT_IMPLEMENTED_ERROR = "Not implemented for this scheme";

public:
virtual ~AdvancedSHEBase() = default;

Expand Down Expand Up @@ -366,7 +367,7 @@ class AdvancedSHEBase {
* @return returns the evaluation keys
*/
virtual std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> EvalSumKeyGen(
const PrivateKey<Element> privateKey, const PublicKey<Element> publicKey) const;
const PrivateKey<Element> privateKey) const;

/**
* Virtual function to generate the automorphism keys for EvalSumRows; works
Expand Down
19 changes: 2 additions & 17 deletions src/pke/include/schemebase/base-leveledshe.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class LeveledSHEBase {

// TODO: should we use just one error message instead of two (see below)
constexpr static std::string_view NOT_IMPLEMENTED_ERROR = "Not implemented for this scheme";
constexpr static std::string_view NOT_SUPPORTED_ERROR = "Not supported for this scheme";
constexpr static std::string_view NOT_SUPPORTED_ERROR = "Not supported for this scheme";

public:
virtual ~LeveledSHEBase() = default;
Expand Down Expand Up @@ -569,20 +569,6 @@ class LeveledSHEBase {
virtual std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> EvalAutomorphismKeyGen(
const PrivateKey<Element> privateKey, const std::vector<uint32_t>& indexList) const;

/**
* Virtual function to generate all isomorphism keys for a given private key
*
* @param publicKey encryption key for the new ciphertext.
* @param origPrivateKey original private key used for decryption.
* @param indexList list of automorphism indices to be computed
* @return returns the evaluation keys
*/
virtual std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> EvalAutomorphismKeyGen(
const PublicKey<Element> publicKey, const PrivateKey<Element> privateKey,
const std::vector<uint32_t>& indexList) const {
OPENFHE_THROW(NOT_IMPLEMENTED_ERROR);
}

/**
* Virtual function for evaluating automorphism of ciphertext at index i
*
Expand Down Expand Up @@ -637,8 +623,7 @@ class LeveledSHEBase {
* @return returns the evaluation keys
*/
virtual std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> EvalAtIndexKeyGen(
const PublicKey<Element> publicKey, const PrivateKey<Element> privateKey,
const std::vector<int32_t>& indexList) const;
const PrivateKey<Element> privateKey, const std::vector<int32_t>& indexList) const;

/**
* Moves i-th slot to slot 0
Expand Down
9 changes: 2 additions & 7 deletions src/pke/include/schemebase/base-scheme.h
Original file line number Diff line number Diff line change
Expand Up @@ -643,10 +643,6 @@ class SchemeBase {
virtual std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> EvalAutomorphismKeyGen(
const PrivateKey<Element> privateKey, const std::vector<uint32_t>& indexList) const;

virtual std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> EvalAutomorphismKeyGen(
const PublicKey<Element> publicKey, const PrivateKey<Element> privateKey,
const std::vector<uint32_t>& indexList) const;

virtual Ciphertext<Element> EvalAutomorphism(ConstCiphertext<Element>& ciphertext, uint32_t i,
const std::map<uint32_t, EvalKey<Element>>& evalKeyMap,
CALLER_INFO_ARGS_HDR) const {
Expand Down Expand Up @@ -717,8 +713,7 @@ class SchemeBase {
}

virtual std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> EvalAtIndexKeyGen(
const PublicKey<Element> publicKey, const PrivateKey<Element> privateKey,
const std::vector<int32_t>& indexList) const;
const PrivateKey<Element> privateKey, const std::vector<int32_t>& indexList) const;

virtual Ciphertext<Element> EvalAtIndex(ConstCiphertext<Element>& ciphertext, uint32_t i,
const std::map<uint32_t, EvalKey<Element>>& evalKeyMap) const {
Expand Down Expand Up @@ -947,7 +942,7 @@ class SchemeBase {
/////////////////////////////////////

virtual std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> EvalSumKeyGen(
const PrivateKey<Element> privateKey, const PublicKey<Element> publicKey) const;
const PrivateKey<Element> privateKey) const;

virtual std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> EvalSumRowsKeyGen(
const PrivateKey<Element> privateKey, uint32_t rowSize, uint32_t subringDim,
Expand Down
36 changes: 13 additions & 23 deletions src/pke/lib/cryptocontext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,38 +144,33 @@ void CryptoContextImpl<Element>::InsertEvalMultKey(const std::vector<EvalKey<Ele
/////////////////////////////////////////

template <typename Element>
void CryptoContextImpl<Element>::EvalSumKeyGen(const PrivateKey<Element> privateKey,
const PublicKey<Element> publicKey) {
void CryptoContextImpl<Element>::EvalSumKeyGen(const PrivateKey<Element> privateKey) {
ValidateKey(privateKey);
if (publicKey != nullptr && privateKey->GetKeyTag() != publicKey->GetKeyTag()) {
OPENFHE_THROW("Public key passed to EvalSumKeyGen does not match private key");
}

auto&& evalKeys = GetScheme()->EvalSumKeyGen(privateKey, publicKey);
auto&& evalKeys = GetScheme()->EvalSumKeyGen(privateKey);
CryptoContextImpl<Element>::InsertEvalAutomorphismKey(evalKeys, privateKey->GetKeyTag());
}

template <typename Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> CryptoContextImpl<Element>::EvalSumRowsKeyGen(
const PrivateKey<Element> privateKey, const PublicKey<Element> publicKey, uint32_t rowSize, uint32_t subringDim) {
const PrivateKey<Element> privateKey, uint32_t rowSize, uint32_t subringDim) {
ValidateKey(privateKey);
if (publicKey != nullptr && privateKey->GetKeyTag() != publicKey->GetKeyTag())
OPENFHE_THROW("Public key passed to EvalSumKeyGen does not match private key");

std::vector<uint32_t> indices;
auto&& evalKeys = GetScheme()->EvalSumRowsKeyGen(privateKey, rowSize, subringDim, indices);
CryptoContextImpl<Element>::InsertEvalAutomorphismKey(evalKeys, privateKey->GetKeyTag());

return CryptoContextImpl<Element>::GetPartialEvalAutomorphismKeyMapPtr(privateKey->GetKeyTag(), indices);
}

// TODO: this is here for backwards compatibility; should remove in v2.0
template <typename Element>
inline std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> CryptoContextImpl<Element>::EvalSumRowsKeyGen(
const PrivateKey<Element> privateKey, const PublicKey<Element> publicKey, uint32_t rowSize, uint32_t subringDim) {
return CryptoContextImpl<Element>::EvalSumRowsKeyGen(privateKey, rowSize, subringDim);
}

template <typename Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> CryptoContextImpl<Element>::EvalSumColsKeyGen(
const PrivateKey<Element> privateKey, const PublicKey<Element> publicKey) {
const PrivateKey<Element> privateKey) {
ValidateKey(privateKey);
if (publicKey != nullptr && privateKey->GetKeyTag() != publicKey->GetKeyTag())
OPENFHE_THROW("Public key passed to EvalSumKeyGen does not match private key");

std::vector<uint32_t> indices;
auto&& evalKeys = GetScheme()->EvalSumColsKeyGen(privateKey, indices);
CryptoContextImpl<Element>::InsertEvalAutomorphismKey(evalKeys, privateKey->GetKeyTag());
Expand Down Expand Up @@ -274,14 +269,9 @@ void CryptoContextImpl<Element>::ClearEvalSumKeys(const CryptoContext<Element> c

template <typename Element>
void CryptoContextImpl<Element>::EvalAtIndexKeyGen(const PrivateKey<Element> privateKey,
const std::vector<int32_t>& indexList,
const PublicKey<Element> publicKey) {
const std::vector<int32_t>& indexList) {
ValidateKey(privateKey);
if (publicKey != nullptr && privateKey->GetKeyTag() != publicKey->GetKeyTag()) {
OPENFHE_THROW("Public key passed to EvalAtIndexKeyGen does not match private key");
}

auto&& evalKeys = GetScheme()->EvalAtIndexKeyGen(publicKey, privateKey, indexList);
auto&& evalKeys = GetScheme()->EvalAtIndexKeyGen(privateKey, indexList);
CryptoContextImpl<Element>::InsertEvalAutomorphismKey(evalKeys, privateKey->GetKeyTag());
}

Expand Down
2 changes: 1 addition & 1 deletion src/pke/lib/scheme/ckksrns/ckksrns-fhe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ std::shared_ptr<std::map<uint32_t, EvalKey<DCRTPoly>>> FHECKKSRNS::EvalBootstrap
slots = (slots == 0) ? M / 4 : slots;

// computing all indices for baby-step giant-step procedure
auto evalKeys = algo->EvalAtIndexKeyGen(nullptr, privateKey, FindBootstrapRotationIndices(slots, M));
auto evalKeys = algo->EvalAtIndexKeyGen(privateKey, FindBootstrapRotationIndices(slots, M));

(*evalKeys)[M - 1] = ConjugateKeyGen(privateKey);

Expand Down
6 changes: 3 additions & 3 deletions src/pke/lib/scheme/ckksrns/ckksrns-schemeswitching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,7 @@ std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>> SWITCHCKKSRNS::EvalCKKStoFHE
// Compute multiplication key
algo->EvalMultKeyGen(privateKey);

auto evalKeys = algo->EvalAtIndexKeyGen(publicKey, privateKey, indexRotationS2C);
auto evalKeys = algo->EvalAtIndexKeyGen(privateKey, indexRotationS2C);

// Compute conjugation key
auto conjKey = ConjugateKeyGen(privateKey);
Expand Down Expand Up @@ -1539,7 +1539,7 @@ std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>> SWITCHCKKSRNS::EvalFHEWtoCKK
indexRotationHomDec.end());

auto algo = ccCKKS->GetScheme();
auto evalKeys = algo->EvalAtIndexKeyGen(publicKey, privateKey, indexRotationHomDec);
auto evalKeys = algo->EvalAtIndexKeyGen(privateKey, indexRotationHomDec);

// Compute multiplication key
ccCKKS->EvalMultKeyGen(privateKey);
Expand Down Expand Up @@ -1855,7 +1855,7 @@ std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>> SWITCHCKKSRNS::EvalSchemeSwi
indexRotationS2C.erase(std::remove(indexRotationS2C.begin(), indexRotationS2C.end(), 0), indexRotationS2C.end());

auto algo = ccCKKS->GetScheme();
auto evalKeys = algo->EvalAtIndexKeyGen(publicKey, privateKey, indexRotationS2C);
auto evalKeys = algo->EvalAtIndexKeyGen(privateKey, indexRotationS2C);

// Compute conjugation key
auto conjKey = ConjugateKeyGen(privateKey);
Expand Down
6 changes: 1 addition & 5 deletions src/pke/lib/schemebase/base-advancedshe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,9 @@ Ciphertext<Element> AdvancedSHEBase<Element>::AddRandomNoise(ConstCiphertext<Ele

template <class Element>
std::shared_ptr<std::map<usint, EvalKey<Element>>> AdvancedSHEBase<Element>::EvalSumKeyGen(
const PrivateKey<Element> privateKey, const PublicKey<Element> publicKey) const {
const PrivateKey<Element> privateKey) const {
if (!privateKey)
OPENFHE_THROW("Input private key is nullptr");
/*
* we don't validate publicKey as it is needed by NTRU-based scheme only
* NTRU-based scheme only and it is checked for null later.
*/

// get automorphism indices and convert them to a vector
std::set<uint32_t> indx_set{GenerateIndexListForEvalSum(privateKey)};
Expand Down
4 changes: 1 addition & 3 deletions src/pke/lib/schemebase/base-leveledshe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,6 @@ void LeveledSHEBase<Element>::RelinearizeInPlace(Ciphertext<Element>& ciphertext
template <class Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> LeveledSHEBase<Element>::EvalAutomorphismKeyGen(
const PrivateKey<Element> privateKey, const std::vector<uint32_t>& indexList) const {

// Do not generate duplicate keys that have been already generated and added to the static storage (map)
std::set<uint32_t> allIndices(indexList.begin(), indexList.end());
std::set<uint32_t> indicesToGenerate{
Expand Down Expand Up @@ -477,8 +476,7 @@ Ciphertext<Element> LeveledSHEBase<Element>::EvalFastRotation(

template <class Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> LeveledSHEBase<Element>::EvalAtIndexKeyGen(
const PublicKey<Element> publicKey, const PrivateKey<Element> privateKey,
const std::vector<int32_t>& indexList) const {
const PrivateKey<Element> privateKey, const std::vector<int32_t>& indexList) const {
uint32_t M = privateKey->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder();
std::vector<uint32_t> autoIndices(indexList.size());
for (size_t i = 0; i < indexList.size(); i++)
Expand Down
20 changes: 4 additions & 16 deletions src/pke/lib/schemebase/base-scheme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,9 @@ std::vector<EvalKey<Element>> SchemeBase<Element>::EvalMultKeysGen(const Private

template <typename Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> SchemeBase<Element>::EvalAtIndexKeyGen(
const PublicKey<Element> publicKey, const PrivateKey<Element> privateKey,
const std::vector<int32_t>& indexList) const {
const PrivateKey<Element> privateKey, const std::vector<int32_t>& indexList) const {
VerifyLeveledSHEEnabled(__func__);
auto evalKeyMap = m_LeveledSHE->EvalAtIndexKeyGen(publicKey, privateKey, indexList);
auto evalKeyMap = m_LeveledSHE->EvalAtIndexKeyGen(privateKey, indexList);
for (auto& key : *evalKeyMap)
key.second->SetKeyTag(privateKey->GetKeyTag());
return evalKeyMap;
Expand All @@ -102,9 +101,9 @@ Ciphertext<Element> SchemeBase<Element>::ModReduce(ConstCiphertext<Element>& cip

template <typename Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> SchemeBase<Element>::EvalSumKeyGen(
const PrivateKey<Element> privateKey, const PublicKey<Element> publicKey) const {
const PrivateKey<Element> privateKey) const {
VerifyAdvancedSHEEnabled(__func__);
auto evalKeyMap = m_AdvancedSHE->EvalSumKeyGen(privateKey, publicKey);
auto evalKeyMap = m_AdvancedSHE->EvalSumKeyGen(privateKey);
for (auto& key : *evalKeyMap)
key.second->SetKeyTag(privateKey->GetKeyTag());
return evalKeyMap;
Expand Down Expand Up @@ -323,17 +322,6 @@ std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> SchemeBase<Element>::EvalA
return evalKeyMap;
}

template <typename Element>
std::shared_ptr<std::map<uint32_t, EvalKey<Element>>> SchemeBase<Element>::EvalAutomorphismKeyGen(
const PublicKey<Element> publicKey, const PrivateKey<Element> privateKey,
const std::vector<uint32_t>& indexList) const {
VerifyLeveledSHEEnabled(__func__);
auto evalKeyMap = m_LeveledSHE->EvalAutomorphismKeyGen(publicKey, privateKey, indexList);
for (auto& key : *evalKeyMap)
key.second->SetKeyTag(privateKey->GetKeyTag());
return evalKeyMap;
}

template class SchemeBase<DCRTPoly>;

} // namespace lbcrypto
2 changes: 1 addition & 1 deletion src/pke/unittest/utckksrns/UnitTestCKKSrnsAutomorphism.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ class UTCKKSRNS_AUTOMORPHISM : public ::testing::TestWithParam<TEST_CASE_UTCKKSR
// Encrypt the encoded vectors
auto ctMat = cc->Encrypt(kp.publicKey, ptxtMat);

auto evalSumRowKeys = cc->EvalSumRowsKeyGen(kp.secretKey, nullptr, rowSize);
auto evalSumRowKeys = cc->EvalSumRowsKeyGen(kp.secretKey, rowSize);

// Evaluation
auto ctRowsSum = cc->EvalSumRows(ctMat, rowSize, *evalSumRowKeys);
Expand Down
2 changes: 1 addition & 1 deletion src/pke/unittest/utils/UnitTestSer.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void UnitTestContextWithSertype(CryptoContext<Element> cc, const ST& sertype,
try {
KeyPair<Element> kp = cc->KeyGen();
cc->EvalMultKeyGen(kp.secretKey);
cc->EvalSumKeyGen(kp.secretKey, kp.publicKey);
cc->EvalSumKeyGen(kp.secretKey);

std::stringstream s;
Serial::Serialize(cc, s, sertype);
Expand Down