From 147f9b4daab87b539058f99417331867fd40ca0b Mon Sep 17 00:00:00 2001 From: Ayyuce Demirbas Date: Thu, 3 Oct 2024 00:44:23 +0300 Subject: [PATCH] added data/download_images.py --- llava/data/download_images.py | 49 +++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 llava/data/download_images.py diff --git a/llava/data/download_images.py b/llava/data/download_images.py new file mode 100644 index 0000000..36a49a2 --- /dev/null +++ b/llava/data/download_images.py @@ -0,0 +1,49 @@ +import os +import json +import shutil +from tqdm import tqdm +import tarfile +import argparse +from urllib.error import HTTPError +import urllib.request + + +def main(args): + input_data = [] + with open(args.input_path) as f: + for line in f: + input_data.append(json.loads(line)) + + # Download all PMC articles + print('Downloading PMC articles') + for idx, sample in enumerate(tqdm(input_data)): + try: + urllib.request.urlretrieve(sample['pmc_tar_url'], os.path.join(args.pmc_output_path, os.path.basename(sample['pmc_tar_url']))) + except HTTPError as e: + print('Error downloading PMC article: {}'.format(sample['pmc_tar_url'])) + continue + + + # Untar all PMC articles + print('Untarring PMC articles') + for sample in tqdm(input_data): + fname = os.path.join(args.pmc_output_path, os.path.basename(os.path.join(sample['pmc_tar_url']))) + tar = tarfile.open(fname, "r:gz") + tar.extractall(args.pmc_output_path) + tar.close() + + # Copy to images directory + print('Copying images') + for sample in tqdm(input_data): + src = os.path.join(args.pmc_output_path, sample['image_file_path']) + dst = os.path.join(args.images_output_path, sample['pair_id']+'.jpg') + shutil.copyfile(src, dst) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--input_path', type=str, default='data/llava_med_image_urls.jsonl') + parser.add_argument('--pmc_output_path', type=str, default='data/pmc_articles/') + parser.add_argument('--images_output_path', type=str, default='data/images/') + args = parser.parse_args() + main(args)