Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 35 additions & 29 deletions transparency-in-coverage/python/mrfutils/src/mrfutils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from itertools import chain
from pathlib import Path
from urllib.parse import urlparse
import zipfile
import io

import requests

from mrfutils.exceptions import InvalidMRF

log = logging.getLogger('mrfutils')
log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)


def prepend(value, iterator):
Expand All @@ -37,7 +39,7 @@ def peek(iterator):

class JSONOpen:
"""
Context manager for opening JSON(.gz) MRFs.
Context manager for opening JSON(.gz/.zip) MRFs.
Usage:
>>> with JSONOpen('localfile.json') as f:
or
Expand All @@ -58,44 +60,48 @@ def __init__(self, filename):

if not (
self.suffix.endswith('.json.gz') or
self.suffix.endswith('.json')
self.suffix.endswith('.json') or
self.suffix.endswith('.zip')
):
raise InvalidMRF(f'Suffix not JSON: {self.filename=} {self.suffix=}')
raise InvalidMRF(f'Suffix not JSON or ZIP: {self.filename=} {self.suffix=}')

self.is_remote = parsed_url.scheme in ('http', 'https')

def __enter__(self):
if (
self.is_remote
# endswith is used to protect against the case
# where the filename contains lots of dots
# insurer.stuff.json.gz
and self.suffix.endswith('.json.gz')
):
self.s = requests.Session()
self.r = self.s.get(self.filename, stream=True)
self.f = gzip.GzipFile(fileobj=self.r.raw)

elif (
self.is_remote
and self.suffix.endswith('.json')
):
self.s = requests.Session()
self.r = self.s.get(self.filename, stream=True)
self.r.raw.decode_content = True
self.f = self.r.raw

elif self.suffix == '.json.gz':
self.f = gzip.open(self.filename, 'rb')

if self.is_remote and self.suffix.endswith('.zip'):
# Download the zip file and store it in memory
response = requests.get(self.filename)
response.raise_for_status()
zip_data = io.BytesIO(response.content)

# Open the first file in the zip
with zipfile.ZipFile(zip_data) as zip_file:
inner_filename = zip_file.namelist()[0]
self.f = zip_file.open(inner_filename)
elif self.suffix.endswith('.json.gz'):
if self.is_remote:
self.s = requests.Session()
self.r = self.s.get(self.filename, stream=True)
self.f = gzip.GzipFile(fileobj=self.r.raw)
else:
self.f = gzip.open(self.filename, 'rb')
elif self.suffix.endswith('.json'):
if self.is_remote:
self.s = requests.Session()
self.r = self.s.get(self.filename, stream=True)
self.r.raw.decode_content = True
self.f = self.r.raw
else:
self.f = open(self.filename, 'rb')
else:
self.f = open(self.filename, 'rb')
raise InvalidMRF(f'Suffix not JSON or ZIP: {self.filename=} {self.suffix=}')

log.info(f'Opened file: {self.filename}')
return self.f


def __exit__(self, exc_type, exc_val, exc_tb):
if self.is_remote:
if self.is_remote and not self.suffix.endswith('.zip'):
self.s.close()
self.r.close()

Expand Down