-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathduplicate_scan_engine.py
More file actions
566 lines (496 loc) · 18.2 KB
/
duplicate_scan_engine.py
File metadata and controls
566 lines (496 loc) · 18.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
from __future__ import annotations
import hashlib
import logging
import os
import sqlite3
import time
from dataclasses import dataclass
from typing import Callable, Iterable, Sequence
import librosa
import numpy as np
from near_duplicate_detector import fingerprint_distance, _parse_fp
from simple_duplicate_finder import SUPPORTED_EXTS
import chromaprint_utils
logger = logging.getLogger(__name__)
LogCallback = Callable[[str], None]
@dataclass
class DuplicateScanConfig:
sample_rate: int = 11025
max_analysis_sec: float = 120.0
duration_tolerance_ms: int = 2000
duration_tolerance_ratio: float = 0.01
rms_tolerance_db: float | None = 6.0
centroid_tolerance: float | None = 1500.0
rolloff_tolerance: float | None = None
fp_bands: int = 8
min_band_collisions: int = 2
fp_distance_threshold: float = 0.2
chroma_max_offset_frames: int = 12
chroma_match_threshold: float = 0.82
chroma_possible_threshold: float = 0.72
@dataclass
class DuplicateScanSummary:
tracks_total: int
headers_updated: int
fingerprints_updated: int
edges_written: int
groups_written: int
def _log(log_callback: LogCallback | None, message: str) -> None:
if log_callback:
log_callback(message)
else:
logger.info(message)
def _connect(db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(db_path)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
return conn
def ensure_schema(conn: sqlite3.Connection) -> None:
conn.executescript(
"""
CREATE TABLE IF NOT EXISTS audio_header (
track_id TEXT PRIMARY KEY,
mtime REAL,
size INTEGER,
duration_ms INTEGER,
rms_db REAL,
centroid_mean REAL,
rolloff_mean REAL,
updated_at REAL
);
CREATE TABLE IF NOT EXISTS audio_fingerprint (
track_id TEXT PRIMARY KEY,
mtime REAL,
size INTEGER,
fp_blob BLOB,
fp_len INTEGER,
fp_version TEXT
);
CREATE TABLE IF NOT EXISTS fp_lsh (
band_hash TEXT NOT NULL,
band_index INTEGER NOT NULL,
track_id TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS dup_edges (
track_id_a TEXT NOT NULL,
track_id_b TEXT NOT NULL,
score REAL NOT NULL,
method TEXT NOT NULL,
verified_at REAL NOT NULL
);
CREATE TABLE IF NOT EXISTS dup_groups (
group_id TEXT NOT NULL,
track_id TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_audio_header_duration ON audio_header(duration_ms);
CREATE INDEX IF NOT EXISTS idx_audio_header_centroid ON audio_header(centroid_mean);
CREATE INDEX IF NOT EXISTS idx_fp_lsh_band ON fp_lsh(band_hash, band_index);
CREATE INDEX IF NOT EXISTS idx_fp_lsh_track ON fp_lsh(track_id);
CREATE INDEX IF NOT EXISTS idx_dup_edges_pair ON dup_edges(track_id_a, track_id_b);
CREATE INDEX IF NOT EXISTS idx_dup_groups_track ON dup_groups(track_id);
"""
)
EXCLUDED_DIRS = {"not sorted", "playlists"}
def _is_excluded_path(path: str, root: str) -> bool:
rel_path = os.path.relpath(path, root)
parts = {part.lower() for part in rel_path.split(os.sep)}
return bool(EXCLUDED_DIRS & parts)
def _list_audio_files(root: str) -> list[str]:
paths: list[str] = []
for base, dirnames, files in os.walk(root):
if _is_excluded_path(base, root):
dirnames[:] = []
continue
dirnames[:] = [d for d in dirnames if d.lower() not in EXCLUDED_DIRS]
for name in files:
ext = os.path.splitext(name)[1].lower()
if ext in SUPPORTED_EXTS:
paths.append(os.path.join(base, name))
return sorted(paths)
def _compute_audio_header(
path: str,
sample_rate: int,
max_analysis_sec: float,
) -> tuple[int, float, float, float]:
y, sr = librosa.load(path, sr=sample_rate, mono=True, duration=max_analysis_sec)
if y.size == 0:
raise ValueError("empty audio")
duration_ms = int(round(librosa.get_duration(y=y, sr=sr) * 1000))
rms = librosa.feature.rms(y=y)[0]
rms_db = float(librosa.amplitude_to_db(np.mean(rms), ref=1.0))
centroid = librosa.feature.spectral_centroid(y=y, sr=sr)[0]
centroid_mean = float(np.mean(centroid))
rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)[0]
rolloff_mean = float(np.mean(rolloff))
return duration_ms, rms_db, centroid_mean, rolloff_mean
def _fingerprint_to_bytes(fp: str) -> bytes:
return fp.encode("utf-8")
def _fingerprint_band_hashes(fp: str, bands: int) -> list[tuple[int, str]]:
parsed = _parse_fp(fp)
if not parsed:
return []
kind, data = parsed
if kind == "ints":
items: Sequence[int] = data # type: ignore[assignment]
band_size = max(1, len(items) // bands)
hashes = []
for idx in range(bands):
start = idx * band_size
if start >= len(items):
break
end = len(items) if idx == bands - 1 else start + band_size
slice_items = items[start:end]
if not slice_items:
continue
payload = ",".join(str(v) for v in slice_items).encode("utf-8")
hashes.append((idx, hashlib.sha1(payload).hexdigest()))
return hashes
blob: bytes = data # type: ignore[assignment]
band_size = max(1, len(blob) // bands)
hashes = []
for idx in range(bands):
start = idx * band_size
if start >= len(blob):
break
end = len(blob) if idx == bands - 1 else start + band_size
payload = blob[start:end]
if not payload:
continue
hashes.append((idx, hashlib.sha1(payload).hexdigest()))
return hashes
def _chroma_sequence(
path: str,
sample_rate: int,
max_analysis_sec: float,
hop_length: int = 512,
) -> np.ndarray:
y, sr = librosa.load(path, sr=sample_rate, mono=True, duration=max_analysis_sec)
if y.size == 0:
raise ValueError("empty audio")
chroma = librosa.feature.chroma_stft(y=y, sr=sr, hop_length=hop_length)
chroma = chroma / (np.linalg.norm(chroma, axis=0, keepdims=True) + 1e-8)
return chroma
def _alignment_score(
seq_a: np.ndarray,
seq_b: np.ndarray,
max_offset_frames: int,
) -> float:
if seq_a.size == 0 or seq_b.size == 0:
return 0.0
max_score = 0.0
for offset in range(-max_offset_frames, max_offset_frames + 1):
if offset < 0:
a = seq_a[:, :offset]
b = seq_b[:, -offset:]
elif offset > 0:
a = seq_a[:, offset:]
b = seq_b[:, :-offset]
else:
a = seq_a
b = seq_b
if a.shape[1] == 0 or b.shape[1] == 0:
continue
n = min(a.shape[1], b.shape[1])
if n == 0:
continue
a_seg = a[:, :n]
b_seg = b[:, :n]
score = float(np.mean(np.sum(a_seg * b_seg, axis=0)))
if score > max_score:
max_score = score
return max_score
def _stage1_candidates(
conn: sqlite3.Connection,
target_id: str,
config: DuplicateScanConfig,
) -> set[str]:
row = conn.execute(
"SELECT duration_ms, rms_db, centroid_mean, rolloff_mean FROM audio_header WHERE track_id=?",
(target_id,),
).fetchone()
if not row:
return set()
duration_ms, rms_db, centroid_mean, rolloff_mean = row
if duration_ms is None:
return set()
duration_ms = int(duration_ms)
tolerance = max(
config.duration_tolerance_ms,
int(round(config.duration_tolerance_ratio * duration_ms)),
)
conditions = ["track_id != ?", "ABS(duration_ms - ?) <= ?"]
params: list[object] = [target_id, duration_ms, tolerance]
if config.centroid_tolerance is not None and centroid_mean is not None:
conditions.append("ABS(centroid_mean - ?) <= ?")
params.extend([centroid_mean, config.centroid_tolerance])
if config.rms_tolerance_db is not None and rms_db is not None:
conditions.append("ABS(rms_db - ?) <= ?")
params.extend([rms_db, config.rms_tolerance_db])
if config.rolloff_tolerance is not None and rolloff_mean is not None:
conditions.append("ABS(rolloff_mean - ?) <= ?")
params.extend([rolloff_mean, config.rolloff_tolerance])
query = f"SELECT track_id FROM audio_header WHERE {' AND '.join(conditions)}"
rows = conn.execute(query, params).fetchall()
return {r[0] for r in rows}
def _stage2_candidates(
conn: sqlite3.Connection,
target_id: str,
stage1: set[str],
config: DuplicateScanConfig,
) -> set[str]:
row = conn.execute(
"SELECT fp_blob FROM audio_fingerprint WHERE track_id=?",
(target_id,),
).fetchone()
if not row or row[0] is None:
return set()
fp_text = row[0].decode("utf-8")
bands = _fingerprint_band_hashes(fp_text, config.fp_bands)
if not bands:
return set()
collision_counts: dict[str, int] = {}
for band_index, band_hash in bands:
for (track_id,) in conn.execute(
"SELECT track_id FROM fp_lsh WHERE band_hash=? AND band_index=?",
(band_hash, band_index),
):
if track_id == target_id:
continue
collision_counts[track_id] = collision_counts.get(track_id, 0) + 1
candidates = {
track_id
for track_id, count in collision_counts.items()
if count >= config.min_band_collisions
}
if stage1:
candidates &= stage1
return candidates
def _fetch_fingerprint(conn: sqlite3.Connection, track_id: str) -> str | None:
row = conn.execute(
"SELECT fp_blob FROM audio_fingerprint WHERE track_id=?",
(track_id,),
).fetchone()
if not row or row[0] is None:
return None
return row[0].decode("utf-8")
def _cleanup_missing(conn: sqlite3.Connection, track_ids: set[str]) -> None:
existing = {row[0] for row in conn.execute("SELECT track_id FROM audio_header")}
missing = existing - track_ids
if not missing:
return
for track_id in missing:
conn.execute("DELETE FROM audio_header WHERE track_id=?", (track_id,))
conn.execute("DELETE FROM audio_fingerprint WHERE track_id=?", (track_id,))
conn.execute("DELETE FROM fp_lsh WHERE track_id=?", (track_id,))
def _update_headers(
conn: sqlite3.Connection,
paths: Iterable[str],
config: DuplicateScanConfig,
log_callback: LogCallback | None,
) -> int:
updated = 0
for path in paths:
try:
stat = os.stat(path)
except OSError:
continue
row = conn.execute(
"SELECT mtime, size FROM audio_header WHERE track_id=?",
(path,),
).fetchone()
if row and row[0] == stat.st_mtime and row[1] == stat.st_size:
continue
try:
duration_ms, rms_db, centroid_mean, rolloff_mean = _compute_audio_header(
path, config.sample_rate, config.max_analysis_sec
)
except Exception as exc:
_log(log_callback, f"⚠ Skipped header for {path}: {exc}")
continue
conn.execute(
"""
INSERT OR REPLACE INTO audio_header (
track_id, mtime, size, duration_ms, rms_db, centroid_mean, rolloff_mean, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
path,
stat.st_mtime,
stat.st_size,
duration_ms,
rms_db,
centroid_mean,
rolloff_mean,
time.time(),
),
)
updated += 1
return updated
def _update_fingerprints(
conn: sqlite3.Connection,
paths: Iterable[str],
config: DuplicateScanConfig,
log_callback: LogCallback | None,
) -> int:
updated = 0
for path in paths:
try:
stat = os.stat(path)
except OSError:
continue
row = conn.execute(
"SELECT mtime, size FROM audio_fingerprint WHERE track_id=?",
(path,),
).fetchone()
if row and row[0] == stat.st_mtime and row[1] == stat.st_size:
continue
try:
fp = chromaprint_utils.fingerprint_fpcalc(
path,
trim=True,
start_sec=0.0,
duration_sec=config.max_analysis_sec,
)
except chromaprint_utils.FingerprintError as exc:
_log(log_callback, f"⚠ Fingerprint failed for {path}: {exc}")
continue
if not fp:
_log(log_callback, f"⚠ No fingerprint for {path}")
continue
fp_blob = _fingerprint_to_bytes(fp)
fp_len = len(fp.split())
conn.execute(
"""
INSERT OR REPLACE INTO audio_fingerprint (
track_id, mtime, size, fp_blob, fp_len, fp_version
) VALUES (?, ?, ?, ?, ?, ?)
""",
(path, stat.st_mtime, stat.st_size, fp_blob, fp_len, "chromaprint-v1"),
)
conn.execute("DELETE FROM fp_lsh WHERE track_id=?", (path,))
for band_index, band_hash in _fingerprint_band_hashes(fp, config.fp_bands):
conn.execute(
"INSERT INTO fp_lsh (band_hash, band_index, track_id) VALUES (?, ?, ?)",
(band_hash, band_index, path),
)
updated += 1
return updated
def _write_dup_groups(conn: sqlite3.Connection) -> int:
rows = conn.execute(
"SELECT track_id_a, track_id_b FROM dup_edges ORDER BY track_id_a"
).fetchall()
parent: dict[str, str] = {}
def find(x: str) -> str:
parent.setdefault(x, x)
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]
def union(a: str, b: str) -> None:
ra = find(a)
rb = find(b)
if ra != rb:
parent[rb] = ra
for a, b in rows:
union(a, b)
groups: dict[str, list[str]] = {}
for track_id in parent:
root = find(track_id)
groups.setdefault(root, []).append(track_id)
conn.execute("DELETE FROM dup_groups")
group_id = 0
for members in groups.values():
if len(members) < 2:
continue
group_id += 1
gid = f"group-{group_id:04d}"
for track_id in members:
conn.execute(
"INSERT INTO dup_groups (group_id, track_id) VALUES (?, ?)",
(gid, track_id),
)
return group_id
def run_duplicate_scan(
library_path: str,
db_path: str,
config: DuplicateScanConfig | None = None,
log_callback: LogCallback | None = None,
) -> DuplicateScanSummary:
config = config or DuplicateScanConfig()
paths = _list_audio_files(library_path)
_log(log_callback, f"Found {len(paths)} audio files.")
with _connect(db_path) as conn:
ensure_schema(conn)
_cleanup_missing(conn, set(paths))
_log(log_callback, "Stage 1: updating audio headers...")
headers_updated = _update_headers(conn, paths, config, log_callback)
_log(log_callback, f"Header updates: {headers_updated}")
_log(log_callback, "Stage 2: updating fingerprints + LSH...")
fingerprints_updated = _update_fingerprints(conn, paths, config, log_callback)
_log(log_callback, f"Fingerprints updated: {fingerprints_updated}")
_log(log_callback, "Stage 3: scanning candidates...")
conn.execute("DELETE FROM dup_edges")
chroma_cache: dict[str, np.ndarray] = {}
edges_written = 0
for idx, target_id in enumerate(paths):
stage1 = _stage1_candidates(conn, target_id, config)
if not stage1:
continue
stage2 = _stage2_candidates(conn, target_id, stage1, config)
if not stage2:
continue
fp_target = _fetch_fingerprint(conn, target_id)
if not fp_target:
continue
for cand_id in sorted(stage2):
if cand_id <= target_id:
continue
fp_cand = _fetch_fingerprint(conn, cand_id)
if not fp_cand:
continue
dist = fingerprint_distance(fp_target, fp_cand)
if dist > config.fp_distance_threshold:
continue
try:
seq_a = chroma_cache.get(target_id)
if seq_a is None:
seq_a = _chroma_sequence(
target_id, config.sample_rate, config.max_analysis_sec
)
chroma_cache[target_id] = seq_a
seq_b = chroma_cache.get(cand_id)
if seq_b is None:
seq_b = _chroma_sequence(
cand_id, config.sample_rate, config.max_analysis_sec
)
chroma_cache[cand_id] = seq_b
score = _alignment_score(
seq_a, seq_b, config.chroma_max_offset_frames
)
except Exception as exc:
_log(log_callback, f"⚠ Stage 3 failed for {cand_id}: {exc}")
continue
if score >= config.chroma_match_threshold:
verdict = "match"
elif score >= config.chroma_possible_threshold:
verdict = "possible"
else:
continue
conn.execute(
"""
INSERT INTO dup_edges (track_id_a, track_id_b, score, method, verified_at)
VALUES (?, ?, ?, ?, ?)
""",
(target_id, cand_id, score, verdict, time.time()),
)
edges_written += 1
if (idx + 1) % 50 == 0:
_log(log_callback, f"Scanned {idx + 1}/{len(paths)} tracks...")
groups_written = _write_dup_groups(conn)
_log(log_callback, f"Duplicate groups: {groups_written}")
return DuplicateScanSummary(
tracks_total=len(paths),
headers_updated=headers_updated,
fingerprints_updated=fingerprints_updated,
edges_written=edges_written,
groups_written=groups_written,
)