diff --git a/videocr/api.py b/videocr/api.py index d7dc4a0..4cdf6bf 100644 --- a/videocr/api.py +++ b/videocr/api.py @@ -4,7 +4,7 @@ def get_subtitles( video_path: str, lang='eng', time_start='0:00', time_end='', - conf_threshold=65, sim_threshold=90, use_fullframe=False) -> str: + conf_threshold=65, sim_threshold=90, use_fullframe=False, skip_frames=0) -> str: utils.download_lang_data(lang) v = Video(video_path) diff --git a/videocr/video.py b/videocr/video.py index 27ebdc3..8e1c08c 100644 --- a/videocr/video.py +++ b/videocr/video.py @@ -29,12 +29,13 @@ def __init__(self, path: str): self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT)) def run_ocr(self, lang: str, time_start: str, time_end: str, - conf_threshold: int, use_fullframe: bool) -> None: + conf_threshold: int, use_fullframe: bool, skip_frames: int) -> None: self.lang = lang self.use_fullframe = use_fullframe ocr_start = utils.get_frame_index(time_start, self.fps) if time_start else 0 ocr_end = utils.get_frame_index(time_end, self.fps) if time_end else self.num_frames + skip_frames = max(1, int(self.fps / 2) if skip_frames is True else 0 if skip_frames is False else skip_frames) if ocr_end < ocr_start: raise ValueError('time_start is later than time_end') @@ -43,8 +44,8 @@ def run_ocr(self, lang: str, time_start: str, time_end: str, # get frames from ocr_start to ocr_end with Capture(self.path) as v, multiprocessing.Pool() as pool: v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start) - frames = (v.read()[1] for _ in range(num_ocr_frames)) + frames = (v.read()[1] for _ in range(0, num_ocr_frames, skip_frames)) # perform ocr to frames in parallel it_ocr = pool.imap(self._image_to_data, frames, chunksize=10) self.pred_frames = [