@@ -94,12 +94,24 @@ Matrix<Ctxt> calculateMasks(const EncryptedArray& ea,
9494 mask = mask.columns (columns);
9595 mask.inPlaceTranspose ();
9696
97- (mask -= database)
97+ // FIXME: Avoid deep copy
98+ // Ptxt Query
99+ if constexpr (std::is_same_v<TXT, Ptxt<BGV>>) {
100+ auto tmp = database.deepCopy ();
101+ (tmp -= mask)
98102 .apply ([&](auto & entry) { mapTo01 (ea, entry); })
99103 .apply ([](auto & entry) { entry.negate (); })
100104 .apply ([](auto & entry) { entry.addConstant (NTL::ZZX (1l )); });
101105
102- return mask;
106+ return tmp;
107+ } else { // Ctxt Query
108+ (mask -= database)
109+ .apply ([&](auto & entry) { mapTo01 (ea, entry); })
110+ .apply ([](auto & entry) { entry.negate (); })
111+ .apply ([](auto & entry) { entry.addConstant (NTL::ZZX (1l )); });
112+
113+ return mask;
114+ }
103115}
104116
105117/* *
@@ -569,8 +581,8 @@ class Database
569581 * match or no match respectively.
570582 **/
571583 template <typename TXT2>
572- Matrix<TXT2> contains (const Query_t& lookup_query,
573- const Matrix<TXT2>& query_data) const ;
584+ auto contains (const Query_t& lookup_query,
585+ const Matrix<TXT2>& query_data) const ;
574586
575587 // FIXME: Combination of TXT = ctxt and TXT2 = ptxt does not work
576588 /* *
@@ -583,8 +595,8 @@ class Database
583595 * @return A `Matrix<TXT2>` containing a score on weighted matches.
584596 **/
585597 template <typename TXT2>
586- Matrix<TXT2> getScore (const Query_t& weighted_query,
587- const Matrix<TXT2>& query_data) const ;
598+ auto getScore (const Query_t& weighted_query,
599+ const Matrix<TXT2>& query_data) const ;
588600
589601 // TODO - correct name?
590602 /* *
@@ -593,14 +605,16 @@ class Database
593605 **/
594606 long columns () { return data.dims (1 ); }
595607
608+ Matrix<TXT>& getData ();
609+
596610private:
597611 Matrix<TXT> data;
598612 std::shared_ptr<const Context> context;
599613};
600614
601615template <typename TXT>
602616template <typename TXT2>
603- inline Matrix<TXT2> Database<TXT>::contains(
617+ inline auto Database<TXT>::contains(
604618 const Query_t& lookup_query,
605619 const Matrix<TXT2>& query_data) const
606620{
@@ -619,7 +633,7 @@ inline Matrix<TXT2> Database<TXT>::contains(
619633
620634template <typename TXT>
621635template <typename TXT2>
622- inline Matrix<TXT2> Database<TXT>::getScore(
636+ inline auto Database<TXT>::getScore(
623637 const Query_t& weighted_query,
624638 const Matrix<TXT2>& query_data) const
625639{
@@ -633,6 +647,12 @@ inline Matrix<TXT2> Database<TXT>::getScore(
633647 return result;
634648}
635649
650+ template <typename TXT>
651+ inline Matrix<TXT>& Database<TXT>::getData()
652+ {
653+ return data;
654+ }
655+
636656} // namespace helib
637657
638658#endif
0 commit comments