Skip to content

Commit d9e6784

Browse files
binfhe_bindings cleanup
1 parent d70f780 commit d9e6784

File tree

3 files changed

+52
-110
lines changed

3 files changed

+52
-110
lines changed

src/include/binfhe/binfhecontext_wrapper.h

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -40,34 +40,7 @@
4040

4141
namespace py = pybind11;
4242
using namespace lbcrypto;
43-
LWECiphertext binfhe_EncryptWrapper(BinFHEContext &self,
44-
ConstLWEPrivateKey sk,
45-
const LWEPlaintext &m,
46-
BINFHE_OUTPUT output,
47-
LWEPlaintextModulus p,
48-
uint64_t mod);
49-
LWEPlaintext binfhe_DecryptWrapper(BinFHEContext &self,
50-
ConstLWEPrivateKey sk,
51-
ConstLWECiphertext ct,
52-
LWEPlaintextModulus p);
53-
54-
uint32_t GetnWrapper(BinFHEContext &self);
55-
56-
const uint64_t GetqWrapper(BinFHEContext &self) ;
57-
58-
const uint64_t GetMaxPlaintextSpaceWrapper(BinFHEContext &self);
59-
60-
const uint64_t GetBetaWrapper(BinFHEContext &self);
61-
62-
const uint64_t GetLWECiphertextModulusWrapper(LWECiphertext &self);
6343

6444
std::vector<uint64_t> GenerateLUTviaFunctionWrapper(BinFHEContext &self, py::function f, uint64_t p);
6545

66-
NativeInteger StaticFunction(NativeInteger m, NativeInteger p);
67-
68-
// Define static variables to hold the state
69-
// extern py::function static_f;
70-
71-
LWECiphertext EvalFuncWrapper(BinFHEContext &self, ConstLWECiphertext &ct, const std::vector<uint64_t> &LUT);
72-
7346
#endif // __BINFHECONTEXT_WRAPPER_H__

src/lib/binfhe/binfhecontext_wrapper.cpp

Lines changed: 7 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -36,89 +36,28 @@
3636
using namespace lbcrypto;
3737
namespace py = pybind11;
3838

39-
LWECiphertext binfhe_EncryptWrapper(BinFHEContext &self, ConstLWEPrivateKey sk, const LWEPlaintext &m, BINFHE_OUTPUT output,
40-
LWEPlaintextModulus p, uint64_t mod)
41-
{
42-
NativeInteger mod_native_int = NativeInteger(mod);
43-
return self.Encrypt(sk, m, output, p, mod_native_int);
44-
}
45-
46-
LWEPlaintext binfhe_DecryptWrapper(BinFHEContext &self,
47-
ConstLWEPrivateKey sk,
48-
ConstLWECiphertext ct,
49-
LWEPlaintextModulus p)
50-
{
51-
52-
LWEPlaintext result;
53-
self.Decrypt(sk, ct, &result, p);
54-
return result;
55-
}
56-
57-
uint32_t GetnWrapper(BinFHEContext &self)
58-
{
59-
return self.GetParams()->GetLWEParams()->Getn();
60-
}
61-
62-
const uint64_t GetqWrapper(BinFHEContext &self)
63-
{
64-
return self.GetParams()->GetLWEParams()->Getq().ConvertToInt<uint64_t>();
65-
}
66-
67-
const uint64_t GetMaxPlaintextSpaceWrapper(BinFHEContext &self)
68-
{
69-
return self.GetMaxPlaintextSpace().ConvertToInt<uint64_t>();
70-
}
71-
72-
const uint64_t GetBetaWrapper(BinFHEContext &self)
73-
{
74-
return self.GetBeta().ConvertToInt<uint64_t>();
75-
}
76-
77-
const uint64_t GetLWECiphertextModulusWrapper(LWECiphertext &self)
78-
{
79-
return self->GetModulus().ConvertToInt<uint64_t>();
80-
}
8139

8240
// Define static variables to hold the state
8341
py::function* static_f = nullptr;
8442

