-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver_side.py
More file actions
100 lines (89 loc) · 3.65 KB
/
server_side.py
File metadata and controls
100 lines (89 loc) · 3.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from fastapi import FastAPI
from pydantic import BaseModel
from Server.predict import (
init_context,
encrypted_predict_with_encrypted_weights
)
from fastapi.middleware.cors import CORSMiddleware
import piheaan as heaan
import base64
import tempfile
# ----------------------------- #
# ⚙️ Initialize FastAPI application
# ----------------------------- #
app = FastAPI()
# CORS Config
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------- #
# 🔑 Initialize encryption context and evaluator
# - These are required for homomorphic computations
# ---------------------------------------------------- #
context, evaluator = init_context()
# ------------------------------------------------------------------ #
# 🔁 Utility: Convert list of base64 strings back to Ciphertext objects
# ------------------------------------------------------------------ #
def recover_ciphertexts(base64_list: list[str], context: heaan.Context) -> list[heaan.Ciphertext]:
recovered_ciphertexts = []
for b64 in base64_list:
ct = heaan.Ciphertext(context)
binary_data = base64.b64decode(b64)
with tempfile.NamedTemporaryFile(delete=True) as tmp_file:
tmp_file.write(binary_data)
tmp_file.flush()
ct.load(tmp_file.name)
recovered_ciphertexts.append(ct)
return recovered_ciphertexts
# ------------------------------------------------------------- #
# 📦 Request model for /encrypted-predict endpoint
# - All inputs are base64-encoded Ciphertexts
# ------------------------------------------------------------- #
class EncryptedPredictRequest(BaseModel):
encrypted_vector: list[str] # List of base64-encoded ciphertexts for the input vector
encrypted_coef: list[str] # List of base64-encoded encrypted model coefficients
encrypted_intercept: str # Single base64-encoded encrypted intercept
# -------------------------------------------------------------------------- #
# 🔮 Prediction endpoint: /encrypted-predict
# - Performs encrypted inference using encrypted weights and inputs
# - Returns the encrypted result encoded in base64
# -------------------------------------------------------------------------- #
@app.post("/encrypted-predict")
def encrypted_predict_api(request: EncryptedPredictRequest):
# 1️⃣ Recover encrypted input vector from base64
encrypted_vec = recover_ciphertexts(
base64_list=request.encrypted_vector,
context=context
)
# 2️⃣ Recover encrypted model coefficients from base64
encrypted_coef = recover_ciphertexts(
base64_list=request.encrypted_coef,
context=context
)
# 3️⃣ Recover encrypted intercept term (only one)
encrypted_intercept = recover_ciphertexts(
base64_list=[request.encrypted_intercept], # Wrap single value in list
context=context
)[0] # Use the single recovered ciphertext
# 4️⃣ Run encrypted prediction using homomorphic operations
encrypted_result = encrypted_predict_with_encrypted_weights(
encrypted_vec,
encrypted_coef,
encrypted_intercept,
context,
evaluator
)
# 5️⃣ Convert encrypted result to base64 for transmission
with tempfile.NamedTemporaryFile(delete=True) as tmp_file:
encrypted_result.save(tmp_file.name)
tmp_file.seek(0)
result_binary = tmp_file.read()
result_b64 = base64.b64encode(result_binary).decode("utf-8")
# 6️⃣ Return base64-encoded result
return {
"encrypted_result": result_b64
}