diff --git a/.gitignore b/.gitignore index 7229f90..65cf3c4 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ *.tar *.tgz *.egg-info +*.DS_Store diff --git a/setup.py b/setup.py index b7260be..0bce48b 100644 --- a/setup.py +++ b/setup.py @@ -1,84 +1,198 @@ #!/usr/bin/env python +# -*- coding: utf-8 -*- -descr = """A collection of datasets available and associated tools""" +""" distribute- and pip-enabled setup.py """ -import sys +import logging import os -import shutil - -DISTNAME = 'skdata' -DESCRIPTION = '' -LONG_DESCRIPTION = open('README.rst').read() -MAINTAINER = 'James Bergstra' -MAINTAINER_EMAIL = 'bergstra@rowland.harvard.edu' -URL = '' -LICENSE = 'new BSD' -DOWNLOAD_URL = '' -VERSION = '0.1' - -import setuptools # we are using a setuptools namespace -from numpy.distutils.core import setup - - -if __name__ == "__main__": - - old_path = os.getcwd() - local_path = os.path.dirname(os.path.abspath(sys.argv[0])) - # python 3 compatibility stuff. - # Simplified version of scipy strategy: copy files into - # build/py3k, and patch them using lib2to3. - if sys.version_info[0] == 3: - try: - import lib2to3cache - except ImportError: - pass - local_path = os.path.join(local_path, 'build', 'py3k') - if os.path.exists(local_path): - shutil.rmtree(local_path) - print("Copying source tree into build/py3k for 2to3 transformation" - "...") - - import lib2to3.main - from io import StringIO - print("Converting to Python3 via 2to3...") - _old_stdout = sys.stdout - try: - sys.stdout = StringIO() # supress noisy output - res = lib2to3.main.main("lib2to3.fixes", - ['-x', 'import', '-w', local_path]) - finally: - sys.stdout = _old_stdout - - if res != 0: - raise Exception('2to3 failed, exiting ...') - - os.chdir(local_path) - sys.path.insert(0, local_path) - - setup(name=DISTNAME, - maintainer=MAINTAINER, - packages=setuptools.find_packages(), - include_package_data=True, - maintainer_email=MAINTAINER_EMAIL, - description=DESCRIPTION, - license=LICENSE, - url=URL, - version=VERSION, - download_url=DOWNLOAD_URL, - long_description=LONG_DESCRIPTION, - zip_safe=True, # the package can run out of an .egg file - install_requires=['numpy>=1.3.0'], # 'glumpy>=0.1.0' - classifiers=[ - 'Intended Audience :: Science/Research', - 'Intended Audience :: Developers', - 'License :: OSI Approved', - 'Programming Language :: C', - 'Programming Language :: Python', - 'Topic :: Software Development', - 'Topic :: Scientific/Engineering', - 'Operating System :: Microsoft :: Windows', - 'Operating System :: POSIX', - 'Operating System :: Unix', - 'Operating System :: MacOS' - ] - ) +import re + +# ----- overrides ----- + +# set these to anything but None to override the automatic defaults +packages = None +package_name = None +package_data = None +scripts = None +requirements_file = None +requirements = None +dependency_links = None + +# --------------------- + + +# ----- control flags ----- + +# fallback to setuptools if distribute isn't found +setup_tools_fallback = False + +# don't include subdir named 'tests' in package_data +skip_tests = True + +# print some extra debugging info +debug = True + +# ------------------------- + +if debug: logging.basicConfig(level=logging.DEBUG) +# distribute import and testing +try: + import distribute_setup + distribute_setup.use_setuptools() + logging.debug("distribute_setup.py imported and used") +except ImportError: + # fallback to setuptools? + # distribute_setup.py was not in this directory + if not (setup_tools_fallback): + import setuptools + if not (hasattr(setuptools,'_distribute') and \ + setuptools._distribute): + raise ImportError("distribute was not found and fallback to setuptools was not allowed") + else: + logging.debug("distribute_setup.py not found, defaulted to system distribute") + else: + logging.debug("distribute_setup.py not found, defaulting to system setuptools") + +import setuptools + +def find_scripts(): + return [s for s in setuptools.findall('scripts/') if os.path.splitext(s)[1] != '.pyc'] + +def package_to_path(package): + """ + Convert a package (as found by setuptools.find_packages) + e.g. "foo.bar" to usable path + e.g. "foo/bar" + + No idea if this works on windows + """ + return package.replace('.','/') + +def find_subdirectories(package): + """ + Get the subdirectories within a package + This will include resources (non-submodules) and submodules + """ + try: + subdirectories = os.walk(package_to_path(package)).next()[1] + except StopIteration: + subdirectories = [] + return subdirectories + +def subdir_findall(dir, subdir): + """ + Find all files in a subdirectory and return paths relative to dir + + This is similar to (and uses) setuptools.findall + However, the paths returned are in the form needed for package_data + """ + strip_n = len(dir.split('/')) + path = '/'.join((dir, subdir)) + return ['/'.join(s.split('/')[strip_n:]) for s in setuptools.findall(path)] + +def find_package_data(packages): + """ + For a list of packages, find the package_data + + This function scans the subdirectories of a package and considers all + non-submodule subdirectories as resources, including them in + the package_data + + Returns a dictionary suitable for setup(package_data=) + """ + package_data = {} + for package in packages: + package_data[package] = [] + for subdir in find_subdirectories(package): + if '.'.join((package, subdir)) in packages: # skip submodules + logging.debug("skipping submodule %s/%s" % (package, subdir)) + continue + if skip_tests and (subdir == 'tests'): # skip tests + logging.debug("skipping tests %s/%s" % (package, subdir)) + continue + package_data[package] += subdir_findall(package_to_path(package), subdir) + return package_data + +def parse_requirements(file_name): + """ + from: + http://cburgmer.posterous.com/pip-requirementstxt-and-setuppy + """ + requirements = [] + with open(file_name, 'r') as f: + for line in f: + if re.match(r'(\s*#)|(\s*$)', line): continue + if re.match(r'\s*-e\s+', line): + requirements.append(re.sub(r'\s*-e\s+.*#egg=(.*)$',\ + r'\1', line).strip()) + elif re.match(r'\s*-f\s+', line): + pass + else: + requirements.append(line.strip()) + return requirements + +def parse_dependency_links(file_name): + """ + from: + http://cburgmer.posterous.com/pip-requirementstxt-and-setuppy + """ + dependency_links = [] + with open(file_name) as f: + for line in f: + if re.match(r'\s*-[ef]\s+', line): + dependency_links.append(re.sub(r'\s*-[ef]\s+',\ + '', line)) + return dependency_links + +# ----------- Override defaults here ---------------- +if packages is None: packages = setuptools.find_packages() + +if len(packages) == 0: raise Exception("No valid packages found") + +if package_name is None: package_name = packages[0] + +if package_data is None: package_data = find_package_data(packages) + +if scripts is None: scripts = find_scripts() + +if requirements_file is None: + requirements_file = 'requirements.txt' + +if os.path.exists(requirements_file): + if requirements is None: + requirements = parse_requirements(requirements_file) + if dependency_links is None: + dependency_links = parse_dependency_links(requirements_file) +else: + if requirements is None: + requirements = [] + if dependency_links is None: + dependency_links = [] + +if debug: + logging.debug("Module name: %s" % package_name) + for package in packages: + logging.debug("Package: %s" % package) + logging.debug("\tData: %s" % str(package_data[package])) + logging.debug("Scripts:") + for script in scripts: + logging.debug("\tScript: %s" % script) + logging.debug("Requirements:") + for req in requirements: + logging.debug("\t%s" % req) + logging.debug("Dependency links:") + for dl in dependency_links: + logging.debug("\t%s" % dl) + +setuptools.setup( + name = package_name, + version = 'dev', + packages = packages, + scripts = scripts, + + package_data = package_data, + include_package_data = True, + + install_requires = requirements, + dependency_links = dependency_links +) diff --git a/skdata/fbo.py b/skdata/fbo.py new file mode 100644 index 0000000..57a84ee --- /dev/null +++ b/skdata/fbo.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +"""A very simple Face Body Object dataset. + + Contains 60 grayscale images (20 each of monkey faces, monkey bodies, + and various objects) on pink noise backgrounds. +""" + +# Copyright (C) 2011 +# Authors: Elias Issa and Dan Yamins + +# License: Simplified BSD + + +import os +from os import path +import shutil +from glob import glob +import hashlib + +import numpy as np + +import larray +from data_home import get_data_home +from utils import download, extract, int_labels +from utils.image import ImgLoader + + +class BaseFaceBodyObject(object): + + def __init__(self, meta=None, seed=0, ntrain=10, ntest=10, num_splits=5): + + self.seed = seed + self.ntrain = ntrain + self.ntest = ntest + self.num_splits = num_splits + self.names = ['Face','Body','Object'] + + if meta is not None: + self._meta = meta + + self.name = self.__class__.__name__ + + try: + from joblib import Memory + mem = Memory(cachedir=self.home('cache')) + self._get_meta = mem.cache(self._get_meta) + except ImportError: + pass + + def home(self, *suffix_paths): + return path.join(get_data_home(), self.name, *suffix_paths) + + # ------------------------------------------------------------------------ + # -- Dataset Interface: fetch() + # ------------------------------------------------------------------------ + + def fetch(self, download_if_missing=True): + """Download and extract the dataset.""" + + home = self.home() + + if not download_if_missing: + raise IOError("'%s' exists!" % home) + + # download archive + url = self.URL + sha1 = self.SHA1 + basename = path.basename(url) + archive_filename = path.join(home, basename) + if not path.exists(archive_filename): + if not download_if_missing: + return + if not path.exists(home): + os.makedirs(home) + download(url, archive_filename, sha1=sha1) + + # extract it + if not path.exists(self.home(self.SUBDIR)): + extract(archive_filename, home, sha1=sha1, verbose=True) + + # ------------------------------------------------------------------------ + # -- Dataset Interface: meta + # ------------------------------------------------------------------------ + + @property + def meta(self): + if not hasattr(self, '_meta'): + self.fetch(download_if_missing=True) + self._meta = self._get_meta() + return self._meta + + def _get_meta(self): + + img_filenames = sorted(os.listdir(self.home(self.SUBDIR))) + img_filenames = [os.path.join(self.home(self.SUBDIR),x) for x in img_filenames] + + meta = [] + for img_filename in img_filenames: + img_data = open(img_filename, 'rb').read() + sha1 = hashlib.sha1(img_data).hexdigest() + ind = int(os.path.split(img_filename)[1].split('.')[0][2:]) + if ind < 21: + name = 'Face' + elif ind < 41: + name = 'Body' + else: + name = 'Object' + + data = dict(name=name, + id=ind, + filename=img_filename, + sha1=sha1) + + meta += [data] + + + return meta + + @property + def splits(self): + """ + generates splits and attaches them in the "splits" attribute + """ + if not hasattr(self, '_splits'): + seed = self.seed + ntrain = self.ntrain + ntest = self.ntest + num_splits = self.num_splits + self._splits = self.generate_splits(seed, ntrain, + ntest, num_splits) + return self._splits + + def generate_splits(self, seed, ntrain, ntest, num_splits, labelset=None, catfunc=None): + meta = self.meta + if labelset is not None: + assert catfunc is not None + else: + labelset = self.names + catfunc = lambda x : x['name'] + + ntrain = self.ntrain + ntest = self.ntest + rng = np.random.RandomState(seed) + splits = {} + for split_id in range(num_splits): + splits['train_' + str(split_id)] = [] + splits['test_' + str(split_id)] = [] + for label in labelset: + cat = [m for m in meta if catfunc(m) == label] + L = len(cat) + assert L >= ntrain + ntest, 'category %s too small' % name + perm = rng.permutation(L) + for ind in perm[:ntrain]: + splits['train_' + str(split_id)].append(cat[ind]['filename']) + for ind in perm[ntrain: ntrain + ntest]: + splits['test_' + str(split_id)].append(cat[ind]['filename']) + return splits + + # ------------------------------------------------------------------------ + # -- Dataset Interface: clean_up() + # ------------------------------------------------------------------------ + + def clean_up(self): + if path.isdir(self.home()): + shutil.rmtree(self.home()) + + # ------------------------------------------------------------------------ + # -- Standard Tasks + # ------------------------------------------------------------------------ + + def raw_classification_task(self, split=None): + """Return image_paths, labels""" + if split: + inds = self.splits[split] + else: + inds = xrange(len(self.meta)) + image_paths = [self.meta[ind]['filename'] for ind in inds] + names = np.asarray([self.meta[ind]['name'] for ind in inds]) + labels = int_labels(names) + return image_paths, labels + + def img_classification_task(self, dtype='uint8', split=None): + img_paths, labels = self.raw_classification_task(split=split) + imgs = larray.lmap(ImgLoader(ndim=2, shape=(400,400), dtype=dtype, mode='L'), + img_paths) + return imgs, labels + + + +class FaceBodyObject20110803(BaseFaceBodyObject): + URL = 'http://dicarlocox-datasets.s3.amazonaws.com/FaceBodyObject_2011_08_03.tar.gz' + SHA1 = '088387e08ac008a0b8326e7dec1f0a667c8b71d0' + SUBDIR = 'FaceBodyObject_2011_08_03' + diff --git a/skdata/larray.py b/skdata/larray.py index dce3d6e..49bc414 100644 --- a/skdata/larray.py +++ b/skdata/larray.py @@ -96,6 +96,7 @@ class lmap(larray): def __init__(self, fn, obj0, *objs, **kwargs): ragged = kwargs.pop('ragged', False) f_map = kwargs.pop('f_map', None) + verbose = kwargs.pop('verbose', False) if kwargs: raise TypeError('unrecognized kwarg', kwargs.keys()) @@ -103,6 +104,7 @@ def __init__(self, fn, obj0, *objs, **kwargs): self.objs = [obj0] + list(objs) self.ragged = ragged self.f_map = f_map + self.verbose = verbose if not ragged: for o in objs: if len(obj0) != len(o): @@ -134,18 +136,40 @@ def __getitem__(self, idx): if is_int_idx(idx): return self.fn(*[o[idx] for o in self.objs]) else: + if self.verbose: + if isinstance(idx, slice): + print ('Evaluating items from a slice.') + else: + num_items = len(idx) + print ('Evaluating %d items' % num_items) try: tmps = [o[idx] for o in self.objs] except TypeError: # advanced indexing failed, try one element at a time - return [self.fn(*[o[i] for o in self.objs]) - for i in idx] + #return [self.fn(*[o[i] for o in self.objs]) + # for i in idx] + vals = [] + for ind, i in enumerate(idx): + if (ind / 100) * 100 == ind: + print(ind, i) + tmp = [o[i] for o in self.objs] + vals.append(self.fn(*tmp)) + return vals # we loaded our args by advanced indexing if self.f_map: return self.f_map(*tmps) else: - return map(self.fn, *tmps) + if self.verbose: + vals = [] + for ind, i in enumerate(idx): + if (ind / 100) * 100 == ind: + print(ind, i) + tmp = [o[i] for o in self.objs] + vals.append(self.fn(*tmp)) + return vals + else: + return map(self.fn, *tmps) def __array__(self): #XXX: use self.batch_len to produce this more efficiently @@ -399,8 +423,9 @@ def inputs(self): return [self.obj] def __getitem__(self, item): + test = self.test if isinstance(item, (int, np.int)): - if self._valid[item]: + if self._valid[item] or (test is not None and item > test): return self._data[item] else: obj_item = self.obj[item] @@ -409,6 +434,11 @@ def __getitem__(self, item): self.rows_computed += 1 return self._data[item] else: + if test is not None: + if hasattr(item, '__getitem__'): + item = item[:test] + else: + return self._data[item] # could be a slice, an intlist, a tuple v = self._valid[item] assert v.ndim == 1 @@ -456,11 +486,12 @@ class cache_memmap(CacheMixin, larray): ROOT = os.path.join(get_data_home(), 'memmaps') - def __init__(self, obj, name, basedir=None, msg=None): + def __init__(self, obj, name, basedir=None, msg=None, test=None): """ If new files are created, then `msg` will be written to README.msg """ + self.test = test self.obj = obj if basedir is None: basedir = self.ROOT diff --git a/skdata/pubfig83.py b/skdata/pubfig83.py index e5373a8..461315f 100644 --- a/skdata/pubfig83.py +++ b/skdata/pubfig83.py @@ -18,6 +18,7 @@ # Dan Yamins # James Bergstra # Nicolas Pinto +# Giovani Chiachia # License: Simplified BSD @@ -30,10 +31,25 @@ from glob import glob import hashlib +import larray from data_home import get_data_home -from utils import download, extract +from utils import download, extract, int_labels, download_and_extract import utils import utils.image +from utils.image import ImgLoader + +from sklearn import cross_validation +import numpy as np + +DEFAULT_NTRAIN = 80 +DEFAULT_NFOLDS = 5 +DEFAULT_NVALIDATE = 10 +DEFAULT_NTEST = 10 + + +class NotEnoughExamplesError(Exception): + def __init__(self, label, have, want): + self.msg = '%d wanted, but have only %d for %s' % (want, have, label) class PubFig83(object): @@ -79,18 +95,18 @@ class PubFig83(object): 'female', 'male', 'male', 'male', 'male', 'female', 'female', 'male', 'male', 'male'] - def __init__(self, meta=None): + def __init__(self, meta=None, + ntrain=DEFAULT_NTRAIN, + nvalidate=DEFAULT_NVALIDATE, + ntest=DEFAULT_NTEST, + nfolds=DEFAULT_NFOLDS): if meta is not None: self._meta = meta - self.name = self.__class__.__name__ - - try: - from joblib import Memory - mem = Memory(cachedir=self.home('cache')) - self._get_meta = mem.cache(self._get_meta) - except ImportError: - pass + self.ntrain = ntrain + self.nvalidate = nvalidate + self.ntest = ntest + self.nfolds = nfolds def home(self, *suffix_paths): return path.join(get_data_home(), self.name, *suffix_paths) @@ -112,16 +128,13 @@ def fetch(self, download_if_missing=True): sha1 = self.SHA1 basename = path.basename(url) archive_filename = path.join(home, basename) - if not path.exists(archive_filename): + if not path.exists(self.home('pubfig83')): if not download_if_missing: return if not path.exists(home): os.makedirs(home) - download(url, archive_filename, sha1=sha1) + download_and_extract(url, archive_filename, sha1=sha1, verbose=True) - # extract it - if not path.exists(self.home('pubfig83')): - extract(archive_filename, home, sha1=sha1, verbose=True) # ------------------------------------------------------------------------ # -- Dataset Interface: meta @@ -129,20 +142,18 @@ def fetch(self, download_if_missing=True): @property def meta(self): - if hasattr(self, '_meta'): - return self._meta - else: + if not hasattr(self, '_meta'): self.fetch(download_if_missing=True) self._meta = self._get_meta() - return self._meta + return self._meta def _get_meta(self): - names = sorted(os.listdir(self.home('pubfig83'))) + names2 = sorted(os.listdir(self.home('pubfig83'))) genders = self._GENDERS - assert len(names) == len(genders) + assert len(names2) == len(genders) meta = [] ind = 0 - for gender, name in zip(genders, names): + for gender, name in zip(genders, names2): img_filenames = sorted(glob(self.home('pubfig83', name, '*.jpg'))) for img_filename in img_filenames: img_data = open(img_filename, 'rb').read() @@ -150,8 +161,100 @@ def _get_meta(self): meta.append(dict(gender=gender, name=name, id=ind, filename=img_filename, sha1=sha1)) ind += 1 + return meta + @property + def names(self): + if not hasattr(self, '_names'): + self._names = np.array([self.meta[ind]['name'] for ind in xrange(len(self.meta))]) + return self._names + + @property + def classification_splits(self): + """ + """ + if not hasattr(self, '_classification_splits'): + self._classification_splits = \ + self._generate_classification_splits(self.ntrain, + self.nvalidate, + self.ntest, + self.nfolds) + + return self._classification_splits + + def _generate_classification_splits(self, ntrain, nvalidate, ntest, nfolds): + meta = self.meta + rng = np.random.RandomState(0) + classification_splits = {} + + splits = {} + labels = np.unique(self.names) + for label in labels: + samples_to_consider = (self.names == label) + samples_to_consider = np.where(samples_to_consider)[0] + if len(samples_to_consider) < ntrain + nvalidate + ntest: + raise NotEnoughExamplesError(label, + len(samples_to_consider), + ntrain + nvalidate + ntest) + p = rng.permutation(len(samples_to_consider)) + if 'Test' not in splits: + splits['Test'] = [] + splits['Test'].extend(samples_to_consider[p[: ntest]]) + remainder = samples_to_consider[p[ntest: ntest + ntrain + nvalidate]] + assert len(remainder) == ntrain + nvalidate + for _ind in range(nfolds): + p = rng.permutation(len(remainder)) + if 'Train%d' % _ind not in splits: + splits['Train%d' % _ind] = [] + splits['Train%d' % _ind].extend(remainder[p[: ntrain]].copy()) + if 'Validate%d' % _ind not in splits: + splits['Validate%d' % _ind] = [] + splits['Validate%d' % _ind].extend( + remainder[p[ntrain: ntrain + nvalidate]].copy()) + + return splits + + + def view2_classification_splits(self, ntrain_view2_additional, ntest_view2, nfolds_view2): + """ + """ + ntrain_screen = self.ntrain + nvalidate_screen = self.nvalidate + ntest_screen = self.ntest + nfolds_screen = self.nfolds + rng = np.random.RandomState(0) + + splits = self.classification_splits + all_train = sorted(splits['Train0'] + splits['Validate0']) + test = splits['Test'] + + assert ntrain_view2_additional + ntest_view2 <= len(test) + + view2_splits = {} + meta = self.meta + labels = np.unique(self.names) + for name in labels: + samples_to_consider = sorted(list(set((self.names == name).nonzero()[0]).intersection(test))) + p = rng.permutation(len(samples_to_consider)) + samples_to_consider = samples_to_consider[:ntrain_view2_additional + ntest_view2] + for _ind in range(nfolds_view2): + p = rng.permutation(len(samples_to_consider)) + train_set = samples_to_consider[:ntrain_view2_additional] + test_set = samples_to_consider[ntrain_view2_additional: ntrain_view2_additional + ntest_view2] + if 'Train%d' % _ind not in view2_splits: + view2_splits['Train%d' % _ind] = [] + view2_splits['Train%d' % _ind].extend(train_set) + if 'Validate%d' % _ind not in view2_splits: + view2_splits['Validate%d' % _ind] = [] + view2_splits['Validate%d' % _ind].extend(test_set) + + for _ind in range(nfolds): + view2_splits['Train%d' % _ind].extend(all_train) + + return view2_splits + + # ------------------------------------------------------------------------ # -- Dataset Interface: clean_up() # ------------------------------------------------------------------------ @@ -165,23 +268,44 @@ def clean_up(self): # ------------------------------------------------------------------------ def image_path(self, m): - return self.home('pubfig83', m['name'], m['jpgfile']) + return self.home('pubfig83', m['name'], m['filename']) + #return self.home('pubfig83', m['name'], m['jpgfile']) # ------------------------------------------------------------------------ # -- Standard Tasks # ------------------------------------------------------------------------ - def raw_recognition_task(self): - names = [m['name'] for m in self.meta] - paths = [self.image_path(m) for m in self.meta] - labels = utils.int_labels(names) - return paths, labels + def raw_classification_task(self, split=None): + """ + :param split: an integer from 0 to 9 inclusive. + :param split_role: either 'train' or 'test' + + :returns: either all samples (when split_k=None) or the specific + train/test split + """ + + if split is not None: + inds = self.classification_splits[split] + else: + inds = range(len(self.meta)) + names = self.names[inds] + paths = [self.meta[ind]['filename'] for ind in inds] + labels = int_labels(names) + return paths, labels, inds def raw_gender_task(self): genders = [m['gender'] for m in self.meta] paths = [self.image_path(m) for m in self.meta] return paths, utils.int_labels(genders) + def img_classification_task(self, dtype='uint8', split=None): + img_paths, labels, inds = self.raw_classification_task(split=split) + imgs = larray.lmap(ImgLoader(shape=(100, 100, 3), + dtype=dtype, + mode='RGB'), + img_paths) + return imgs, labels + # ------------------------------------------------------------------------ # -- Drivers for skdata/bin executables diff --git a/skdata/tests/test_pubfig83.py b/skdata/tests/test_pubfig83.py new file mode 100644 index 0000000..a685095 --- /dev/null +++ b/skdata/tests/test_pubfig83.py @@ -0,0 +1,387 @@ +import numpy as np + +import skdata.pubfig83 as pubfig83 + + +def test_meta(): + dataset = pubfig83.PubFig83() + meta = dataset.meta + names = dataset.names.tolist() + assert names == sorted(names) + assert len(meta) == 13838 + assert len(names) == 13838 + assert np.unique(names).tolist() == NAMES + counts = [names.count(n) for n in NAMES] + assert counts == COUNTS + + +def test_classification_splits(): + classification_splits_base(nfolds=pubfig83.DEFAULT_NFOLDS, + ntrain=pubfig83.DEFAULT_NTRAIN, + ntest=pubfig83.DEFAULT_NTEST, + nvalidate=pubfig83.DEFAULT_NVALIDATE) + + classification_splits_base(nfolds=2, + ntrain=pubfig83.DEFAULT_NTRAIN, + ntest=pubfig83.DEFAULT_NTEST, + nvalidate=pubfig83.DEFAULT_NVALIDATE) + + classification_splits_base(nfolds=10, + ntrain=pubfig83.DEFAULT_NTRAIN, + ntest=pubfig83.DEFAULT_NTEST, + nvalidate=pubfig83.DEFAULT_NVALIDATE) + + classification_splits_base(nfolds=pubfig83.DEFAULT_NFOLDS, + ntrain=40, + ntest=pubfig83.DEFAULT_NTEST, + nvalidate=pubfig83.DEFAULT_NVALIDATE) + + classification_splits_base(nfolds=pubfig83.DEFAULT_NFOLDS, + ntrain=40, + ntest=20, + nvalidate=pubfig83.DEFAULT_NVALIDATE) + + classification_splits_base(nfolds=pubfig83.DEFAULT_NFOLDS, + ntrain=40, + ntest=20, + nvalidate=0) + + try: + classification_splits_base(nfolds=pubfig83.DEFAULT_NFOLDS, + ntrain=200, + ntest=20, + nvalidate=pubfig83.DEFAULT_NVALIDATE) + except pubfig83.NotEnoughExamplesError: + pass + else: + raise Exception('Should have raised exception') + + +def classification_splits_base(nfolds, ntrain, nvalidate, ntest): + """ + Test that there are Test and Train/Validate splits + """ + dataset = pubfig83.PubFig83(ntrain=ntrain, nfolds=nfolds, ntest=ntest, + nvalidate=nvalidate) + splits = dataset.classification_splits + assert set(splits.keys()) == set(correct_split_names(nfolds)) + assert len(np.unique(splits['Test'])) == 83 * ntest + names = dataset.names + assert (names[splits['Test']] == np.repeat(NAMES, ntest)).all() + all_ids = [] + for ind in range(nfolds): + assert len(np.unique(splits['Train%d' % ind])) == ntrain * 83 + assert len(np.unique(splits['Validate%d' % ind])) == nvalidate * 83 + assert (names[splits['Train%d' % ind]] + == np.repeat(NAMES, ntrain)).all() + assert (names[splits['Validate%d' % ind]] + == np.repeat(NAMES, nvalidate)).all() + #no intersections between test & train & validate) + assert set(splits['Test']).intersection( + splits['Train%d' % ind]) == set([]) + assert set(splits['Test']).intersection( + splits['Validate%d' % ind]) == set([]) + assert set(splits['Train%d' % ind]).intersection( + splits['Validate%d' % ind]) == set([]) + all_id = np.concatenate([splits['Train%d' % ind], + splits['Validate%d' % ind]]) + all_id.sort() + all_ids.append(all_id) + if len(all_ids) > 1: + assert all([(all_ids[0] == a).all() for a in all_ids[1:]]) + + +def test_classification_task(): + dataset = pubfig83.PubFig83() + paths, labels, inds = dataset.raw_classification_task() + assert (np.unique(labels) == range(83)).all() + assert (labels == np.repeat(range(83), COUNTS)).all() + + +def test_images(): + dataset = pubfig83.PubFig83() + I, labels = dataset.img_classification_task() + #there are 13838 100x100 rgb images + assert I.shape == (13838, 100, 100, 3) + #a random sampling of 100 having the right checksums + rng = np.random.RandomState(0) + inds = rng.randint(13838, size=(100,)) + assert [I[k].sum() for k in inds] == DATA_SUMS + + +def correct_split_names(nfolds): + split_names = ['Test'] + for ind in range(nfolds): + split_names.append('Train%d' % ind) + split_names.append('Validate%d' % ind) + return split_names + + +DATA_SUMS = [3183920, + 2381895, + 2497078, + 3700911, + 2533567, + 2829736, + 3957948, + 2941275, + 3663414, + 4175522, + 4545230, + 3176163, + 2652150, + 2471430, + 2144770, + 4740294, + 4572341, + 3974616, + 2226700, + 3629392, + 3943478, + 3403567, + 3691577, + 890403, + 4196404, + 3583622, + 5685328, + 3843063, + 4130031, + 3781745, + 3693937, + 4125657, + 1390701, + 1943173, + 3618524, + 4654504, + 4665718, + 3277141, + 3934313, + 3299852, + 4722088, + 3655136, + 4589109, + 3565381, + 4761454, + 4475485, + 3465789, + 3301557, + 5026760, + 3802993, + 2146857, + 2603939, + 4387492, + 1534471, + 2423450, + 2423450, + 3513793, + 2715373, + 3729230, + 2429881, + 3134908, + 3442861, + 2906273, + 3432838, + 2586067, + 3440165, + 2112775, + 2222274, + 3857327, + 3499149, + 3345367, + 2801360, + 3056529, + 4648178, + 5010990, + 3719158, + 2397436, + 4283887, + 5021188, + 3495208, + 2782718, + 3130212, + 3866963, + 4484351, + 3990490, + 4404666, + 4405524, + 4067704, + 2960773, + 2320882, + 3444839, + 3998430, + 4276482, + 3327832, + 4051843, + 4109509, + 4220971, + 3645995, + 3256999, + 3039119] + +NAMES = ['Adam Sandler', + 'Alec Baldwin', + 'Angelina Jolie', + 'Anna Kournikova', + 'Ashton Kutcher', + 'Avril Lavigne', + 'Barack Obama', + 'Ben Affleck', + 'Beyonce Knowles', + 'Brad Pitt', + 'Cameron Diaz', + 'Cate Blanchett', + 'Charlize Theron', + 'Christina Ricci', + 'Claudia Schiffer', + 'Clive Owen', + 'Colin Farrell', + 'Colin Powell', + 'Cristiano Ronaldo', + 'Daniel Craig', + 'Daniel Radcliffe', + 'David Beckham', + 'David Duchovny', + 'Denise Richards', + 'Drew Barrymore', + 'Dustin Hoffman', + 'Ehud Olmert', + 'Eva Mendes', + 'Faith Hill', + 'George Clooney', + 'Gordon Brown', + 'Gwyneth Paltrow', + 'Halle Berry', + 'Harrison Ford', + 'Hugh Jackman', + 'Hugh Laurie', + 'Jack Nicholson', + 'Jennifer Aniston', + 'Jennifer Lopez', + 'Jennifer Love Hewitt', + 'Jessica Alba', + 'Jessica Simpson', + 'Joaquin Phoenix', + 'John Travolta', + 'Julia Roberts', + 'Julia Stiles', + 'Kate Moss', + 'Kate Winslet', + 'Katherine Heigl', + 'Keira Knightley', + 'Kiefer Sutherland', + 'Leonardo DiCaprio', + 'Lindsay Lohan', + 'Mariah Carey', + 'Martha Stewart', + 'Matt Damon', + 'Meg Ryan', + 'Meryl Streep', + 'Michael Bloomberg', + 'Mickey Rourke', + 'Miley Cyrus', + 'Morgan Freeman', + 'Nicole Kidman', + 'Nicole Richie', + 'Orlando Bloom', + 'Reese Witherspoon', + 'Renee Zellweger', + 'Ricky Martin', + 'Robert Gates', + 'Sania Mirza', + 'Scarlett Johansson', + 'Shahrukh Khan', + 'Shakira', + 'Sharon Stone', + 'Silvio Berlusconi', + 'Stephen Colbert', + 'Steve Carell', + 'Tom Cruise', + 'Uma Thurman', + 'Victoria Beckham', + 'Viggo Mortensen', + 'Will Smith', + 'Zac Efron'] + +COUNTS = [108, + 103, + 214, + 171, + 101, + 299, + 268, + 117, + 126, + 300, + 246, + 160, + 195, + 143, + 122, + 134, + 145, + 112, + 168, + 168, + 246, + 187, + 149, + 200, + 152, + 100, + 130, + 135, + 115, + 227, + 102, + 253, + 110, + 150, + 157, + 168, + 101, + 230, + 129, + 107, + 175, + 300, + 108, + 132, + 132, + 132, + 153, + 134, + 257, + 195, + 135, + 199, + 354, + 102, + 108, + 154, + 210, + 146, + 102, + 119, + 367, + 108, + 185, + 188, + 260, + 157, + 133, + 143, + 100, + 128, + 273, + 152, + 201, + 206, + 121, + 124, + 166, + 197, + 167, + 134, + 112, + 128, + 193] diff --git a/skdata/utils/__init__.py b/skdata/utils/__init__.py index ad372e3..cac7c84 100644 --- a/skdata/utils/__init__.py +++ b/skdata/utils/__init__.py @@ -1,6 +1,6 @@ from my_path import get_my_path, get_my_path_basename from xml2x import xml2dict, xml2list -from download_and_extract import download, extract, download_and_extract +from download_and_extract import download, download_boto, extract, download_and_extract # -- old utils.py diff --git a/skdata/utils/download_and_extract.py b/skdata/utils/download_and_extract.py index 51e6fbb..aee9f92 100644 --- a/skdata/utils/download_and_extract.py +++ b/skdata/utils/download_and_extract.py @@ -2,11 +2,15 @@ # Authors: Nicolas Pinto # Nicolas Poilvert +# Daniel Yamins # License: BSD 3 clause from urllib2 import urlopen from os import path import hashlib +import urlparse + +import boto import archive @@ -57,6 +61,21 @@ def download(url, output_filename, sha1=None, verbose=True): verify_sha1(output_filename, sha1) +def download_boto(url, credentials, output_filename, sha1=None): + """Downloads file from S3 via boto at `url` and write it in `output_dirname`""" + + conn = boto.connect_s3(*credentials) + url = urlparse.urlparse(url) + bucketname = url.netloc.split('.')[0] + file = url.path.strip('/') + bucket = conn.get_bucket(bucketname) + key = bucket.get_key(file) + key.get_contents_to_filename(output_filename) + + if sha1 is not None: + verify_sha1(output_filename, sha1) + + def extract(archive_filename, output_dirname, sha1=None, verbose=True): """Extracts `archive_filename` in `output_dirname`.