From 4fc705d110e1118fc6864a1424b8aa84db12e6b5 Mon Sep 17 00:00:00 2001 From: Leo Chu Date: Fri, 13 Mar 2026 18:47:34 +0800 Subject: [PATCH 1/3] Add Apple Silicon MPS GPU support and fix PyTorch 2.6+ compatibility - Add MPS (Metal Performance Shaders) backend detection so Apple Silicon Macs use GPU acceleration instead of falling back to CPU - Add weights_only=False to all torch.load() calls for PyTorch 2.6+ which changed the default to weights_only=True - Add torch.no_grad() to inference paths (run_segment, run_pix2pix) to avoid unnecessary gradient computation - Fix Queue.qsize() NotImplementedError on macOS - Add prefetch thread for frame I/O in video fusion to overlap disk reads with model inference Made-with: Cursor --- cores/clean.py | 47 +++++++++++++++++++++++++++++++++++++------- cores/options.py | 7 +++++-- models/loadmodel.py | 10 +++++----- models/model_util.py | 23 +++++++++++++++------- models/runmodel.py | 6 ++++-- train/add/train.py | 2 +- util/data.py | 17 ++++++++++++---- 7 files changed, 84 insertions(+), 28 deletions(-) diff --git a/cores/clean.py b/cores/clean.py index 285542b..9982f29 100644 --- a/cores/clean.py +++ b/cores/clean.py @@ -33,7 +33,7 @@ def get_mosaic_positions(opt,netM,imagepaths,savemask=True): cv2.namedWindow('mosaic mask', cv2.WINDOW_NORMAL) print('Step:2/4 -- Find mosaic location') - img_read_pool = Queue(4) + img_read_pool = Queue(8) def loader(imagepaths): for imagepath in imagepaths: img_origin = impro.imread(os.path.join(opt.temp_dir+'/video2image',imagepath)) @@ -194,21 +194,54 @@ def write_result(): t.setDaemon(True) t.start() + # Prefetch all needed frame indices into a cache via background thread + img_cache = {} + prefetch_queue = Queue(16) + def prefetch_loader(): + needed = set() + for i in range(length): + if i == 0: + for j in range(POOL_NUM): + needed.add(np.clip(i+j-LEFT_FRAME,0,length-1)) + else: + needed.add(np.clip(i+LEFT_FRAME,0,length-1)) + for idx in sorted(needed): + img = impro.imread(os.path.join(opt.temp_dir+'/video2image',imagepaths[idx])) + prefetch_queue.put((idx, img)) + prefetch_t = Thread(target=prefetch_loader) + prefetch_t.setDaemon(True) + prefetch_t.start() + + def ensure_cached(idx): + while idx not in img_cache: + k, v = prefetch_queue.get() + img_cache[k] = v + for i,imagepath in enumerate(imagepaths,0): x,y,size = positions[i][0],positions[i][1],positions[i][2] input_stream = [] - # image read stream - if i==0 :# init + if i==0: for j in range(POOL_NUM): - img_pool.append(impro.imread(os.path.join(opt.temp_dir+'/video2image',imagepaths[np.clip(i+j-LEFT_FRAME,0,len(imagepaths)-1)]))) - else: # load next frame + frame_idx = np.clip(i+j-LEFT_FRAME,0,length-1) + ensure_cached(frame_idx) + img_pool.append(img_cache[frame_idx]) + else: img_pool.pop(0) - img_pool.append(impro.imread(os.path.join(opt.temp_dir+'/video2image',imagepaths[np.clip(i+LEFT_FRAME,0,len(imagepaths)-1)]))) + frame_idx = np.clip(i+LEFT_FRAME,0,length-1) + ensure_cached(frame_idx) + img_pool.append(img_cache[frame_idx]) + # Free frames we no longer need + old_idx = np.clip(i-LEFT_FRAME-1,0,length-1) + img_cache.pop(old_idx, None) img_origin = img_pool[LEFT_FRAME] # preview result and print if not opt.no_preview: - if show_pool.qsize()>3: + try: + pool_full = show_pool.qsize() > 3 + except NotImplementedError: + pool_full = not show_pool.empty() + if pool_full: cv2.imshow('clean',show_pool.get()) cv2.waitKey(1) & 0xFF diff --git a/cores/options.py b/cores/options.py index 9574a9b..c625561 100644 --- a/cores/options.py +++ b/cores/options.py @@ -61,9 +61,12 @@ def getparse(self, test_flag = False): self.opt.temp_dir = os.path.join(self.opt.temp_dir, 'DeepMosaics_temp') if self.opt.gpu_id != '-1': - os.environ["CUDA_VISIBLE_DEVICES"] = str(self.opt.gpu_id) import torch - if not torch.cuda.is_available(): + if torch.cuda.is_available(): + os.environ["CUDA_VISIBLE_DEVICES"] = str(self.opt.gpu_id) + elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + self.opt.gpu_id = 'mps' + else: self.opt.gpu_id = '-1' # else: # self.opt.gpu_id = '-1' diff --git a/models/loadmodel.py b/models/loadmodel.py index 7d9b391..44191bc 100755 --- a/models/loadmodel.py +++ b/models/loadmodel.py @@ -19,7 +19,7 @@ def pix2pix(opt): else: netG = pix2pix_G(3, 3, 64, opt.netG, norm='batch',use_dropout=True, init_type='normal', gpu_ids=[]) show_paramsnumber(netG,'netG') - netG.load_state_dict(torch.load(opt.model_path)) + netG.load_state_dict(torch.load(opt.model_path, weights_only=False)) netG = model_util.todevice(netG,opt.gpu_id) netG.eval() return netG @@ -37,7 +37,7 @@ def style(opt): netG = netG.module # if you are using PyTorch newer than 0.4 (e.g., built from # GitHub source), you can remove str() on self.device - state_dict = torch.load(opt.model_path, map_location='cpu') + state_dict = torch.load(opt.model_path, map_location='cpu', weights_only=False) if hasattr(state_dict, '_metadata'): del state_dict._metadata @@ -53,7 +53,7 @@ def style(opt): def video(opt): netG = video_G(N=2,n_blocks=4,gpu_id=opt.gpu_id) show_paramsnumber(netG,'netG') - netG.load_state_dict(torch.load(opt.model_path)) + netG.load_state_dict(torch.load(opt.model_path, weights_only=False)) netG = model_util.todevice(netG,opt.gpu_id) netG.eval() return netG @@ -65,9 +65,9 @@ def bisenet(opt,type='roi'): net = BiSeNet(num_classes=1, context_path='resnet18',train_flag=False) show_paramsnumber(net,'segment') if type == 'roi': - net.load_state_dict(torch.load(opt.model_path)) + net.load_state_dict(torch.load(opt.model_path, weights_only=False)) elif type == 'mosaic': - net.load_state_dict(torch.load(opt.mosaic_position_model_path)) + net.load_state_dict(torch.load(opt.mosaic_position_model_path, weights_only=False)) net = model_util.todevice(net,opt.gpu_id) net.eval() return net diff --git a/models/model_util.py b/models/model_util.py index 2aa7f9e..d54a5df 100644 --- a/models/model_util.py +++ b/models/model_util.py @@ -11,20 +11,29 @@ import torch.utils.model_zoo as model_zoo ################################## IO ################################## +def get_device(gpu_id): + """Return a torch.device based on gpu_id string.""" + if gpu_id == 'mps': + return torch.device('mps') + elif gpu_id != '-1': + return torch.device('cuda') + return torch.device('cpu') + def save(net,path,gpu_id): if isinstance(net, nn.DataParallel): torch.save(net.module.cpu().state_dict(),path) else: - torch.save(net.cpu().state_dict(),path) - if gpu_id != '-1': - net.cuda() + torch.save(net.cpu().state_dict(),path) + device = get_device(gpu_id) + if device.type != 'cpu': + net.to(device) def todevice(net,gpu_id): - if gpu_id != '-1' and len(gpu_id) == 1: - net.cuda() - elif gpu_id != '-1' and len(gpu_id) > 1: + device = get_device(gpu_id) + if device.type == 'cuda' and len(gpu_id) > 1 and gpu_id != 'mps': net = nn.DataParallel(net) - net.cuda() + if device.type != 'cpu': + net.to(device) return net # patch InstanceNorm checkpoints prior to 0.4 diff --git a/models/runmodel.py b/models/runmodel.py index 3e97fee..7600218 100755 --- a/models/runmodel.py +++ b/models/runmodel.py @@ -10,7 +10,8 @@ def run_segment(img,net,size = 360,gpu_id = '-1'): img = impro.resize(img,size) img = data.im2tensor(img,gpu_id = gpu_id, bgr2rgb = False, is0_1 = True) - mask = net(img) + with torch.no_grad(): + mask = net(img) mask = data.tensor2im(mask, gray=True, is0_1 = True) return mask @@ -20,7 +21,8 @@ def run_pix2pix(img,net,opt): else: img = impro.resize(img,128) img = data.im2tensor(img,gpu_id=opt.gpu_id) - img_fake = net(img) + with torch.no_grad(): + img_fake = net(img) img_fake = data.tensor2im(img_fake) return img_fake diff --git a/train/add/train.py b/train/add/train.py index a939ee3..9cef12c 100644 --- a/train/add/train.py +++ b/train/add/train.py @@ -107,7 +107,7 @@ def loadimage(imagepaths,maskpaths,opt,test_flag = False): opt.continue_train = False print('can not load last.pth, training on init weight.') if opt.continue_train: - net.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last.pth'))) + net.load_state_dict(torch.load(os.path.join(dir_checkpoint,'last.pth'), weights_only=False)) f = open(os.path.join(dir_checkpoint,'epoch_log.txt'),'r') opt.startepoch = int(f.read()) f.close() diff --git a/util/data.py b/util/data.py index 8a7865e..f322db3 100755 --- a/util/data.py +++ b/util/data.py @@ -8,10 +8,18 @@ from . import image_processing as impro from . import degradater +def _get_device(gpu_id): + if gpu_id == 'mps': + return torch.device('mps') + elif gpu_id != '-1': + return torch.device('cuda') + return torch.device('cpu') + def to_tensor(data,gpu_id): data = torch.from_numpy(data) - if gpu_id != '-1': - data = data.cuda() + device = _get_device(gpu_id) + if device.type != 'cpu': + data = data.to(device) return data def normalize(data): @@ -65,8 +73,9 @@ def im2tensor(image_numpy, gray=False,bgr2rgb = True, reshape = True, gpu_id = ' image_tensor = torch.from_numpy(image_numpy).float() if reshape: image_tensor = image_tensor.reshape(1,ch,h,w) - if gpu_id != '-1': - image_tensor = image_tensor.cuda() + device = _get_device(gpu_id) + if device.type != 'cpu': + image_tensor = image_tensor.to(device) return image_tensor def shuffledata(data,target): From 6b37c776ab9ef152f5162fcb84ef0dd1fd46205f Mon Sep 17 00:00:00 2001 From: Leo Chu Date: Sat, 14 Mar 2026 06:51:53 +0800 Subject: [PATCH 2/3] Add performance optimizations and fix process hang on exit Performance: - Batch segmentation: process 4 frames at once through BiSeNet - ThreadPoolExecutor for concurrent disk writes (masks + results) - Prefetch I/O thread for frame loading in video fusion - torch.compile() for CUDA model acceleration (skipped on MPS) - float16 autocast for CUDA inference paths - Contiguous tensor conversion to avoid stride mismatches - Hardware-accelerated h264_videotoolbox encoder on macOS Fixes: - Replace multiprocessing.Queue with threading queue.Queue to prevent process hang from undrained pipe buffers on exit - Guard all input() prompts with sys.stdin.isatty() so the process exits cleanly in non-interactive terminals - Auto-resume unfinished videos in non-interactive mode Made-with: Cursor --- cores/clean.py | 134 +++++++++++++++++++++++++++++--------------- cores/init.py | 6 +- cores/options.py | 17 ++++-- deepmosaic.py | 7 ++- models/loadmodel.py | 40 +++++++------ models/runmodel.py | 50 ++++++++++++++++- util/data.py | 3 +- util/ffmpeg.py | 20 ++++++- 8 files changed, 194 insertions(+), 83 deletions(-) diff --git a/cores/clean.py b/cores/clean.py index 9982f29..ba7c2f4 100644 --- a/cores/clean.py +++ b/cores/clean.py @@ -3,13 +3,17 @@ import numpy as np import cv2 import torch +from concurrent.futures import ThreadPoolExecutor from models import runmodel from util import data,util,ffmpeg,filt from util import image_processing as impro from .init import video_init -from multiprocessing import Queue, Process +from queue import Queue from threading import Thread +BATCH_SIZE = 4 +WRITE_WORKERS = 4 + ''' ---------------------Clean Mosaic--------------------- ''' @@ -33,7 +37,7 @@ def get_mosaic_positions(opt,netM,imagepaths,savemask=True): cv2.namedWindow('mosaic mask', cv2.WINDOW_NORMAL) print('Step:2/4 -- Find mosaic location') - img_read_pool = Queue(8) + img_read_pool = Queue(BATCH_SIZE * 3) def loader(imagepaths): for imagepath in imagepaths: img_origin = impro.imread(os.path.join(opt.temp_dir+'/video2image',imagepath)) @@ -42,14 +46,38 @@ def loader(imagepaths): t.setDaemon(True) t.start() - for i,imagepath in enumerate(imagepaths,1): - img_origin = img_read_pool.get() - x,y,size,mask = runmodel.get_mosaic_position(img_origin,netM,opt) - positions.append([x,y,size]) - if savemask: - t = Thread(target=cv2.imwrite,args=(os.path.join(opt.temp_dir+'/mosaic_mask',imagepath), mask,)) - t.start() - if i%1000==0: + write_executor = ThreadPoolExecutor(max_workers=WRITE_WORKERS) if savemask else None + total = len(imagepaths) + + for batch_start in range(0, total, BATCH_SIZE): + batch_end = min(batch_start + BATCH_SIZE, total) + batch_imgs = [img_read_pool.get() for _ in range(batch_end - batch_start)] + batch_masks = runmodel.run_segment_batch(batch_imgs, netM, size=360, gpu_id=opt.gpu_id) + + for j, img_origin in enumerate(batch_imgs): + idx = batch_start + j + h, w = img_origin.shape[:2] + mask = batch_masks[j] + mask_proc = impro.mask_threshold(mask, ex_mun=int(min(h,w)/20), threshold=opt.mask_threshold) + if not opt.all_mosaic_area: + mask_proc = impro.find_mostlikely_ROI(mask_proc) + x,y,size,area = impro.boundingSquare(mask_proc, Ex_mul=opt.ex_mult) + rat = min(h,w)/360.0 + x,y,size = int(rat*x),int(rat*y),int(rat*size) + x,y = np.clip(x, 0, w),np.clip(y, 0, h) + size = np.clip(size, 0, min(w-x,h-y)) + positions.append([x,y,size]) + + if savemask: + path_out = os.path.join(opt.temp_dir+'/mosaic_mask', imagepaths[idx]) + write_executor.submit(cv2.imwrite, path_out, mask) + + if not opt.no_preview: + cv2.imshow('mosaic mask', mask) + cv2.waitKey(1) & 0xFF + + i = batch_end + if i % 1000 < BATCH_SIZE and i >= BATCH_SIZE: save_positions = np.array(positions) if continue_flag: save_positions = np.concatenate((pre_positions,save_positions),axis=0) @@ -57,12 +85,11 @@ def loader(imagepaths): step = {'step':2,'frame':i+resume_frame} util.savejson(os.path.join(opt.temp_dir,'step.json'),step) - #preview result and print - if not opt.no_preview: - cv2.imshow('mosaic mask',mask) - cv2.waitKey(1) & 0xFF t2 = time.time() - print('\r',str(i)+'/'+str(len(imagepaths)),util.get_bar(100*i/len(imagepaths),num=35),util.counttime(t1,t2,i,len(imagepaths)),end='') + print('\r',str(i)+'/'+str(total),util.get_bar(100*i/total,num=35),util.counttime(t1,t2,i,total),end='') + + if write_executor: + write_executor.shutdown(wait=True) if not opt.no_preview: cv2.destroyAllWindows() @@ -118,15 +145,24 @@ def cleanmosaic_video_byframe(opt,netG,netM): if not opt.no_preview: cv2.namedWindow('clean', cv2.WINDOW_NORMAL) - # clean mosaic print('Step:3/4 -- Clean Mosaic:') length = len(imagepaths) + write_exec = ThreadPoolExecutor(max_workers=WRITE_WORKERS) + + img_read_pool = Queue(8) + def frame_loader(paths): + for p in paths: + img_read_pool.put(impro.imread(os.path.join(opt.temp_dir+'/video2image', p))) + lt = Thread(target=frame_loader, args=(imagepaths,)) + lt.setDaemon(True) + lt.start() + for i,imagepath in enumerate(imagepaths,0): x,y,size = positions[i][0],positions[i][1],positions[i][2] - img_origin = impro.imread(os.path.join(opt.temp_dir+'/video2image',imagepath)) + img_origin = img_read_pool.get() img_result = img_origin.copy() if size > 100: - try:#Avoid unknown errors + try: img_mosaic = img_origin[y-size:y+size,x-size:x+size] if opt.traditional: img_fake = runmodel.traditional_cleaner(img_mosaic,opt) @@ -136,17 +172,21 @@ def cleanmosaic_video_byframe(opt,netG,netM): img_result = impro.replace_mosaic(img_origin,img_fake,mask,x,y,size,opt.no_feather) except Exception as e: print('Warning:',e) - t = Thread(target=cv2.imwrite,args=(os.path.join(opt.temp_dir+'/replace_mosaic',imagepath), img_result,)) - t.start() - os.remove(os.path.join(opt.temp_dir+'/video2image',imagepath)) - - #preview result and print + + out_path = os.path.join(opt.temp_dir+'/replace_mosaic',imagepath) + rm_path = os.path.join(opt.temp_dir+'/video2image',imagepath) + def _write_and_rm(op, rp, img): + cv2.imwrite(op, img) + os.remove(rp) + write_exec.submit(_write_and_rm, out_path, rm_path, img_result) + if not opt.no_preview: cv2.imshow('clean',img_result) cv2.waitKey(1) & 0xFF t2 = time.time() print('\r',str(i+1)+'/'+str(length),util.get_bar(100*i/length,num=35),util.counttime(t1,t2,i+1,len(imagepaths)),end='') print() + write_exec.shutdown(wait=True) if not opt.no_preview: cv2.destroyAllWindows() print('Step:4/4 -- Convert images to video') @@ -176,23 +216,22 @@ def cleanmosaic_video_fusion(opt,netG,netM): # clean mosaic print('Step:3/4 -- Clean Mosaic:') length = len(imagepaths) - write_pool = Queue(4) + write_pool = Queue(8) show_pool = Queue(4) - def write_result(): - while True: - save_ori,imagepath,img_origin,img_fake,x,y,size = write_pool.get() - if save_ori: - img_result = img_origin - else: - mask = cv2.imread(os.path.join(opt.temp_dir+'/mosaic_mask',imagepath),0) - img_result = impro.replace_mosaic(img_origin,img_fake,mask,x,y,size,opt.no_feather) - if not opt.no_preview: - show_pool.put(img_result.copy()) - cv2.imwrite(os.path.join(opt.temp_dir+'/replace_mosaic',imagepath),img_result) - os.remove(os.path.join(opt.temp_dir+'/video2image',imagepath)) - t = Thread(target=write_result,args=()) - t.setDaemon(True) - t.start() + ac_dev = runmodel._get_autocast_device(opt.gpu_id) + + def write_single(save_ori, imagepath, img_origin, img_fake, x, y, size): + if save_ori: + img_result = img_origin + else: + mask = cv2.imread(os.path.join(opt.temp_dir+'/mosaic_mask',imagepath),0) + img_result = impro.replace_mosaic(img_origin,img_fake,mask,x,y,size,opt.no_feather) + if not opt.no_preview: + show_pool.put(img_result.copy()) + cv2.imwrite(os.path.join(opt.temp_dir+'/replace_mosaic',imagepath),img_result) + os.remove(os.path.join(opt.temp_dir+'/video2image',imagepath)) + + write_executor = ThreadPoolExecutor(max_workers=WRITE_WORKERS) # Prefetch all needed frame indices into a cache via background thread img_cache = {} @@ -246,7 +285,7 @@ def ensure_cached(idx): cv2.waitKey(1) & 0xFF if size>50: - try:#Avoid unknown errors + try: for pos in FRAME_POS: input_stream.append(impro.resize(img_pool[pos][y-size:y+size,x-size:x+size], INPUT_SIZE,interpolation=cv2.INTER_CUBIC)[:,:,::-1]) if init_flag: @@ -257,22 +296,27 @@ def ensure_cached(idx): input_stream = np.array(input_stream).reshape(1,T,INPUT_SIZE,INPUT_SIZE,3).transpose((0,4,1,2,3)) input_stream = data.to_tensor(data.normalize(input_stream),gpu_id=opt.gpu_id) with torch.no_grad(): - unmosaic_pred = netG(input_stream,previous_frame) + if ac_dev: + with torch.autocast(ac_dev, dtype=torch.float16): + unmosaic_pred = netG(input_stream,previous_frame) + else: + unmosaic_pred = netG(input_stream,previous_frame) img_fake = data.tensor2im(unmosaic_pred,rgb2bgr = True) previous_frame = unmosaic_pred - write_pool.put([False,imagepath,img_origin.copy(),img_fake.copy(),x,y,size]) + write_executor.submit(write_single, False, imagepath, img_origin.copy(), img_fake.copy(), x, y, size) except Exception as e: init_flag = True print('Error:',e) else: - write_pool.put([True,imagepath,img_origin.copy(),-1,-1,-1,-1]) + write_executor.submit(write_single, True, imagepath, img_origin.copy(), -1, -1, -1, -1) init_flag = True t2 = time.time() print('\r',str(i+1)+'/'+str(length),util.get_bar(100*i/length,num=35),util.counttime(t1,t2,i+1,len(imagepaths)),end='') print() - write_pool.close() - show_pool.close() + write_executor.shutdown(wait=True) + while not show_pool.empty(): + show_pool.get_nowait() if not opt.no_preview: cv2.destroyAllWindows() print('Step:4/4 -- Convert images to video') diff --git a/cores/init.py b/cores/init.py index 5993c58..f9f95f7 100644 --- a/cores/init.py +++ b/cores/init.py @@ -1,4 +1,5 @@ import os +import sys from util import util,ffmpeg ''' @@ -13,7 +14,10 @@ def video_init(opt,path): if os.path.isfile(os.path.join(opt.temp_dir,'step.json')): step = util.loadjson(os.path.join(opt.temp_dir,'step.json')) if int(step['step'])>=1: - choose = input('There is an unfinished video. Continue it? [y/n] ') + if sys.stdin.isatty(): + choose = input('There is an unfinished video. Continue it? [y/n] ') + else: + choose = 'y' if choose.lower() =='yes' or choose.lower() == 'y': imagepaths = os.listdir(opt.temp_dir+'/video2image') imagepaths.sort() diff --git a/cores/options.py b/cores/options.py index c625561..5d5680b 100644 --- a/cores/options.py +++ b/cores/options.py @@ -74,11 +74,13 @@ def getparse(self, test_flag = False): if test_flag: if not os.path.exists(self.opt.media_path): print('Error: Media does not exist!') - input('Please press any key to exit.\n') + if sys.stdin.isatty(): + input('Please press any key to exit.\n') sys.exit(0) if not os.path.exists(self.opt.model_path): print('Error: Model does not exist!') - input('Please press any key to exit.\n') + if sys.stdin.isatty(): + input('Please press any key to exit.\n') sys.exit(0) if self.opt.mode == 'auto': @@ -90,7 +92,8 @@ def getparse(self, test_flag = False): self.opt.mode = 'style' else: print('Please check model_path!') - input('Please press any key to exit.\n') + if sys.stdin.isatty(): + input('Please press any key to exit.\n') sys.exit(0) if self.opt.output_size == 0 and self.opt.mode == 'style': @@ -110,7 +113,8 @@ def getparse(self, test_flag = False): self.opt.netG = 'video' else: print('Type of Generator error!') - input('Please press any key to exit.\n') + if sys.stdin.isatty(): + input('Please press any key to exit.\n') sys.exit(0) if self.opt.ex_mult == 'auto': @@ -126,8 +130,9 @@ def getparse(self, test_flag = False): if os.path.isfile(_path): self.opt.mosaic_position_model_path = _path else: - input('Please check mosaic_position_model_path!') - input('Please press any key to exit.\n') + print('Please check mosaic_position_model_path!') + if sys.stdin.isatty(): + input('Please press any key to exit.\n') sys.exit(0) return self.opt \ No newline at end of file diff --git a/deepmosaic.py b/deepmosaic.py index 5a59f19..8a0ed48 100644 --- a/deepmosaic.py +++ b/deepmosaic.py @@ -7,7 +7,8 @@ from models import loadmodel except Exception as e: print(e) - input('Please press any key to exit.\n') + if sys.stdin.isatty(): + input('Please press any key to exit.\n') sys.exit(0) opt = Options().getparse(test_flag = True) @@ -94,6 +95,6 @@ def main(): print(ex_val) for stack in traceback.extract_tb(ex_stack): print(stack) - input('Please press any key to exit.\n') - #util.clean_tempfiles(tmp_init = False) + if sys.stdin.isatty(): + input('Please press any key to exit.\n') sys.exit(0) \ No newline at end of file diff --git a/models/loadmodel.py b/models/loadmodel.py index 44191bc..8e1e1db 100755 --- a/models/loadmodel.py +++ b/models/loadmodel.py @@ -2,8 +2,6 @@ from . import model_util from .pix2pix_model import define_G as pix2pix_G from .pix2pixHD_model import define_G as pix2pixHD_G -# from .video_model import MosaicNet -# from .videoHD_model import MosaicNet as MosaicNet_HD from .BiSeNet_model import BiSeNet from .BVDNet import define_G as video_G @@ -12,17 +10,28 @@ def show_paramsnumber(net,netname='net'): parameters = round(parameters/1e6,2) print(netname+' parameters: '+str(parameters)+'M') +def _try_compile(net, gpu_id): + if gpu_id == 'mps' or gpu_id == '-1': + return net + try: + return torch.compile(net) + except Exception: + return net + +def _finalize(net, gpu_id): + net = model_util.todevice(net, gpu_id) + net.eval() + net = _try_compile(net, gpu_id) + return net + def pix2pix(opt): - # print(opt.model_path,opt.netG) if opt.netG == 'HD': netG = pix2pixHD_G(3, 3, 64, 'global' ,4) else: netG = pix2pix_G(3, 3, 64, opt.netG, norm='batch',use_dropout=True, init_type='normal', gpu_ids=[]) show_paramsnumber(netG,'netG') netG.load_state_dict(torch.load(opt.model_path, weights_only=False)) - netG = model_util.todevice(netG,opt.gpu_id) - netG.eval() - return netG + return _finalize(netG, opt.gpu_id) def style(opt): @@ -31,32 +40,23 @@ def style(opt): else: netG = pix2pix_G(3, 3, 64, 'resnet_9blocks', norm='instance',use_dropout=False, init_type='normal', gpu_ids=[]) - #in other to load old pretrain model - #https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/models/base_model.py if isinstance(netG, torch.nn.DataParallel): netG = netG.module - # if you are using PyTorch newer than 0.4 (e.g., built from - # GitHub source), you can remove str() on self.device state_dict = torch.load(opt.model_path, map_location='cpu', weights_only=False) if hasattr(state_dict, '_metadata'): del state_dict._metadata - # patch InstanceNorm checkpoints prior to 0.4 - for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop + for key in list(state_dict.keys()): model_util.patch_instance_norm_state_dict(state_dict, netG, key.split('.')) netG.load_state_dict(state_dict) - netG = model_util.todevice(netG,opt.gpu_id) - netG.eval() - return netG + return _finalize(netG, opt.gpu_id) def video(opt): netG = video_G(N=2,n_blocks=4,gpu_id=opt.gpu_id) show_paramsnumber(netG,'netG') netG.load_state_dict(torch.load(opt.model_path, weights_only=False)) - netG = model_util.todevice(netG,opt.gpu_id) - netG.eval() - return netG + return _finalize(netG, opt.gpu_id) def bisenet(opt,type='roi'): ''' @@ -68,6 +68,4 @@ def bisenet(opt,type='roi'): net.load_state_dict(torch.load(opt.model_path, weights_only=False)) elif type == 'mosaic': net.load_state_dict(torch.load(opt.mosaic_position_model_path, weights_only=False)) - net = model_util.todevice(net,opt.gpu_id) - net.eval() - return net + return _finalize(net, opt.gpu_id) diff --git a/models/runmodel.py b/models/runmodel.py index 7600218..b3012cb 100755 --- a/models/runmodel.py +++ b/models/runmodel.py @@ -7,22 +7,60 @@ import torch import numpy as np +def _get_autocast_device(gpu_id): + """Only enable autocast on CUDA — MPS float16 has stride/op issues.""" + if gpu_id != '-1' and gpu_id != 'mps': + return 'cuda' + return None + def run_segment(img,net,size = 360,gpu_id = '-1'): img = impro.resize(img,size) img = data.im2tensor(img,gpu_id = gpu_id, bgr2rgb = False, is0_1 = True) + ac_dev = _get_autocast_device(gpu_id) with torch.no_grad(): - mask = net(img) + if ac_dev: + with torch.autocast(ac_dev, dtype=torch.float16): + mask = net(img) + else: + mask = net(img) mask = data.tensor2im(mask, gray=True, is0_1 = True) return mask +def run_segment_batch(imgs, net, size=360, gpu_id='-1'): + """Process multiple images through segmentation in one batch.""" + resized = [impro.resize(img, size) for img in imgs] + batch = np.stack([img.astype(np.float32) / 255.0 for img in resized]) + batch = batch.transpose((0, 3, 1, 2)) + batch_tensor = torch.from_numpy(batch).float() + device = data._get_device(gpu_id) + if device.type != 'cpu': + batch_tensor = batch_tensor.to(device) + ac_dev = _get_autocast_device(gpu_id) + with torch.no_grad(): + if ac_dev: + with torch.autocast(ac_dev, dtype=torch.float16): + masks_tensor = net(batch_tensor) + else: + masks_tensor = net(batch_tensor) + masks = [] + for i in range(masks_tensor.shape[0]): + mask = data.tensor2im(masks_tensor, gray=True, is0_1=True, batch_index=i) + masks.append(mask) + return masks + def run_pix2pix(img,net,opt): if opt.netG == 'HD': img = impro.resize(img,512) else: img = impro.resize(img,128) img = data.im2tensor(img,gpu_id=opt.gpu_id) + ac_dev = _get_autocast_device(opt.gpu_id) with torch.no_grad(): - img_fake = net(img) + if ac_dev: + with torch.autocast(ac_dev, dtype=torch.float16): + img_fake = net(img) + else: + img_fake = net(img) img_fake = data.tensor2im(img_fake) return img_fake @@ -58,7 +96,13 @@ def run_styletransfer(opt, net, img): img = data.im2tensor(img,gpu_id=opt.gpu_id,gray=True) else: img = data.im2tensor(img,gpu_id=opt.gpu_id) - img = net(img) + ac_dev = _get_autocast_device(opt.gpu_id) + with torch.no_grad(): + if ac_dev: + with torch.autocast(ac_dev, dtype=torch.float16): + img = net(img) + else: + img = net(img) img = data.tensor2im(img) return img diff --git a/util/data.py b/util/data.py index f322db3..6fe928c 100755 --- a/util/data.py +++ b/util/data.py @@ -16,6 +16,7 @@ def _get_device(gpu_id): return torch.device('cpu') def to_tensor(data,gpu_id): + data = np.ascontiguousarray(data) data = torch.from_numpy(data) device = _get_device(gpu_id) if device.type != 'cpu': @@ -69,7 +70,7 @@ def im2tensor(image_numpy, gray=False,bgr2rgb = True, reshape = True, gpu_id = ' image_numpy = image_numpy/255.0 else: image_numpy = (image_numpy/255.0-0.5)/0.5 - image_numpy = image_numpy.transpose((2, 0, 1)) + image_numpy = np.ascontiguousarray(image_numpy.transpose((2, 0, 1))) image_tensor = torch.from_numpy(image_numpy).float() if reshape: image_tensor = image_tensor.reshape(1,ch,h,w) diff --git a/util/ffmpeg.py b/util/ffmpeg.py index 6efd686..3f97d69 100755 --- a/util/ffmpeg.py +++ b/util/ffmpeg.py @@ -50,12 +50,26 @@ def video2voice(videopath, voicepath, start_time='00:00:00', last_time='00:00:00 args += [voicepath] run(args) +def _get_encoder(): + """Try hardware-accelerated encoder, fall back to libx264.""" + import platform, subprocess + if platform.system() == 'Darwin': + try: + r = subprocess.run(['ffmpeg','-hide_banner','-encoders'], capture_output=True, text=True) + if 'h264_videotoolbox' in r.stdout: + return 'h264_videotoolbox -b:v 10M' + except Exception: + pass + return 'libx264' + def image2video(fps,imagepath,voicepath,videopath): - os.system('ffmpeg -y -r '+str(fps)+' -i '+imagepath+' -vcodec libx264 '+os.path.split(voicepath)[0]+'/video_tmp.mp4') + encoder = _get_encoder() + tmp_dir = os.path.split(voicepath)[0] + os.system('ffmpeg -y -r '+str(fps)+' -i '+imagepath+' -vcodec '+encoder+' '+tmp_dir+'/video_tmp.mp4') if os.path.exists(voicepath): - os.system('ffmpeg -i '+os.path.split(voicepath)[0]+'/video_tmp.mp4'+' -i "'+voicepath+'" -vcodec copy -acodec aac '+videopath) + os.system('ffmpeg -i '+tmp_dir+'/video_tmp.mp4'+' -i "'+voicepath+'" -vcodec copy -acodec aac '+videopath) else: - os.system('ffmpeg -i '+os.path.split(voicepath)[0]+'/video_tmp.mp4 '+videopath) + os.system('ffmpeg -i '+tmp_dir+'/video_tmp.mp4 '+videopath) def get_video_infos(videopath): args = ['ffprobe -v quiet -print_format json -show_format -show_streams', '-i', '"'+videopath+'"'] From ba81e8eed470ae730d1bb3d590bf3d9cec1b900f Mon Sep 17 00:00:00 2001 From: Leo Chu Date: Sun, 15 Mar 2026 10:44:23 +0800 Subject: [PATCH 3/3] Add Real-ESRGAN enhancement for cleaned mosaic patches Integrate Real-ESRGAN as an optional post-processing step (--enhance flag) that sharpens the 256x256 cleaned mosaic patches before compositing them back into the frame. Also harden the prefetch thread with error handling and a sentinel to prevent silent hangs, and flush stdout for step progress. Made-with: Cursor --- cores/clean.py | 60 ++++++++++++++++++++++++++++++---------------- cores/options.py | 3 ++- models/enhancer.py | 51 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 21 deletions(-) create mode 100644 models/enhancer.py diff --git a/cores/clean.py b/cores/clean.py index ba7c2f4..00ba51c 100644 --- a/cores/clean.py +++ b/cores/clean.py @@ -5,6 +5,7 @@ import torch from concurrent.futures import ThreadPoolExecutor from models import runmodel +from models.enhancer import enhance_patch from util import data,util,ffmpeg,filt from util import image_processing as impro from .init import video_init @@ -118,6 +119,8 @@ def cleanmosaic_img(opt,netG,netM): img_fake = runmodel.traditional_cleaner(img_mosaic,opt) else: img_fake = runmodel.run_pix2pix(img_mosaic,netG,opt) + if getattr(opt, 'enhance', False): + img_fake = enhance_patch(img_fake, opt.gpu_id) img_result = impro.replace_mosaic(img_origin,img_fake,mask,x,y,size,opt.no_feather) else: print('Do not find mosaic') @@ -132,6 +135,8 @@ def cleanmosaic_img_server(opt,img_origin,netG,netM): img_fake = runmodel.traditional_cleaner(img_mosaic,opt) else: img_fake = runmodel.run_pix2pix(img_mosaic,netG,opt) + if getattr(opt, 'enhance', False): + img_fake = enhance_patch(img_fake, opt.gpu_id) img_result = impro.replace_mosaic(img_origin,img_fake,mask,x,y,size,opt.no_feather) return img_result @@ -168,6 +173,8 @@ def frame_loader(paths): img_fake = runmodel.traditional_cleaner(img_mosaic,opt) else: img_fake = runmodel.run_pix2pix(img_mosaic,netG,opt) + if getattr(opt, 'enhance', False): + img_fake = enhance_patch(img_fake, opt.gpu_id) mask = cv2.imread(os.path.join(opt.temp_dir+'/mosaic_mask',imagepath),0) img_result = impro.replace_mosaic(img_origin,img_fake,mask,x,y,size,opt.no_feather) except Exception as e: @@ -185,11 +192,11 @@ def _write_and_rm(op, rp, img): cv2.waitKey(1) & 0xFF t2 = time.time() print('\r',str(i+1)+'/'+str(length),util.get_bar(100*i/length,num=35),util.counttime(t1,t2,i+1,len(imagepaths)),end='') - print() - write_exec.shutdown(wait=True) + print(flush=True) + write_exec.shutdown(wait=True, cancel_futures=False) if not opt.no_preview: cv2.destroyAllWindows() - print('Step:4/4 -- Convert images to video') + print('Step:4/4 -- Convert images to video', flush=True) ffmpeg.image2video( fps, opt.temp_dir+'/replace_mosaic/output_%06d.'+opt.tempimage_type, opt.temp_dir+'/voice_tmp.mp3', @@ -233,27 +240,38 @@ def write_single(save_ori, imagepath, img_origin, img_fake, x, y, size): write_executor = ThreadPoolExecutor(max_workers=WRITE_WORKERS) - # Prefetch all needed frame indices into a cache via background thread img_cache = {} prefetch_queue = Queue(16) + _prefetch_error = [None] + def prefetch_loader(): - needed = set() - for i in range(length): - if i == 0: - for j in range(POOL_NUM): - needed.add(np.clip(i+j-LEFT_FRAME,0,length-1)) - else: - needed.add(np.clip(i+LEFT_FRAME,0,length-1)) - for idx in sorted(needed): - img = impro.imread(os.path.join(opt.temp_dir+'/video2image',imagepaths[idx])) - prefetch_queue.put((idx, img)) - prefetch_t = Thread(target=prefetch_loader) - prefetch_t.setDaemon(True) + try: + needed = set() + for i in range(length): + if i == 0: + for j in range(POOL_NUM): + needed.add(np.clip(i+j-LEFT_FRAME,0,length-1)) + else: + needed.add(np.clip(i+LEFT_FRAME,0,length-1)) + for idx in sorted(needed): + img = impro.imread(os.path.join(opt.temp_dir+'/video2image',imagepaths[idx])) + prefetch_queue.put((idx, img)) + except Exception as e: + _prefetch_error[0] = e + finally: + prefetch_queue.put(None) + + prefetch_t = Thread(target=prefetch_loader, daemon=True) prefetch_t.start() def ensure_cached(idx): while idx not in img_cache: - k, v = prefetch_queue.get() + item = prefetch_queue.get(timeout=30) + if item is None: + if _prefetch_error[0]: + raise RuntimeError('Prefetch failed: %s' % _prefetch_error[0]) + return + k, v = item img_cache[k] = v for i,imagepath in enumerate(imagepaths,0): @@ -302,6 +320,8 @@ def ensure_cached(idx): else: unmosaic_pred = netG(input_stream,previous_frame) img_fake = data.tensor2im(unmosaic_pred,rgb2bgr = True) + if getattr(opt, 'enhance', False): + img_fake = enhance_patch(img_fake, opt.gpu_id) previous_frame = unmosaic_pred write_executor.submit(write_single, False, imagepath, img_origin.copy(), img_fake.copy(), x, y, size) except Exception as e: @@ -313,13 +333,13 @@ def ensure_cached(idx): t2 = time.time() print('\r',str(i+1)+'/'+str(length),util.get_bar(100*i/length,num=35),util.counttime(t1,t2,i+1,len(imagepaths)),end='') - print() - write_executor.shutdown(wait=True) + print(flush=True) + write_executor.shutdown(wait=True, cancel_futures=False) while not show_pool.empty(): show_pool.get_nowait() if not opt.no_preview: cv2.destroyAllWindows() - print('Step:4/4 -- Convert images to video') + print('Step:4/4 -- Convert images to video', flush=True) ffmpeg.image2video( fps, opt.temp_dir+'/replace_mosaic/output_%06d.'+opt.tempimage_type, opt.temp_dir+'/voice_tmp.mp3', diff --git a/cores/options.py b/cores/options.py index 5d5680b..21baac3 100644 --- a/cores/options.py +++ b/cores/options.py @@ -33,7 +33,8 @@ def initialize(self): self.parser.add_argument('--mosaic_size', type=int, default=0,help='mosaic size,if 0 auto size') self.parser.add_argument('--mask_extend', type=int, default=10,help='extend mosaic area') - #CleanMosaic + #CleanMosaic + self.parser.add_argument('--enhance', action='store_true', help='if specified, apply Real-ESRGAN to enhance cleaned mosaic patches') self.parser.add_argument('--mosaic_position_model_path', type=str, default='auto',help='name of model use to find mosaic position') self.parser.add_argument('--traditional', action='store_true', help='if specified, use traditional image processing methods to clean mosaic') self.parser.add_argument('--tr_blur', type=int, default=10, help='ksize of blur when using traditional method, it will affect final quality') diff --git a/models/enhancer.py b/models/enhancer.py new file mode 100644 index 0000000..07f711e --- /dev/null +++ b/models/enhancer.py @@ -0,0 +1,51 @@ +import cv2 +import numpy as np +import torch +from basicsr.archs.rrdbnet_arch import RRDBNet +from realesrgan import RealESRGANer + +_MODEL_URL = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth' +_instance = None + +def get_enhancer(gpu_id): + """Lazy-init a shared RealESRGANer instance (downloads weights on first call).""" + global _instance + if _instance is not None: + return _instance + + model = RRDBNet( + num_in_ch=3, num_out_ch=3, + num_feat=64, num_block=23, num_grow_ch=32, scale=4, + ) + + if gpu_id == 'mps': + device = torch.device('mps') + elif gpu_id != '-1': + device = torch.device('cuda') + else: + device = torch.device('cpu') + + _instance = RealESRGANer( + scale=4, + model_path=_MODEL_URL, + model=model, + tile=0, + pre_pad=10, + half=False, + device=device, + ) + return _instance + + +def enhance_patch(img_bgr, gpu_id, outscale=1): + """Enhance a small BGR image patch using Real-ESRGAN. + + outscale=1 means output same size as input (upscale 4x then downscale). + outscale=2 means output 2x the input size, etc. + """ + upsampler = get_enhancer(gpu_id) + try: + output, _ = upsampler.enhance(img_bgr, outscale=outscale) + except RuntimeError: + output = img_bgr + return output