-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathupdate.py
More file actions
119 lines (102 loc) · 3.96 KB
/
update.py
File metadata and controls
119 lines (102 loc) · 3.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import pickle
import subprocess
import random
import cv2
import numpy as np
import torch
import warnings
from insightface.app import FaceAnalysis
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 경로 설정
known_people_dir = r'D:\DATA'
train_emb_file = 'train_embeddings.pkl'
val_emb_file = 'val_embeddings.pkl'
label_to_id_file = 'label_to_id.pkl'
# 얼굴 감지기 초기화
provider = 'CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'
app = FaceAnalysis(name='buffalo_l', providers=[provider])
app.prepare(ctx_id=0, det_size=(864,576), det_thresh=0.5)
# 얼굴 임베딩 추출 함수
def get_embedding(img):
faces = app.get(img)
if not faces: return None
e = faces[0].normed_embedding
return e / np.linalg.norm(e)
# 비디오에서 일정 프레임마다 얼굴 추출
def process_video(video_path, every_n=5):
embeddings = []
cap = cv2.VideoCapture(video_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"[INFO] 영상 처리 시작: {os.path.basename(video_path)} (총 {total}프레임 중 {every_n}프레임마다 추출)")
idx = 0
while True:
ret, frame = cap.read()
if not ret: break
if idx % every_n == 0:
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
emb = get_embedding(rgb)
if emb is not None:
embeddings.append(emb)
idx += 1
cap.release()
return embeddings
# 기존 폴더 기반 등록
def process_image_folder(image_dir):
embeddings = []
for f in os.listdir(image_dir):
if f.lower().endswith(('.jpg','jpeg','png')):
img_path = os.path.join(image_dir, f)
img = cv2.imread(img_path)
if img is None: continue
img = cv2.resize(img, (864,576))
emb = get_embedding(img)
if emb is not None:
embeddings.append(emb)
return embeddings
# 메인 함수
def update_photos(pid=None, video_path=None, every_n_frames=5):
# 파일 로드
with open(train_emb_file,'rb') as f: train_emb = pickle.load(f)
with open(val_emb_file,'rb') as f: val_emb = pickle.load(f)
with open(label_to_id_file,'rb') as f: label2id = pickle.load(f)
if pid is None:
pid = input("Person ID (new or existing): ")
new_embs = []
# 비디오 기반 처리
if video_path:
new_embs = process_video(video_path, every_n_frames)
else:
new_dir = os.path.join(known_people_dir, pid)
os.makedirs(new_dir, exist_ok=True)
new_embs = process_image_folder(new_dir)
if not new_embs:
print("[오류] 유효한 얼굴 임베딩을 추출하지 못했습니다.")
return
# 신규 사용자 등록 여부
is_new = pid not in label2id
if is_new:
label2id[pid] = len(label2id)
random.shuffle(new_embs)
n_val = max(1, int(0.2 * len(new_embs)))
val_emb[pid] = new_embs[:n_val]
train_emb[pid] = new_embs[n_val:]
else:
train_emb.setdefault(pid, []).extend(new_embs)
# 저장
with open(train_emb_file,'wb') as f: pickle.dump(train_emb, f)
with open(val_emb_file,'wb') as f: pickle.dump(val_emb, f)
with open(label_to_id_file,'wb') as f: pickle.dump(label2id, f)
print(f"[완료] {pid} 사용자 업데이트 완료. 총 등록 임베딩 수: {len(train_emb[pid])}")
# 재학습
subprocess.run(['python','train_arcface.py'], check=True)
print("✅ Update and retraining complete")
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--pid', type=str, help='Person ID to register', required=True)
parser.add_argument('--video', type=str, help='Optional: path to video file (instead of folder)')
parser.add_argument('--every', type=int, default=5, help='Use every Nth frame from video')
args = parser.parse_args()
update_photos(pid=args.pid, video_path=args.video, every_n_frames=args.every)