-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdatabase.py
More file actions
377 lines (330 loc) · 12.8 KB
/
database.py
File metadata and controls
377 lines (330 loc) · 12.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
# database.py
import sqlite3
import datetime
import csv
import io
import geoip2.database
import bcrypt
from config import Config
def init_db():
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
# Logs Table
c.execute('''CREATE TABLE IF NOT EXISTS logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT,
ip_address TEXT,
method TEXT,
url TEXT,
headers TEXT,
payload TEXT,
attack_type TEXT,
risk_score INTEGER,
action TEXT,
country TEXT
)''')
# IP Bans Table
c.execute('''CREATE TABLE IF NOT EXISTS bans (
ip_address TEXT PRIMARY KEY,
banned_at TEXT,
expires_at TEXT,
reason TEXT
)''')
# Admin Users Table
c.execute('''CREATE TABLE IF NOT EXISTS admin_users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE,
password_hash TEXT,
role TEXT,
telegram_chat_id TEXT,
telegram_sync_token TEXT
)''')
# IP Reputation Cache Table
c.execute('''CREATE TABLE IF NOT EXISTS ip_reputation (
ip_address TEXT PRIMARY KEY,
score INTEGER,
last_checked TEXT
)''')
# --- Adaptive Defense Rules Table ---
c.execute('''CREATE TABLE IF NOT EXISTS adaptive_rules (
id INTEGER PRIMARY KEY AUTOINCREMENT,
pattern TEXT UNIQUE,
attack_type TEXT,
confidence INTEGER,
status TEXT DEFAULT 'pending', -- 'pending', 'approved', 'rejected'
created_at TEXT
)''')
# Migrations
c.execute("PRAGMA table_info(logs)")
columns = [info[1] for info in c.fetchall()]
if 'country' not in columns:
print("⚠️ Migrating database: Adding 'country' column to logs...")
c.execute("ALTER TABLE logs ADD COLUMN country TEXT DEFAULT 'Unknown'")
c.execute("PRAGMA table_info(admin_users)")
admin_columns = [info[1] for info in c.fetchall()]
if 'telegram_chat_id' not in admin_columns:
print("⚠️ Migrating database: Adding Telegram columns to admin_users...")
c.execute("ALTER TABLE admin_users ADD COLUMN telegram_chat_id TEXT")
c.execute("ALTER TABLE admin_users ADD COLUMN telegram_sync_token TEXT")
# Seed Default Admin
c.execute("SELECT COUNT(*) FROM admin_users")
if c.fetchone()[0] == 0:
print(f"⚠️ Seeding default admin user: {Config.DEFAULT_ADMIN_USER}")
salt = bcrypt.gensalt()
hashed = bcrypt.hashpw(Config.DEFAULT_ADMIN_PASS.encode('utf-8'), salt)
c.execute("INSERT INTO admin_users (username, password_hash, role) VALUES (?, ?, ?)",
(Config.DEFAULT_ADMIN_USER, hashed.decode('utf-8'), 'admin'))
conn.commit()
conn.close()
# --- ADMIN AUTHENTICATION ---
def verify_admin(username, password):
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("SELECT id, username, password_hash, role FROM admin_users WHERE username = ?", (username,))
user = c.fetchone()
conn.close()
if user:
stored_hash = user[2].encode('utf-8')
if bcrypt.checkpw(password.encode('utf-8'), stored_hash):
return {"id": user[0], "username": user[1], "role": user[3]}
return None
# --- TELEGRAM PAIRING LOGIC ---
def set_telegram_sync_token(username, token):
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("UPDATE admin_users SET telegram_sync_token = ? WHERE username = ?", (token, username))
conn.commit()
conn.close()
def link_telegram_account(token, chat_id):
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("UPDATE admin_users SET telegram_chat_id = ?, telegram_sync_token = NULL WHERE telegram_sync_token = ?", (str(chat_id), token))
success = conn.total_changes > 0
conn.commit()
conn.close()
return success
def get_telegram_status(username):
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("SELECT telegram_chat_id, telegram_sync_token FROM admin_users WHERE username = ?", (username,))
row = c.fetchone()
conn.close()
if not row: return {"status": "unlinked"}
if row[0]: return {"status": "linked"}
if row[1]: return {"status": "pending", "token": row[1]}
return {"status": "unlinked"}
def get_all_telegram_chat_ids():
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("SELECT telegram_chat_id FROM admin_users WHERE telegram_chat_id IS NOT NULL")
ids = [row[0] for row in c.fetchall()]
conn.close()
return ids
# THIS IS THE FUNCTION IT WAS COMPLAINING ABOUT MISSING
def disconnect_telegram_account(username):
"""Wipes the Telegram pairing data for the given admin user."""
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("UPDATE admin_users SET telegram_chat_id = NULL, telegram_sync_token = NULL WHERE username = ?", (username,))
success = conn.total_changes > 0
conn.commit()
conn.close()
return success
# --- DETECT LOCAL IPs ---
def get_country_from_ip(ip):
if ip == '127.0.0.1' or ip == 'localhost' or ip.startswith('192.168.') or ip.startswith('10.'):
return 'Local'
try:
with geoip2.database.Reader(Config.GEOIP_DB_PATH) as reader:
response = reader.city(ip)
return response.country.iso_code or 'Unknown'
except Exception:
return 'Unknown'
def log_event(ip, method, url, headers, payload, attack_type, score, action):
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
country = get_country_from_ip(ip)
c.execute("""INSERT INTO logs
(timestamp, ip_address, method, url, headers, payload, attack_type, risk_score, action, country)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(timestamp, ip, method, url, str(headers), str(payload), attack_type, score, action, country))
conn.commit()
conn.close()
def ban_ip(ip, reason):
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
now = datetime.datetime.now()
expires = now + datetime.timedelta(seconds=Config.BAN_DURATION)
c.execute("INSERT OR REPLACE INTO bans (ip_address, banned_at, expires_at, reason) VALUES (?, ?, ?, ?)",
(ip, now.strftime("%Y-%m-%d %H:%M:%S"), expires.strftime("%Y-%m-%d %H:%M:%S"), reason))
conn.commit()
conn.close()
def is_ip_banned(ip):
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("SELECT expires_at FROM bans WHERE ip_address = ?", (ip,))
result = c.fetchone()
conn.close()
if result:
expires_at = datetime.datetime.strptime(result[0], "%Y-%m-%d %H:%M:%S")
if datetime.datetime.now() < expires_at:
return True
else:
unban_ip(ip)
return False
def unban_ip(ip):
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("DELETE FROM bans WHERE ip_address = ?", (ip,))
conn.commit()
conn.close()
# --- THREAT INTEL REPUTATION CACHING ---
def get_cached_reputation(ip):
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("SELECT score, last_checked FROM ip_reputation WHERE ip_address = ?", (ip,))
row = c.fetchone()
conn.close()
if row:
score, last_checked_str = row
last_checked = datetime.datetime.strptime(last_checked_str, "%Y-%m-%d %H:%M:%S")
if datetime.datetime.now() < last_checked + datetime.timedelta(hours=Config.ABUSEIPDB_CACHE_HOURS):
return score
return None
def cache_reputation(ip, score):
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
c.execute("INSERT OR REPLACE INTO ip_reputation (ip_address, score, last_checked) VALUES (?, ?, ?)",
(ip, score, now))
conn.commit()
conn.close()
# --- ADAPTIVE DEFENSE LOGIC ---
def suggest_rule(pattern, attack_type, confidence=85):
"""Called by the WAF engine. NOW FULLY AUTONOMOUS."""
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
try:
# Notice we are now hardcoding status to 'approved' so it goes live instantly!
c.execute("INSERT INTO adaptive_rules (pattern, attack_type, confidence, status, created_at) VALUES (?, ?, ?, 'approved', ?)",
(pattern, attack_type, confidence, now))
conn.commit()
except sqlite3.IntegrityError:
pass # Pattern already active
finally:
conn.close()
def get_active_custom_rules():
"""Fetches all rules that the WAF engine should load into active memory."""
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("SELECT pattern, attack_type FROM adaptive_rules WHERE status = 'approved'")
rules = c.fetchall()
conn.close()
return rules
# --- NEW: Fetch all learned rules for the Settings Summary UI ---
def get_all_ai_rules():
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("SELECT id, pattern, attack_type, confidence, created_at FROM adaptive_rules WHERE status = 'approved' ORDER BY id DESC")
rules = c.fetchall()
conn.close()
return rules
# --- NEW: Surgical removal of a single rule (False Positive correction) ---
def delete_ai_rule(rule_id):
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("DELETE FROM adaptive_rules WHERE id = ?", (rule_id,))
success = conn.total_changes > 0
conn.commit()
conn.close()
return success
# --- DATA FETCHING FOR DASHBOARD ---
def get_all_logs():
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("SELECT * FROM logs ORDER BY id DESC LIMIT 200")
logs = c.fetchall()
conn.close()
return logs
def get_log_by_id(log_id):
conn = sqlite3.connect(Config.DB_NAME)
conn.row_factory = sqlite3.Row
c = conn.cursor()
c.execute("SELECT * FROM logs WHERE id = ?", (log_id,))
row = c.fetchone()
conn.close()
if row: return dict(row)
return None
def get_all_bans():
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("SELECT * FROM bans ORDER BY banned_at DESC")
bans = c.fetchall()
conn.close()
return bans
def get_stats():
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("SELECT COUNT(*) FROM logs")
total_requests = c.fetchone()[0]
c.execute("SELECT COUNT(*) FROM logs WHERE action='BLOCKED'")
blocked_requests = c.fetchone()[0]
c.execute("SELECT COUNT(*) FROM bans")
active_bans = c.fetchone()[0]
c.execute("SELECT attack_type, COUNT(*) FROM logs WHERE attack_type != 'Normal' GROUP BY attack_type")
attack_dist = dict(c.fetchall())
c.execute("SELECT ip_address, COUNT(*) as count FROM logs WHERE attack_type != 'Normal' GROUP BY ip_address ORDER BY count DESC LIMIT 5")
top_ips = dict(c.fetchall())
c.execute("SELECT country, COUNT(*) as count FROM logs WHERE attack_type != 'Normal' GROUP BY country ORDER BY count DESC LIMIT 5")
top_countries = dict(c.fetchall())
# --- NEW: THREAT HEATMAP DATA ---
c.execute("SELECT url, COUNT(*) as count FROM logs WHERE attack_type != 'Normal' GROUP BY url ORDER BY count DESC LIMIT 5")
top_endpoints = dict(c.fetchall())
c.execute("SELECT * FROM logs ORDER BY id DESC LIMIT 10")
recent_logs = c.fetchall()
conn.close()
return {
"total": total_requests,
"blocked": blocked_requests,
"bans": active_bans,
"attacks": attack_dist,
"top_ips": top_ips,
"top_countries": top_countries,
"top_endpoints": top_endpoints, # Added here!
"logs": recent_logs
}
# --- MAINTENANCE ---
def clear_database():
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("DELETE FROM logs")
c.execute("DELETE FROM bans")
c.execute("DELETE FROM ip_reputation")
# Reset auto-increment counters for a clean slate
c.execute("DELETE FROM sqlite_sequence WHERE name='logs'")
c.execute("DELETE FROM sqlite_sequence WHERE name='bans'")
conn.commit()
conn.close()
def export_logs_csv():
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("SELECT * FROM logs ORDER BY timestamp DESC")
rows = c.fetchall()
headers = [description[0] for description in c.description] if c.description else []
conn.close()
if not rows: return ""
output = io.StringIO()
writer = csv.writer(output)
writer.writerow(headers)
writer.writerows(rows)
return output.getvalue()
def clear_ai_knowledge():
conn = sqlite3.connect(Config.DB_NAME)
c = conn.cursor()
c.execute("DELETE FROM adaptive_rules")
# Reset auto-increment counter for rules
c.execute("DELETE FROM sqlite_sequence WHERE name='adaptive_rules'")
conn.commit()
conn.close()