diff --git a/fastclass/fc_download.py b/fastclass/fc_download.py index f62189a..3bc834a 100755 --- a/fastclass/fc_download.py +++ b/fastclass/fc_download.py @@ -66,27 +66,22 @@ def crawl( for c in crawlers: print(f" -> {c}") - if c == "GOOGLE": - google_crawler = GoogleImageCrawler( + if c == "BAIDU": + baidu_crawler = BaiduImageCrawler( downloader_cls=CustomDownloader, - parser_cls=GoogleParser, log_level=logging.CRITICAL, - feeder_threads=1, - parser_threads=1, - downloader_threads=4, storage={"root_dir": folder}, ) - - google_crawler.crawl( + baidu_crawler.crawl( keyword=search, offset=0, max_num=maxnum, min_size=(200, 200), max_size=None, - file_idx_offset=0, + file_idx_offset="auto", ) - if c == "BING": + elif c == "BING": bing_crawler = BingImageCrawler( downloader_cls=CustomDownloader, log_level=logging.CRITICAL, @@ -101,22 +96,7 @@ def crawl( file_idx_offset="auto", ) - if c == "BAIDU": - baidu_crawler = BaiduImageCrawler( - downloader_cls=CustomDownloader, - log_level=logging.CRITICAL, - storage={"root_dir": folder}, - ) - baidu_crawler.crawl( - keyword=search, - offset=0, - max_num=maxnum, - min_size=(200, 200), - max_size=None, - file_idx_offset="auto", - ) - - if c == "FLICKR": + elif c == "FLICKR": flick_api_key = os.environ.get("FLICKR_API_KEY") if not flick_api_key: print( @@ -139,15 +119,32 @@ def crawl( file_idx_offset="auto", ) + elif c == "GOOGLE": + google_crawler = GoogleImageCrawler( + downloader_cls=CustomDownloader, + parser_cls=GoogleParser, + log_level=logging.CRITICAL, + feeder_threads=1, + parser_threads=1, + downloader_threads=4, + storage={"root_dir": folder}, + ) + + google_crawler.crawl( + keyword=search, + offset=0, + max_num=maxnum, + min_size=(200, 200), + max_size=None, + file_idx_offset=0, + ) + return {k: v for k, v in CustomDownloader.registry.items() if k is not None} def main( infile: str, size: int, crawler: List[str], keep: bool, maxnum: int, outpath: str ): - SIZE = (size, size) - classes = [] - if "ALL" in crawler: crawler = ["GOOGLE", "BING"] @@ -156,7 +153,7 @@ def main( f'Directory "{outpath}" exists. Would you like to overwrite the directory? [y/n]' ) choice = input().lower() - while not (choice == "y" or "n"): + while choice != "y" and not "n": print("Please reply with 'y' or 'n'") choice = input().lower() if choice == "y": @@ -170,6 +167,8 @@ def main( print(f"INFO: final dataset will be located in {outpath}") with tempfile.TemporaryDirectory() as tmp: + classes = [] + for lcnt, line in enumerate(infile): if lcnt > 0: no_cols = line[:-1].count(",") + 1 @@ -180,6 +179,7 @@ def main( remove_terms = None classes.append((search_term, remove_terms)) + SIZE = (size, size) for i, (search_term, remove_terms) in enumerate(classes): print(f"[{i+1}/{len(classes)}] Searching: >> {search_term} <<") out_name = sanitize_searchstring(search_term, rstring=remove_terms)