8543
// Define a static function that uses the static variables
8644
NativeInteger StaticFunction(NativeInteger m, NativeInteger p) {
87-
// Convert the arguments to int
88-
uint64_t m_int = m.ConvertToInt<uint64_t>();
89-
uint64_t p_int = p.ConvertToInt<uint64_t>();
9045
// Call the Python function
91-
py::object result_py = (*static_f)(m_int, p_int);
46+
py::object result_py = (*static_f)(m.ConvertToInt<uint64_t>(), p.ConvertToInt<uint64_t>());
9247
// Convert the result to a NativeInteger
9348
return NativeInteger(py::cast<uint64_t>(result_py));
9449
}
9550

96-
std::vector<uint64_t> GenerateLUTviaFunctionWrapper(BinFHEContext &self, py::function f, uint64_t p)
97-
{
98-
NativeInteger p_native_int = NativeInteger(p);
51+
std::vector<uint64_t> GenerateLUTviaFunctionWrapper(BinFHEContext &self, py::function f, uint64_t p) {
9952
static_f = &f;
100-
std::vector<NativeInteger> result = self.GenerateLUTviaFunction(StaticFunction, p_native_int);
53+
std::vector<NativeInteger> result = self.GenerateLUTviaFunction(StaticFunction, NativeInteger(p));
10154
static_f = nullptr;
55+
10256
std::vector<uint64_t> result_uint64_t;
103-
// int size_int = static_cast<int>(result.size());
57+
result.reserve(result.size());
10458
for (const auto& value : result)
105-
{
106-
result_uint64_t.push_back(value.ConvertToInt<uint64_t>());
107-
}
108-
return result_uint64_t;
109-
}
59+
result_uint64_t.emplace_back(value.ConvertToInt<uint64_t>());
11060

111-
// LWECiphertext EvalFunc(ConstLWECiphertext &ct, const std::vector<NativeInteger> &LUT) const
112-
LWECiphertext EvalFuncWrapper(BinFHEContext &self, ConstLWECiphertext &ct, const std::vector<uint64_t> &LUT)
113-
{
114-
std::vector<NativeInteger> LUT_native_int;
115-
LUT_native_int.reserve(LUT.size()); // Reserve space for the elements
116-
for (const auto& value : LUT)
117-
{
118-
LUT_native_int.push_back(NativeInteger(value));
119-
}
120-
return self.EvalFunc(ct, LUT_native_int);
61+
return result_uint64_t;
12162
}
12263

123-
124-

src/lib/binfhe_bindings.cpp

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,10 @@ void bind_binfhe_ciphertext(py::module &m) {
147147
py::class_<LWECiphertextImpl, std::shared_ptr<LWECiphertextImpl>>(m, "LWECiphertext")
148148
.def(py::init<>())
149149
.def("GetLength", &LWECiphertextImpl::GetLength)
150-
.def("GetModulus", &GetLWECiphertextModulusWrapper)
150+
.def("GetModulus",
151+
[](LWECiphertext& self) {
152+
return self->GetModulus().ConvertToInt<uint64_t>();
153+
})
151154
.def(py::self == py::self)
152155
.def(py::self != py::self);
153156
}
@@ -179,18 +182,26 @@ void bind_binfhe_context(py::module &m) {
179182
binfhe_BTKeyGen_docs,
180183
py::arg("sk"),
181184
py::arg("keygenMode") = SYM_ENCRYPT)
182-
.def("Encrypt", &binfhe_EncryptWrapper,
183-
binfhe_Encrypt_docs,
185+
.def("Encrypt",
186+
[](BinFHEContext& self, ConstLWEPrivateKey sk, const LWEPlaintext& m, BINFHE_OUTPUT output, LWEPlaintextModulus p, uint64_t mod) {
187+
return self.Encrypt(sk, m, output, p, NativeInteger(mod));
188+
},
184189
py::arg("sk"),
185190
py::arg("m"),
186191
py::arg("output") = BOOTSTRAPPED,
187192
py::arg("p") = 4,
188-
py::arg("mod") = 0)
189-
.def("Decrypt", &binfhe_DecryptWrapper,
190-
binfhe_Decrypt_docs,
193+
py::arg("mod") = 0,
194+
py::doc(binfhe_Encrypt_docs))
195+
.def("Decrypt",
196+
[](BinFHEContext& self, ConstLWEPrivateKey sk, ConstLWECiphertext ct, LWEPlaintextModulus p) {
197+
LWEPlaintext result;
198+
self.Decrypt(sk, ct, &result, p);
199+
return result;
200+
},
191201
py::arg("sk"),
192202
py::arg("ct"),
193-
py::arg("p") = 4)
203+
py::arg("p") = 4,
204+
py::doc(binfhe_Decrypt_docs))
194205
.def("EvalBinGate",
195206
py::overload_cast<BINGATE, ConstLWECiphertext&, ConstLWECiphertext&, bool>(&BinFHEContext::EvalBinGate, py::const_),
196207
binfhe_EvalBinGate_docs,
@@ -206,10 +217,22 @@ void bind_binfhe_context(py::module &m) {
206217
.def("EvalNOT", &BinFHEContext::EvalNOT,
207218
binfhe_EvalNOT_docs,
208219
py::arg("ct"))
209-
.def("Getn", &GetnWrapper)
210-
.def("Getq", &GetqWrapper)
211-
.def("GetMaxPlaintextSpace", &GetMaxPlaintextSpaceWrapper)
212-
.def("GetBeta", &GetBetaWrapper)
220+
.def("Getn",
221+
[](BinFHEContext& self) {
222+
return self.GetParams()->GetLWEParams()->Getn();
223+
})
224+
.def("Getq",
225+
[](BinFHEContext& self) {
226+
return self.GetParams()->GetLWEParams()->Getq().ConvertToInt<uint64_t>();
227+
})
228+
.def("GetMaxPlaintextSpace",
229+
[](BinFHEContext& self) {
230+
return self.GetMaxPlaintextSpace().ConvertToInt<uint64_t>();
231+
})
232+
.def("GetBeta",
233+
[](BinFHEContext& self) {
234+
return self.GetBeta().ConvertToInt<uint64_t>();
235+
})
213236
.def("EvalDecomp", &BinFHEContext::EvalDecomp,
214237
binfhe_EvalDecomp_docs,
215238
py::arg("ct"))
@@ -221,15 +244,22 @@ void bind_binfhe_context(py::module &m) {
221244
binfhe_GenerateLUTviaFunction_docs,
222245
py::arg("f"),
223246
py::arg("p"))
224-
.def("EvalFunc", &EvalFuncWrapper,
225-
binfhe_EvalFunc_docs,
247+
.def("EvalFunc",
248+
[](BinFHEContext& self, ConstLWECiphertext& ct, const std::vector<uint64_t>& LUT) {
249+
std::vector<NativeInteger> nativeLUT;
250+
nativeLUT.reserve(LUT.size());
251+
for (auto value : LUT) {
252+
nativeLUT.emplace_back(value);
253+
}
254+
return self.EvalFunc(ct, nativeLUT);
255+
},
226256
py::arg("ct"),
227-
py::arg("LUT"))
257+
py::arg("LUT"),
258+
py::doc(binfhe_EvalFunc_docs))
228259
.def("EvalSign", &BinFHEContext::EvalSign,
229260
binfhe_EvalSign_docs,
230261
py::arg("ct"),
231262
py::arg("schemeSwitch") = false)
232-
.def("EvalNOT", &BinFHEContext::EvalNOT)
233263
.def("EvalConstant", &BinFHEContext::EvalConstant)
234264
.def("ClearBTKeys", &BinFHEContext::ClearBTKeys)
235265
.def("Bootstrap", &BinFHEContext::Bootstrap, py::arg("ct"), py::arg("extended") = false)

0 commit comments

Comments
 (0)