From 7ab47c909c260310afaebded2be71cb467ad4603 Mon Sep 17 00:00:00 2001 From: wang shuxi Date: Fri, 7 May 2021 11:06:48 +0800 Subject: [PATCH 1/4] some env problems settled --- .../samples_with_missing_skeletons.txt | 302 ----------------- main.py | 6 +- processor/io.py | 303 +++++++++++------- processor/processor.py | 6 +- processor/recognition.py | 6 +- torchlight/__init__.py | 1 + .../__pycache__/__init__.cpython-36.pyc | Bin 0 -> 122 bytes torchlight/__pycache__/io.cpython-36.pyc | Bin 0 -> 7038 bytes torchlight/build/lib/torchlight/__init__.py | 1 + torchlight/build/lib/torchlight/gpu.py | 35 ++ torchlight/build/lib/torchlight/io.py | 203 ++++++++++++ torchlight/dist/torchlight-1.0-py3.6.egg | Bin 0 -> 8580 bytes torchlight/gpu.py | 35 ++ torchlight/io.py | 203 ++++++++++++ torchlight/torchlight.egg-info/PKG-INFO | 10 + torchlight/torchlight.egg-info/SOURCES.txt | 8 + .../torchlight.egg-info/dependency_links.txt | 1 + torchlight/torchlight.egg-info/top_level.txt | 1 + 18 files changed, 703 insertions(+), 418 deletions(-) delete mode 100644 data/NTU-RGB+D/samples_with_missing_skeletons.txt create mode 100644 torchlight/__init__.py create mode 100644 torchlight/__pycache__/__init__.cpython-36.pyc create mode 100644 torchlight/__pycache__/io.cpython-36.pyc create mode 100644 torchlight/build/lib/torchlight/__init__.py create mode 100644 torchlight/build/lib/torchlight/gpu.py create mode 100644 torchlight/build/lib/torchlight/io.py create mode 100644 torchlight/dist/torchlight-1.0-py3.6.egg create mode 100644 torchlight/gpu.py create mode 100644 torchlight/io.py create mode 100644 torchlight/torchlight.egg-info/PKG-INFO create mode 100644 torchlight/torchlight.egg-info/SOURCES.txt create mode 100644 torchlight/torchlight.egg-info/dependency_links.txt create mode 100644 torchlight/torchlight.egg-info/top_level.txt diff --git a/data/NTU-RGB+D/samples_with_missing_skeletons.txt b/data/NTU-RGB+D/samples_with_missing_skeletons.txt deleted file mode 100644 index 5ad472e..0000000 --- a/data/NTU-RGB+D/samples_with_missing_skeletons.txt +++ /dev/null @@ -1,302 +0,0 @@ -S001C002P005R002A008 -S001C002P006R001A008 -S001C003P002R001A055 -S001C003P002R002A012 -S001C003P005R002A004 -S001C003P005R002A005 -S001C003P005R002A006 -S001C003P006R002A008 -S002C002P011R002A030 -S002C003P008R001A020 -S002C003P010R002A010 -S002C003P011R002A007 -S002C003P011R002A011 -S002C003P014R002A007 -S003C001P019R001A055 -S003C002P002R002A055 -S003C002P018R002A055 -S003C003P002R001A055 -S003C003P016R001A055 -S003C003P018R002A024 -S004C002P003R001A013 -S004C002P008R001A009 -S004C002P020R001A003 -S004C002P020R001A004 -S004C002P020R001A012 -S004C002P020R001A020 -S004C002P020R001A021 -S004C002P020R001A036 -S005C002P004R001A001 -S005C002P004R001A003 -S005C002P010R001A016 -S005C002P010R001A017 -S005C002P010R001A048 -S005C002P010R001A049 -S005C002P016R001A009 -S005C002P016R001A010 -S005C002P018R001A003 -S005C002P018R001A028 -S005C002P018R001A029 -S005C003P016R002A009 -S005C003P018R002A013 -S005C003P021R002A057 -S006C001P001R002A055 -S006C002P007R001A005 -S006C002P007R001A006 -S006C002P016R001A043 -S006C002P016R001A051 -S006C002P016R001A052 -S006C002P022R001A012 -S006C002P023R001A020 -S006C002P023R001A021 -S006C002P023R001A022 -S006C002P023R001A023 -S006C002P024R001A018 -S006C002P024R001A019 -S006C003P001R002A013 -S006C003P007R002A009 -S006C003P007R002A010 -S006C003P007R002A025 -S006C003P016R001A060 -S006C003P017R001A055 -S006C003P017R002A013 -S006C003P017R002A014 -S006C003P017R002A015 -S006C003P022R002A013 -S007C001P018R002A050 -S007C001P025R002A051 -S007C001P028R001A050 -S007C001P028R001A051 -S007C001P028R001A052 -S007C002P008R002A008 -S007C002P015R002A055 -S007C002P026R001A008 -S007C002P026R001A009 -S007C002P026R001A010 -S007C002P026R001A011 -S007C002P026R001A012 -S007C002P026R001A050 -S007C002P027R001A011 -S007C002P027R001A013 -S007C002P028R002A055 -S007C003P007R001A002 -S007C003P007R001A004 -S007C003P019R001A060 -S007C003P027R002A001 -S007C003P027R002A002 -S007C003P027R002A003 -S007C003P027R002A004 -S007C003P027R002A005 -S007C003P027R002A006 -S007C003P027R002A007 -S007C003P027R002A008 -S007C003P027R002A009 -S007C003P027R002A010 -S007C003P027R002A011 -S007C003P027R002A012 -S007C003P027R002A013 -S008C002P001R001A009 -S008C002P001R001A010 -S008C002P001R001A014 -S008C002P001R001A015 -S008C002P001R001A016 -S008C002P001R001A018 -S008C002P001R001A019 -S008C002P008R002A059 -S008C002P025R001A060 -S008C002P029R001A004 -S008C002P031R001A005 -S008C002P031R001A006 -S008C002P032R001A018 -S008C002P034R001A018 -S008C002P034R001A019 -S008C002P035R001A059 -S008C002P035R002A002 -S008C002P035R002A005 -S008C003P007R001A009 -S008C003P007R001A016 -S008C003P007R001A017 -S008C003P007R001A018 -S008C003P007R001A019 -S008C003P007R001A020 -S008C003P007R001A021 -S008C003P007R001A022 -S008C003P007R001A023 -S008C003P007R001A025 -S008C003P007R001A026 -S008C003P007R001A028 -S008C003P007R001A029 -S008C003P007R002A003 -S008C003P008R002A050 -S008C003P025R002A002 -S008C003P025R002A011 -S008C003P025R002A012 -S008C003P025R002A016 -S008C003P025R002A020 -S008C003P025R002A022 -S008C003P025R002A023 -S008C003P025R002A030 -S008C003P025R002A031 -S008C003P025R002A032 -S008C003P025R002A033 -S008C003P025R002A049 -S008C003P025R002A060 -S008C003P031R001A001 -S008C003P031R002A004 -S008C003P031R002A014 -S008C003P031R002A015 -S008C003P031R002A016 -S008C003P031R002A017 -S008C003P032R002A013 -S008C003P033R002A001 -S008C003P033R002A011 -S008C003P033R002A012 -S008C003P034R002A001 -S008C003P034R002A012 -S008C003P034R002A022 -S008C003P034R002A023 -S008C003P034R002A024 -S008C003P034R002A044 -S008C003P034R002A045 -S008C003P035R002A016 -S008C003P035R002A017 -S008C003P035R002A018 -S008C003P035R002A019 -S008C003P035R002A020 -S008C003P035R002A021 -S009C002P007R001A001 -S009C002P007R001A003 -S009C002P007R001A014 -S009C002P008R001A014 -S009C002P015R002A050 -S009C002P016R001A002 -S009C002P017R001A028 -S009C002P017R001A029 -S009C003P017R002A030 -S009C003P025R002A054 -S010C001P007R002A020 -S010C002P016R002A055 -S010C002P017R001A005 -S010C002P017R001A018 -S010C002P017R001A019 -S010C002P019R001A001 -S010C002P025R001A012 -S010C003P007R002A043 -S010C003P008R002A003 -S010C003P016R001A055 -S010C003P017R002A055 -S011C001P002R001A008 -S011C001P018R002A050 -S011C002P008R002A059 -S011C002P016R002A055 -S011C002P017R001A020 -S011C002P017R001A021 -S011C002P018R002A055 -S011C002P027R001A009 -S011C002P027R001A010 -S011C002P027R001A037 -S011C003P001R001A055 -S011C003P002R001A055 -S011C003P008R002A012 -S011C003P015R001A055 -S011C003P016R001A055 -S011C003P019R001A055 -S011C003P025R001A055 -S011C003P028R002A055 -S012C001P019R001A060 -S012C001P019R002A060 -S012C002P015R001A055 -S012C002P017R002A012 -S012C002P025R001A060 -S012C003P008R001A057 -S012C003P015R001A055 -S012C003P015R002A055 -S012C003P016R001A055 -S012C003P017R002A055 -S012C003P018R001A055 -S012C003P018R001A057 -S012C003P019R002A011 -S012C003P019R002A012 -S012C003P025R001A055 -S012C003P027R001A055 -S012C003P027R002A009 -S012C003P028R001A035 -S012C003P028R002A055 -S013C001P015R001A054 -S013C001P017R002A054 -S013C001P018R001A016 -S013C001P028R001A040 -S013C002P015R001A054 -S013C002P017R002A054 -S013C002P028R001A040 -S013C003P008R002A059 -S013C003P015R001A054 -S013C003P017R002A054 -S013C003P025R002A022 -S013C003P027R001A055 -S013C003P028R001A040 -S014C001P027R002A040 -S014C002P015R001A003 -S014C002P019R001A029 -S014C002P025R002A059 -S014C002P027R002A040 -S014C002P039R001A050 -S014C003P007R002A059 -S014C003P015R002A055 -S014C003P019R002A055 -S014C003P025R001A048 -S014C003P027R002A040 -S015C001P008R002A040 -S015C001P016R001A055 -S015C001P017R001A055 -S015C001P017R002A055 -S015C002P007R001A059 -S015C002P008R001A003 -S015C002P008R001A004 -S015C002P008R002A040 -S015C002P015R001A002 -S015C002P016R001A001 -S015C002P016R002A055 -S015C003P008R002A007 -S015C003P008R002A011 -S015C003P008R002A012 -S015C003P008R002A028 -S015C003P008R002A040 -S015C003P025R002A012 -S015C003P025R002A017 -S015C003P025R002A020 -S015C003P025R002A021 -S015C003P025R002A030 -S015C003P025R002A033 -S015C003P025R002A034 -S015C003P025R002A036 -S015C003P025R002A037 -S015C003P025R002A044 -S016C001P019R002A040 -S016C001P025R001A011 -S016C001P025R001A012 -S016C001P025R001A060 -S016C001P040R001A055 -S016C001P040R002A055 -S016C002P008R001A011 -S016C002P019R002A040 -S016C002P025R002A012 -S016C003P008R001A011 -S016C003P008R002A002 -S016C003P008R002A003 -S016C003P008R002A004 -S016C003P008R002A006 -S016C003P008R002A009 -S016C003P019R002A040 -S016C003P039R002A016 -S017C001P016R002A031 -S017C002P007R001A013 -S017C002P008R001A009 -S017C002P015R001A042 -S017C002P016R002A031 -S017C002P016R002A055 -S017C003P007R002A013 -S017C003P008R001A059 -S017C003P016R002A031 -S017C003P017R001A055 -S017C003P020R001A059 diff --git a/main.py b/main.py index ee1a0a2..4575f4a 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,7 @@ import argparse import sys import torchlight -from torchlight import import_class +from torchlight.io import import_class if __name__ == '__main__': @@ -10,7 +10,8 @@ processors = dict() processors['recognition'] = import_class('processor.recognition.REC_Processor') - processors['demo'] = import_class('processor.demo.Demo') + #processors['recognition'] = import_class('processor.processor.Processor.io.__init__') + #processors['demo'] = import_class('processor.demo.Demo') subparsers = parser.add_subparsers(dest='processor') for k, p in processors.items(): @@ -20,5 +21,6 @@ # start Processor = processors[arg.processor] + print(sys.argv[:]) p = Processor(sys.argv[2:]) p.start() diff --git a/processor/io.py b/processor/io.py index fb9e4f8..c753ca1 100644 --- a/processor/io.py +++ b/processor/io.py @@ -1,116 +1,203 @@ -import sys -import os +#!/usr/bin/env python import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict import yaml import numpy as np - +# torch import torch import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable -import torchlight -from torchlight import str2bool -from torchlight import DictAction -from torchlight import import_class - +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py class IO(): - - def __init__(self, argv=None): - - self.load_arg(argv) - self.init_environment() - self.load_model() - self.load_weights() - self.gpu() - - def load_arg(self, argv=None): - parser = self.get_parser() - - # load arg form config file - p = parser.parse_args(argv) - if p.config is not None: - # load config file - with open(p.config, 'r') as f: - default_arg = yaml.load(f) - - # update parser from config file - key = vars(p).keys() - for k in default_arg.keys(): - if k not in key: - print('Unknown Arguments: {}'.format(k)) - assert k in key - - parser.set_defaults(**default_arg) - - self.arg = parser.parse_args(argv) - - def init_environment(self): - self.save_dir = os.path.join(self.arg.work_dir, - self.arg.max_hop_dir, - self.arg.lamda_act_dir) - self.io = torchlight.IO(self.save_dir, save_log=self.arg.save_log, print_log=self.arg.print_log) - self.io.save_arg(self.arg) - - # gpu - if self.arg.use_gpu: - gpus = torchlight.visible_gpu(self.arg.device) - torchlight.occupy_gpu(gpus) - self.gpus = gpus - self.dev = "cuda:0" - else: - self.dev = "cpu" - - def load_model(self): - self.model1 = self.io.load_model(self.arg.model1, **(self.arg.model1_args)) - self.model2 = self.io.load_model(self.arg.model2, **(self.arg.model2_args)) - - def load_weights(self): - if self.arg.weights1: - self.model1 = self.io.load_weights(self.model1, self.arg.weights1, self.arg.ignore_weights) - self.model2 = self.io.load_weights(self.model2, self.arg.weights2, self.arg.ignore_weights) - - def gpu(self): - # move modules to gpu - self.model1 = self.model1.to(self.dev) - self.model2 = self.model2.to(self.dev) - for name, value in vars(self).items(): - cls_name = str(value.__class__) - if cls_name.find('torch.nn.modules') != -1: - setattr(self, name, value.to(self.dev)) - - # model parallel - if self.arg.use_gpu and len(self.gpus) > 1: - self.model1 = nn.DataParallel(self.model1, device_ids=self.gpus) - self.model2 = nn.DataParallel(self.model2, device_ids=self.gpus) - - def start(self): - self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg)))) - - @staticmethod - def get_parser(add_help=False): - - #region arguments yapf: disable - # parameter priority: command line > config > default - parser = argparse.ArgumentParser( add_help=add_help, description='IO Processor') - - parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results') - parser.add_argument('-c', '--config', default=None, help='path to the configuration file') - - # processor - parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not') - parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing') - - # visulize and debug - parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not') - parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not') - - # model - parser.add_argument('--model1', default=None, help='the model will be used') - parser.add_argument('--model2', default=None, help='the model will be used') - parser.add_argument('--model1_args', action=DictAction, default=dict(), help='the arguments of model') - parser.add_argument('--model2_args', action=DictAction, default=dict(), help='the arguments of model') - parser.add_argument('--weights', default=None, help='the weights for network initialization') - parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization') - #endregion yapf: enable - - return parser + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/processor/processor.py b/processor/processor.py index 03fc2cf..690d846 100644 --- a/processor/processor.py +++ b/processor/processor.py @@ -8,9 +8,9 @@ import torch.optim as optim import torchlight -from torchlight import str2bool -from torchlight import DictAction -from torchlight import import_class +from torchlight.io import str2bool +from torchlight.io import DictAction +from torchlight.io import import_class from .io import IO diff --git a/processor/recognition.py b/processor/recognition.py index d3905af..06fb642 100644 --- a/processor/recognition.py +++ b/processor/recognition.py @@ -13,9 +13,9 @@ import torch.optim as optim import torchlight -from torchlight import str2bool -from torchlight import DictAction -from torchlight import import_class +from torchlight.io import str2bool +from torchlight.io import DictAction +from torchlight.io import import_class from .processor import Processor diff --git a/torchlight/__init__.py b/torchlight/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/torchlight/__init__.py @@ -0,0 +1 @@ + diff --git a/torchlight/__pycache__/__init__.cpython-36.pyc b/torchlight/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3c790e409f2747726cc89be56d654bae0ca6888 GIT binary patch literal 122 zcmXr!<>k8YcX9$F0|UcjAcg~wfCCU0vjB+{hF}IwM!%H|MId1W@k?DlBR@A)zr46Y y-!WL%-PupSB)=#*BPTOGqeMSGJ~J<~BtBlRpz;=nO>TZlX-=vg$be!XW&i-0_#0{f literal 0 HcmV?d00001 diff --git a/torchlight/__pycache__/io.cpython-36.pyc b/torchlight/__pycache__/io.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..967d1892bfc511e3070341c8edcd71a9781033d2 GIT binary patch literal 7038 zcmaJ`O^_SMb?%-&3Gzxk_McRiz|3<>Zo6u0FVOihao;zWD50s`{D}bM#SFF8h0L@Uy#)jjjIg z>3;qCz3;u(`?Zyo;5Wbf=?{K+Nr?XyCw?W=-$Y9P51A55_C%PMGL+o5LW@f~w7GOb zhf6nfx%5I0rPV9+{m_^Bx5co?V?kKrvK*GVT;OsxY;}3(8l;XCl0)0#!o!qAIHu zl$TURt)hGhv)9x*X0J!0dwFVyFDv;#)Ha^uykfg1P3c{&A|0uFomPgX|9(?MCyor>4j4TqwF0kqqWhwia1hX9Ff&eEc7R;)uq6#qL zsgf#Vq@Y$*1*NZ6)f&p8T2~t=1NEvpr#3OOq+U?x(Nk6z)J2pl>P2-4WktQDE~8vk zFRLpk*OZSVT{YI*?>=lfpoV@_T1a=1(!WMFmSfQnV+*+)+sLiV?%GpB~T#m!lP?eLlXu5U(=kP-5|VAN(2v9u9n6-V+lW=dm+6we;n&y+`K~ zT}#QXJ+`JejTD*Fb*VmaGLQ8|jGb;_D((pypEyUmWBZ0VCmhl*A2!6TJiDjjySqD_AskD@v&NI>{CDQ>|$pZZS07dpX`Xo?_u_3 zUUL&`5_KNtU?raMiyd)X8Ov^Y>`;40#F--`zmn$?JTU;ONRY09EHFCEllieM%$6^s-1-KYTiQRMk;Gc@izoNPYFu z#;J4XeV6Vw<7%8_)pjRd@`T4~n!aVyd%C!2sFt^a6Lah| z!eR%}acTn9X}4iqK#%c;1BG>JW!6p&LkwV_aj7wAX8U0!H#m)1rzy-UvJ;kOV{sGk zDi4i}aZyAU3=H-#Qk}UVQ)b#{OZ*LXVi{)AS?I_@OQ}yjbLi&6IkYINmSRi`}n+{tp za$C;OYHyF+r<3bXC&oJ5t+~cYa07Y^OR)(s-k{SufQ@j}us@iio2R36*u#yG`N+A_ zKFfmKP7MxvA7L$qazT(rsLvJ*Ba34Y9+NY~$GC?|#`udv2*$1aAg@ns;D+ovgjJ8$ zCUhhRnq`ljTjKH6oj-Xr{a09$&bI%?;5c+3`_b}*`YSw`3&!7}uHU4L^h&=<8JWC1 zWBXEI>i6IH0lHHUnUEK8W20bxu>Cg{=b~Y{Dn8dcsNix46(wB`kXd#uV3?EH3moRN zF?P9`+x38)d};Wmpd5T_=jg>(=gzP06dDmP8I7)?!e~^@(MSn=58+V2*(oTXy%%_7 z=>n!NKndXEbdyf3ub@}If(*8|N_A$*RjQrli+&A#`$%&zdUehZwr&3(6pdbaxF1z{ zxVqm=tGn=Fs>EGYMGGif^G{sUb=s4FL({=DVQuKpY{r?Ix72LokVwvLFKM^tkH1de zou=9Ze49)|gyEE|s=#&nq@J>zo8aY~Q(nXhrt-}`L1lik;4C|L>{D1|YjKOX@u^I2 zphfl$by#5^wOp=mknhngDjU|q=>}P;Tc#W0S#a9nG}_7~Pxe#ikJJ|yssi)nmjQg4B;?>S#<(K7PH|qm2}ifw93OA(n)MYVS>I-+3Aho zq)=4C%RR&J%r5!4o-fB1dQODUIa$xiZ6F4Xr7K`}NRts8*y0vUD%c~pu`>|>R9OUG zf0mfme$+Z(YCged^N5AK{^iJmb9D%bRXdZK(}h zW2Taub{>fl5TSR;U-!M>HQ(`F-}lRCJHAaVO*0pgNeW3Fcj?nQvK%~BNQvIVU8yEOLr*iQg(n0Lp`dK70qo1t)^odvW*MdkJliw z!~`Wj03~S9GQ)ef4B)nc=Hi-${3wga0 znc)?OqH)wMybZw2EAfwL11%~Dlc3n-Z%?epuj6gN0(H)|k#&xlHJL}~19`+GrL*^+ zhAHethz0rShb&GQkQqy2FflbYWZ5{>ah8Fr^9A?;sd3}y_;kj{09@RF_o+wdTkw%i zMA*)JB>bG8>LZ)p8bCeHM(B6EmeJ!P92bg>rZ@%OPuzn*uI)@w9=jg@A~NkJjM$(&|PK9B9$Hf+5jiX+O}l(Fgh6g9Il<9jQj zb(hi#nEixi@;%Uv{x=f4Xb5>+zT?^zJ)kM?P(})-e@Gcg7ydoa)G=MQ{#$C6T1l@5 z+W>J)^O*Sox_ literal 0 HcmV?d00001 diff --git a/torchlight/build/lib/torchlight/__init__.py b/torchlight/build/lib/torchlight/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/torchlight/build/lib/torchlight/__init__.py @@ -0,0 +1 @@ + diff --git a/torchlight/build/lib/torchlight/gpu.py b/torchlight/build/lib/torchlight/gpu.py new file mode 100644 index 0000000..306c391 --- /dev/null +++ b/torchlight/build/lib/torchlight/gpu.py @@ -0,0 +1,35 @@ +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/torchlight/build/lib/torchlight/io.py b/torchlight/build/lib/torchlight/io.py new file mode 100644 index 0000000..c753ca1 --- /dev/null +++ b/torchlight/build/lib/torchlight/io.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/torchlight/dist/torchlight-1.0-py3.6.egg b/torchlight/dist/torchlight-1.0-py3.6.egg new file mode 100644 index 0000000000000000000000000000000000000000..d7e16926c97cdb66afe323750bcd72b7fa41499c GIT binary patch literal 8580 zcma)>bx>T(`n3lQ?jAI_YjA=EcV~d$?(Py?g1fuh;O-U%CxPG)+$FdK_{i_py(cI4 zo~o~>x@LCOJge8@#h2B@ULZDa<==u^4}BND+p1VaS7AA67i%c#^&k;-$UZJZk-bLFZ{*ET++Y zQ_}7>t8eq=Stpj%4N<^8H?_NgaXWW;>1WTYY`%v5OolCF#lN`@JgyEjd~rMSGJoC6 ze_X35s40tzsW3UaJG+eZA_xhgeOzyMt0fkZM2Nc}#Ivd!mq+?>?MP9*_6#b`z~h0s zk`d&oOArSP6atu=>Nt%*4V_>$V1QC!q&>^m$^+yz~5U5=PZSxCPF=#dAWo|B=GK-69j_cq9h#uxZ3QaCw z>uP&2_6$rHwn9na5E#}=#ZIl3)ogIv72>4;OnoeuMkE2QiT(-3y)M)m&|c6m`DKpn zZIb>`51_{RjIrEfqz}pdGNk-h@*5ai)tP#keZ?frxia=9QkD6H0v{04tsylLJ?kMB zO6rUyC2UeTr;lu&8nl@;Mt4&NZM!s?%|59$xC|)~n4DJ9UW217jcTc`OcIRGT6;YG zl~fIM$Oj7yxjF8)8yruD`}`HAXn0YHTmr%-+G*lHa5{mJpFlfmy9wZ-&l_mVpnnF;z@q> zF^5>2mN%A9-VSG~GF8|zO{JBl9{(4?Vk_z4m=jOdb+$rN% zuuNYntjQ**1LDqJlsS^Zl-de=Exd9s`0g#H+BMP_rir9QZMArQmAiMxdi~~&q-XQT zb>-ODt`F6VtW0|!kj$Mw51KsIs#WE7%-F^CN7o38nLc1`wmu2dfAf-V?ZFpoPT1ItK<=4-Pk!FO@+ z5J)_N^Mr)=lqFGGy;e?TWUn;z?ZdAuAD7OVM{kH-h%)3k%*4v0reGR%$w`dxnQNZ? zNqP~ec9cKliXihCU)Bar*Q|$4%{Kgm4Y*)`7=t`Vj+4HnfC&}>Bbg{%vsbx#QGWj7 zM)H+P#fzn(x?Nz4z-degp3MP?WkM8@GX$70?)qLtUA8=On}bkvs8r$dt=1;<(x6V& zAPt-^Y?jxRteU%rTtMO0MrzsAEFZxb7Yo}XGBQ_47K970eaix&Y!nJqjlK3)iZIT# zcQ%`GT27O+;0(^lCTI$Z(A3PDYFF8K>1pzYaT|_3F9q4TBZDK@4UI5<&e=)>VDXaN z2dQo6^z4xZbc44_0B?_wiSdYSO45!X*u!DjXf~K#qm$ zmFH}Vf`Mxqa4_vW@5Hui1m2U%r6KFN+dt;BoYHP)5xHttN}It!a< zSKwVl7$WiR5T$#ao!{S{TW(+RCW!?n`}q=I{d{8G^j7ZCA=ZkHqn>$s;2H_qw1cdb zf5z=K-Tq?o%Ev;}l{kmDahb3Q?Hw$_x9lx;qAv+c9PA9m+~zBZT1F_?t8F)~v3eC_ z!m&S7G#E0~_x7F{CZW7aL~ihk19O!`meyfcI#EuI!6gF6RDlBS`}i?Pl7&J!t&C%N z+o7nDXjzE*FvZ;-2tv3$-__LHp;4nRy71l-Z}}%4Q)}M%^7O8Ay|1>N{9L%J*e3if z$k)G*T;O1n1!~I3HMeHHf1F}#xxt6BFQFd572GS3SrGFXDFO}2-x)@I7E=x6F4QLE zkLd+jp82yb)RYZ=nH~imr{Io#IJ}|q4&QPbW@sb^$ViJ)17l+On6+8byFR4E@^C4k zRC)RYY@o{;uhqI?S=|O6XA`og9rhRa*0$$4m1iabr#JiRW?0-ZLXFN9DNs|^NS^8` z-B^DNi^Zi0<5nj+N$7$A8>V6PtI!}pp&0a`LEI+HE3psoD)zpdhdLcf2g)Hyna^mLdCu3ZZAPH_h&_XRaa= zm>S~H7yCZj>3LR3~ z_*tD+7qc!&NH40P@EOj|G^U3{cCyU5q_G#0PXTpFrEOW7Z}Db~4a__BIW&dWpQ1KwxA^`PMA+^MSg(~-O-V#82|hI4leyU=6jFCmKB zdy0x1kL^xQY2Dx>>0L4TtBAU3oLHm!;UCy1;2u;M2yWa zbwi$6#GeMY^_h8v6}A%6ytB&2Nrl%ToOh(I$LI&=B{D`k`z*Kgd!h!?vrM7;?Rl;F zFtTvb6P}BEE+^c|LU2%>7v}^(DAC>2PAhlM=iO-igH_bR(J5NMQ~?3E4!Zz&Z&QV< ziQ zmB}dxwutCpK(xBBL`vla_aa?u@GN0^*ODUMmR zEU}p)!ZL>9jFLN$R`>x{!?Dtujj5cA*T8Ix>lI*Kr3=*}$%vM0tIU^XckziG$nbxa zzti&;u#q>abq(7*Pl>_!^jm52>llI77fodSQj&zcOs2n=CVG1I9!3U6<|cZ2f7O*n z_8!jWcD9V{oG*2Y5&O>ewk1=AoW8NXzWdlmXdy@xA)PTu3T3_Aw&mz}b<$+S<8Tb=nzyp7@plA2s6zUgh} z|Bl?{EaydB5=x=fB^(q;mSvE4V{k&c%84gRH83@zRDl_Wvb?TqN|u4lh{Px7!aEEt6H!OBnuTNph7A~$En)NjnrK&} z)5|N($Sq2N4Br;rR`j~FF>A7&ntqolfCXnlCvFMG&t5*r7*i$=A z0yJME-CzlY7J+o57vmSrE50{!5)@LOw=CT=i@>$}q_Z1+?^2ac`w7&(A zub<{{w)ideJbi`T%zqWPzx}B-(SsJ5Z(?$uI~iBSMVisQJvxj(cxps6P^~mn|4n)N zZMUNtv43rL&=IvJOb8}OY<6={Bx_Ja&_Atblp9`kG#0(1Ga@<@VSbni$C97)%u&22 z2ZP)19a;O<;UadHne6u!#q)xP;IsV)4;Kj=!W=mDQhuc$se@uI}*AaOMm92$ij#nybsELFBxR!*cL6wk@MivfonWuV% zE83kUzDjm}FV#25*%HK$PBZ#enHOxBr!^u(`p*{MVaG{tq5ghS1XvfRd_VyJB3=Uk zH2=J1Urq#n>{#Gv!t2i-r`fMq@bqg@t8X#HgcJAP=&(vI#UIt+E!G*2&@ENGU)-QG zVnxFd=z&BW)AxdX0Qfs5k!mxvQ!H$6#j8}mD|D201ow8`>e|* z+RE8H6_b{Ia(egB1-ctg!gllJe%R{i`;^`s633%SV(b1o z^Un3TFVn|CE{ZdFTafx_GSa#G)9N76<#W%f5lKyR8co;+wJg?dh?^95{w4ec6QV}q zv>rNkZNO7`X@uJ4q|sHn=H~;KoGZ|{+$l`&n%PxU$-tv}7WKAU(OD7m>OrHCY*r}a zg4Hp~&%&ug*u2YqU_181ieh#{ikO*c{sUgrq@)ho?594o9cse9x*%jo204ZxHB8F& z4LRXNF&KpiLowB9K0~$2JmzqhTK`pS=5RzqO#0Cc9tU&bd70^!HW%o zd$k_{Ez!Pm-1__Ch8-6tu1nJ3g4Lc=xi>$S%!|E8(yc9Xw^cB!&@+6TCRpk`XGC($ z5z847w2t9IqN1vTW@W;iutpw>_rL?;TA{v9Xs!vfaYrx?>?j#Z+V{S(nhth*qWVS) zss)Z51ebW-jHKi^eFY2O`UO&jKLvC=;n!3tFp7|G}}@ot=eOav@8%ev-XzYI%lkUm99O-O*@S8ErbaQ& z&YbC>MXg@?q7zf4kLZa!qK}(aArY#p+x(9z3y2C!_E+smB`5+K)Jxn8L(7WvIY%{t zH#PS!|Ht(Dk`DO|>#fVH=f-~da+Ot(EMxj<4QZ2R2ZjdUP}c)J({ncaeSd_SqDDlP zyN%WXXj0h656Va;2Alch$5q)UC#H#!gUmM*w1-ZUn5IN;du1+@-Yz9tD3^ybiDG;9^}WV6`4IUzWP{?L(c!8&5gugjlX2 zLXo&x8|D_Z0KG3Mn zA;RwQ=pdsi7bnP}>MPJ!w zhJJYK@RBkNi)^$F;^ZWqJO~o2<%Te?hu#U*5K&K{J|>*X;6skAzJ3*C@43TN#Y#?+ zar{87v=O0O5-0@a8QqPK590K2=IbrRfc=*0+}Y*y`53ci zl1YqB^JZIWbvWeVjDhfiORj^)8+2v{R^n(jXL?CN0@Pe|1WGNstok=otX2(zw=98< z;<4hTIi28zny&+rwLg*#1J7_}S%S6G<1|dE>;rEY`^qJ(*bTt}W@##i#C$Ea7r4L# zkjfdr&I=SJ&yU!2%FLapr@`E4;2laZ!b)+M-ls`U$T54jP13=)t+XIKdSyU|s@E-46ejrm|+dsNJE6WBVO zGJVA8KP{8SuX@Cb8bQTudwTcws;y}*!m@qX% zAV~!e%PLRoioY!os1$=9OoSabTz#r|GR+K+Io)^A?$D|?{R6QgwujuGBUlc2Z1#=) zaSqu|^fn8}yJZ1C zNJ3}D*98d`fY+1}XY%df7n8n4piC|O){P>`71XESHszf?|TL=nGCj?X{(jQZ@ z$R5BzcrCo2gu-_i1y~@v5zxT5#FFp`MiOP-Lad-Iu~b?|x@D|yW|MEImqGUl8T~_Y z(=8rsop-$jivz)Qu~wWLo80S+SZp{D#+?+S=|!5Zs`0wY6S*%?&Nmn%GMD$(}wpFU|16)Sxt!bOv(!VaTXIB@D=`LecjIr}yqSI98>Qkv8P zg46+$yj)LNNZ`sK`vwh0b&B87Op|-NdpfnRUGO&M`EE||*}+lI>~shtfInd>mW)RA z)+3rs)*y@y4Qo^eFCHX5p_=T=mnZ@PPQRi}>-5R)W^Q^2+I9d>`*yjfJa`P1u$5Mn zeJox!)o5rg;AC0otd#Br(-Wwjzc~V7d|ka*dazu;Y%0|ZOB~G5*~(;`O=&QzbS^5F ziH%1WDyD)TI?&1S9+D)^yW0+MNl_^0wv!5(#FLota@Eh#`^IEErbbmA!-*?VA<8z9 zv5O`=*`@CRs4^wc96nZ+r7!lmBu-ZVhSFa+XMZIA0)f9DDM-fu7MvB+EgkZ(H#DX= z$uL6peQ3CEAMs$sZ0SWy*aT}N=t(t{t(=EJNU8E1?-j<{ z8rOcM=Nt#V=hP)5gyy!xy6O9VV%+&2$oGhbf5NGIJO z8S#0%RHaErH-94sNG$cCz*5oD%7R^6OO>ms8Ek#)>XTKCNMp*C5x!)7KnT4lUGKqS zi+uUIRmyDICBQp7y?w?zhCmu`8@1FU74unP#yc@Bzo~b^LR1GXDe~-%3_4no+XB!H z;bb&=dp>$om8F?5!PIh_Hk85mX<8rOJ>uY9Ji3y*UUORflN=7S9g4A-=9w{`x%)!jao@{;`%Ed2)~W37GkG`BL%#n} zqXObh8NQ;ZunMd!WfNJH1jDF>1(!#zbH;qwW0x!qMhx1zkKH5k0ugE)lyy#|vkMGm zPQr1S=5^GE*bBRXj*0;#|KuO+#@_1E9w<5{pD==##nwu>jVqdoCl7?2MqfECUO$b1 zFMI8PC~4>Rcke$j8eFfaGACY5xFfwPvdna>-#o<22zW$t$`!-zUCq@&Olx&G8$q#$ zoB?XX+m)sRb-z9d;j;oQJfFYHMtmS?%x$~((C#D$Awt2iO3ThW8cqk+TUT(lo0ye^(R4l~Q4 z!LC5HBA$m#m&i|%mC9@0g?YNf>ycnYI}DG;45g|Sgwo5kD=Q)q&|+1FQs)U(baTisXZ?vph8z^H}? z8r0)D_qA=emi6TN#WblyRp8%|g9?MgUYUJ(U`u~CzQf)kfQ<5QLQr|z;^}n|!PBm< zGN|G^sG~?0D%}uo7Lng(zLR}Lo9x9&xgIW)vS8N`zB9;}m8JF0g7KbC#%^EQkx||A zA+sw|B_*GWpV0elK1PAY><*U_|ZYN={eLCG>v@Txhp zVXx8$=fs-#7a)92DQPNS{8Vc3sb;LNA*{T!sXx~6khhzK3wN2Br8n6qxoBXdC7Wh5 zLK)=rg&rP8pM^RCXBv)*vMl#PT^=4v94m6DmX5E(kvqB`+*T&AHQ}BW04;QpwdPCI&4x5j_Cv7^-fT`?or=hSYv%d$ z%Mld<5)0~|W964~>i>-f|6gDKiEIBIGXIJHneYA!1^^oTn|{ZCOMCxR_&wbGmB#*i zH26!uWS9Rj@V~OzKc)VRX8)23`(5g9LG7O!zuWnh&i;Ee_=A7f_;*(OJO00-+Q0CJ z7yP#b_P5~nPmMow+rKo}UK(J(Y5Xh2{S*H)eEbWqeCeJ2hW}p_`8)l;`h$Pb5=eie z|3A8ge`^2P!~0A7`Jc7_yQBA~?4Ny>zhvuB|IzY)bzKx?U|@fdh%X=fmj|Ur{Ojuf E0UVex3;+NC literal 0 HcmV?d00001 diff --git a/torchlight/gpu.py b/torchlight/gpu.py new file mode 100644 index 0000000..306c391 --- /dev/null +++ b/torchlight/gpu.py @@ -0,0 +1,35 @@ +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/torchlight/io.py b/torchlight/io.py new file mode 100644 index 0000000..c753ca1 --- /dev/null +++ b/torchlight/io.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/torchlight/torchlight.egg-info/PKG-INFO b/torchlight/torchlight.egg-info/PKG-INFO new file mode 100644 index 0000000..53cafc2 --- /dev/null +++ b/torchlight/torchlight.egg-info/PKG-INFO @@ -0,0 +1,10 @@ +Metadata-Version: 1.0 +Name: torchlight +Version: 1.0 +Summary: A mini framework for pytorch +Home-page: UNKNOWN +Author: UNKNOWN +Author-email: UNKNOWN +License: UNKNOWN +Description: UNKNOWN +Platform: UNKNOWN diff --git a/torchlight/torchlight.egg-info/SOURCES.txt b/torchlight/torchlight.egg-info/SOURCES.txt new file mode 100644 index 0000000..1ee6009 --- /dev/null +++ b/torchlight/torchlight.egg-info/SOURCES.txt @@ -0,0 +1,8 @@ +setup.py +torchlight/__init__.py +torchlight/gpu.py +torchlight/io.py +torchlight.egg-info/PKG-INFO +torchlight.egg-info/SOURCES.txt +torchlight.egg-info/dependency_links.txt +torchlight.egg-info/top_level.txt \ No newline at end of file diff --git a/torchlight/torchlight.egg-info/dependency_links.txt b/torchlight/torchlight.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/torchlight/torchlight.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/torchlight/torchlight.egg-info/top_level.txt b/torchlight/torchlight.egg-info/top_level.txt new file mode 100644 index 0000000..c600430 --- /dev/null +++ b/torchlight/torchlight.egg-info/top_level.txt @@ -0,0 +1 @@ +torchlight From ebbf43b8ffb5af2ade97a0e442e5f8be0bfa8058 Mon Sep 17 00:00:00 2001 From: wang shuxi Date: Fri, 7 May 2021 22:35:03 +0800 Subject: [PATCH 2/4] settled some bugs --- README.md | 6 +- data/readme.md | 1 - main.py | 2 - processor/gpu.py | 35 ++ processor/io.py | 303 +++++++----------- torchlight/__init__.py | 9 +- .../__pycache__/__init__.cpython-36.pyc | Bin 122 -> 392 bytes torchlight/__pycache__/io.cpython-36.pyc | Bin 7038 -> 7050 bytes torchlight/dist/torchlight-1.0-py3.6.egg | Bin 8580 -> 8583 bytes 9 files changed, 154 insertions(+), 202 deletions(-) delete mode 100644 data/readme.md create mode 100644 processor/gpu.py diff --git a/README.md b/README.md index 66256bf..601db34 100644 --- a/README.md +++ b/README.md @@ -43,9 +43,9 @@ Test: python main.py recognition -c config/as_gcn/ntu-xsub/test.yaml For Cross-View, ``` -PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml -TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xsub/train.yaml -Test: python main.py recognition -c config/as_gcn/ntu-xsub/test.yaml +PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xview/train_aim.yaml +TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xview/train.yaml +Test: python main.py recognition -c config/as_gcn/ntu-xview/test.yaml ``` # Acknowledgement diff --git a/data/readme.md b/data/readme.md deleted file mode 100644 index 777d39a..0000000 --- a/data/readme.md +++ /dev/null @@ -1 +0,0 @@ -The filepath of data (NTU-RGB+D) diff --git a/main.py b/main.py index 4575f4a..bd402cb 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,6 @@ processors = dict() processors['recognition'] = import_class('processor.recognition.REC_Processor') - #processors['recognition'] = import_class('processor.processor.Processor.io.__init__') #processors['demo'] = import_class('processor.demo.Demo') subparsers = parser.add_subparsers(dest='processor') @@ -21,6 +20,5 @@ # start Processor = processors[arg.processor] - print(sys.argv[:]) p = Processor(sys.argv[2:]) p.start() diff --git a/processor/gpu.py b/processor/gpu.py new file mode 100644 index 0000000..306c391 --- /dev/null +++ b/processor/gpu.py @@ -0,0 +1,35 @@ +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/processor/io.py b/processor/io.py index c753ca1..b8d11ca 100644 --- a/processor/io.py +++ b/processor/io.py @@ -1,203 +1,116 @@ -#!/usr/bin/env python -import argparse -import os import sys -import traceback -import time -import warnings -import pickle -from collections import OrderedDict +import os +import argparse import yaml import numpy as np -# torch + import torch import torch.nn as nn -import torch.optim as optim -from torch.autograd import Variable -with warnings.catch_warnings(): - warnings.filterwarnings("ignore",category=FutureWarning) - import h5py +import torchlight +from torchlight.io import str2bool +from torchlight.io import DictAction +from torchlight.io import import_class + class IO(): - def __init__(self, work_dir, save_log=True, print_log=True): - self.work_dir = work_dir - self.save_log = save_log - self.print_to_screen = print_log - self.cur_time = time.time() - self.split_timer = {} - self.pavi_logger = None - self.session_file = None - self.model_text = '' - - # PaviLogger is removed in this version - def log(self, *args, **kwargs): - pass - # try: - # if self.pavi_logger is None: - # from torchpack.runner.hooks import PaviLogger - # url = 'http://pavi.parrotsdnn.org/log' - # with open(self.session_file, 'r') as f: - # info = dict( - # session_file=self.session_file, - # session_text=f.read(), - # model_text=self.model_text) - # self.pavi_logger = PaviLogger(url) - # self.pavi_logger.connect(self.work_dir, info=info) - # self.pavi_logger.log(*args, **kwargs) - # except: #pylint: disable=W0702 - # pass - - def load_model(self, model, **model_args): - Model = import_class(model) - model = Model(**model_args) - self.model_text += '\n\n' + str(model) - return model - - def load_weights(self, model, weights_path, ignore_weights=None): - if ignore_weights is None: - ignore_weights = [] - if isinstance(ignore_weights, str): - ignore_weights = [ignore_weights] - - self.print_log('Load weights from {}.'.format(weights_path)) - weights = torch.load(weights_path) - weights = OrderedDict([[k.split('module.')[-1], - v.cpu()] for k, v in weights.items()]) - - # filter weights - for i in ignore_weights: - ignore_name = list() - for w in weights: - if w.find(i) == 0: - ignore_name.append(w) - for n in ignore_name: - weights.pop(n) - self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) - - for w in weights: - self.print_log('Load weights [{}].'.format(w)) - - try: - model.load_state_dict(weights) - except (KeyError, RuntimeError): - state = model.state_dict() - diff = list(set(state.keys()).difference(set(weights.keys()))) - for d in diff: - self.print_log('Can not find weights [{}].'.format(d)) - state.update(weights) - model.load_state_dict(state) - return model - - def save_pkl(self, result, filename): - with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: - pickle.dump(result, f) - - def save_h5(self, result, filename): - with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: - for k in result.keys(): - f[k] = result[k] - - def save_model(self, model, name): - model_path = '{}/{}'.format(self.work_dir, name) - state_dict = model.state_dict() - weights = OrderedDict([[''.join(k.split('module.')), - v.cpu()] for k, v in state_dict.items()]) - torch.save(weights, model_path) - self.print_log('The model has been saved as {}.'.format(model_path)) - - def save_arg(self, arg): - - self.session_file = '{}/config.yaml'.format(self.work_dir) - - # save arg - arg_dict = vars(arg) - if not os.path.exists(self.work_dir): - os.makedirs(self.work_dir) - with open(self.session_file, 'w') as f: - f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) - yaml.dump(arg_dict, f, default_flow_style=False, indent=4) - - def print_log(self, str, print_time=True): - if print_time: - # localtime = time.asctime(time.localtime(time.time())) - str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str - - if self.print_to_screen: - print(str) - if self.save_log: - with open('{}/log.txt'.format(self.work_dir), 'a') as f: - print(str, file=f) - - def init_timer(self, *name): - self.record_time() - self.split_timer = {k: 0.0000001 for k in name} - - def check_time(self, name): - self.split_timer[name] += self.split_time() - - def record_time(self): - self.cur_time = time.time() - return self.cur_time - - def split_time(self): - split_time = time.time() - self.cur_time - self.record_time() - return split_time - - def print_timer(self): - proportion = { - k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) - for k, v in self.split_timer.items() - } - self.print_log('Time consumption:') - for k in proportion: - self.print_log( - '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) - ) - - -def str2bool(v): - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - - -def str2dict(v): - return eval('dict({})'.format(v)) #pylint: disable=W0123 - - -def _import_class_0(name): - components = name.split('.') - mod = __import__(components[0]) - for comp in components[1:]: - mod = getattr(mod, comp) - return mod - - -def import_class(import_str): - mod_str, _sep, class_str = import_str.rpartition('.') - __import__(mod_str) - try: - return getattr(sys.modules[mod_str], class_str) - except AttributeError: - raise ImportError('Class %s cannot be found (%s)' % - (class_str, - traceback.format_exception(*sys.exc_info()))) - - -class DictAction(argparse.Action): - def __init__(self, option_strings, dest, nargs=None, **kwargs): - if nargs is not None: - raise ValueError("nargs not allowed") - super(DictAction, self).__init__(option_strings, dest, **kwargs) - - def __call__(self, parser, namespace, values, option_string=None): - input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 - output_dict = getattr(namespace, self.dest) - for k in input_dict: - output_dict[k] = input_dict[k] - setattr(namespace, self.dest, output_dict) + + def __init__(self, argv=None): + + self.load_arg(argv) + self.init_environment() + self.load_model() + self.load_weights() + self.gpu() + + def load_arg(self, argv=None): + parser = self.get_parser() + + # load arg form config file + p = parser.parse_args(argv) + if p.config is not None: + # load config file + with open(p.config, 'r') as f: + default_arg = yaml.load(f) + + # update parser from config file + key = vars(p).keys() + for k in default_arg.keys(): + if k not in key: + print('Unknown Arguments: {}'.format(k)) + assert k in key + + parser.set_defaults(**default_arg) + + self.arg = parser.parse_args(argv) + + def init_environment(self): + self.save_dir = os.path.join(self.arg.work_dir, + self.arg.max_hop_dir, + self.arg.lamda_act_dir) + self.io = torchlight.io.IO(self.save_dir, save_log=self.arg.save_log, print_log=self.arg.print_log) + self.io.save_arg(self.arg) + + # gpu + if self.arg.use_gpu: + gpus = torchlight.gpu.visible_gpu(self.arg.device) + #torchlight.occupy_gpu(gpus) + self.gpus = gpus + self.dev = "cuda:0" + else: + self.dev = "cpu" + + def load_model(self): + self.model1 = self.io.load_model(self.arg.model1, **(self.arg.model1_args)) + self.model2 = self.io.load_model(self.arg.model2, **(self.arg.model2_args)) + + def load_weights(self): + if self.arg.weights1: + self.model1 = self.io.load_weights(self.model1, self.arg.weights1, self.arg.ignore_weights) + self.model2 = self.io.load_weights(self.model2, self.arg.weights2, self.arg.ignore_weights) + + def gpu(self): + # move modules to gpu + self.model1 = self.model1.to(self.dev) + self.model2 = self.model2.to(self.dev) + for name, value in vars(self).items(): + cls_name = str(value.__class__) + if cls_name.find('torch.nn.modules') != -1: + setattr(self, name, value.to(self.dev)) + + # model parallel + if self.arg.use_gpu and len(self.gpus) > 1: + self.model1 = nn.DataParallel(self.model1, device_ids=self.gpus) + self.model2 = nn.DataParallel(self.model2, device_ids=self.gpus) + + def start(self): + self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg)))) + + @staticmethod + def get_parser(add_help=False): + + #region arguments yapf: disable + # parameter priority: command line > config > default + parser = argparse.ArgumentParser( add_help=add_help, description='IO Processor') + + parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results') + parser.add_argument('-c', '--config', default=None, help='path to the configuration file') + + # processor + parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not') + parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing') + + # visulize and debug + parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not') + parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not') + + # model + parser.add_argument('--model1', default=None, help='the model will be used') + parser.add_argument('--model2', default=None, help='the model will be used') + parser.add_argument('--model1_args', action=DictAction, default=dict(), help='the arguments of model') + parser.add_argument('--model2_args', action=DictAction, default=dict(), help='the arguments of model') + parser.add_argument('--weights', default=None, help='the weights for network initialization') + parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization') + #endregion yapf: enable + + return parser diff --git a/torchlight/__init__.py b/torchlight/__init__.py index 8b13789..07e70f1 100644 --- a/torchlight/__init__.py +++ b/torchlight/__init__.py @@ -1 +1,8 @@ - +from .io import IO +from .io import str2bool +from .io import str2dict +from .io import DictAction +from .io import import_class +from .gpu import visible_gpu +from .gpu import occupy_gpu +from .gpu import ngpu diff --git a/torchlight/__pycache__/__init__.cpython-36.pyc b/torchlight/__pycache__/__init__.cpython-36.pyc index b3c790e409f2747726cc89be56d654bae0ca6888..d75de414854592b86f7c3e4bf36812c1d7880aaf 100644 GIT binary patch literal 392 zcmYk%Jx;?g7=U3r|4Cc!ks(;n6(Oo3A(nRN=Eby5)L3%j$aa7pgcC4v7fz6si7PPi z*#bl&zj;2(Qsm`gQM|5S9^Mf`AF%$O0CWS()Q}q0SfUP7jGY5;Km+d(9MaG`0!K9R zj=?dFy%U*qDNVgoa7Hul44l&(obOOsd}8=jiKpOx3(btxmj|WfuWP|AxVVMyid&(2 znC7Azl(vk^y)iJ)9)%GH*|6gKgK7o{8-Svf5R5>rFN6EpMv}VwL8t6vx&Q#zZeHW9P$Z%C0(9 QkFvRTcQj{^IsX5C13uPUZ~y=R delta 93 zcmeBRt_m>b<>k8YcX9$F0|UcjAcg~wfCCU0vjB+{hF}IwM!%H|MId1W@k?DlBR@A) Wzr46Y-!WL%-Pvz)0HYd62Lk~4KMDZb<=Yr9qMI?vn0ZI=<)e zo%z0bzw^)aJ7?ycnS16wf81B7T!b9HmMR)LDF6V#0w`8}Nx-i|`xiY4{{n7^BKDWb zD3X6UgTjrTRmQ3b$3_MH4Rz9}UMT-b-_cUBu>QHEqZ@VvhlBAXq2Fsugi{iNg4TJ_ zhymEb05@LT&u=WTY{zU6Z~L{-i4}=DU~m<0B!b-eCP6BQ%}vk+pX$aY5{Dh)Hn;WXOVXdBg&-IkCKwa&Ulr(ls>21oM+E@Z zG0*@2AON1gKncfWB!(;EGeh;qSp+lHg;suEUKS&g9w}c15}haOa^HW-Nt+L4L3}dI z0WX4bckX^wA0SV62m|gH8Mny66kJ1BJf0+YtXZ~2AY{EQ_p ziJ-~G2wBi_C#N-OplgI*dYH3dJl$9p24`2d>?ls$m1q1)c&j&Z;_ zZA(`Y3QW#&sk`TZZ{Q{HzuINrPkH}g#!K=faF_0g{ulyXIygS2Jvuq0Rm{g@#oo`M}}C#yK+>oeCTj&N$mht~Z|8ThPs#B-*Y~_PZvn-38fnU2l2oGhM9t>ypIPOPSmoz- z!q(qqp_g0!CR$kuCCJ9ENpRtu@{wNKhJBJ^c760$^Xflac#&-fTVK{XuX+eO-Mjnf zspLYMi)BzMu6-8yQ-RUyr+a^%zBn{i2+5Z1F5cXQT9|>XU@>~qDK+tK4ea0aJ7oIu z^1Xw5u3y^K-8;9SVPhp<%VKG1!)#3aGGBu-FQOYqO5L1-E*1k+{V%wmJfU_)KEu04xuLrr(7e-2})9qWOdPZ*X^+oC(uORkf=Cyz^%yU4adc^Oiv zcLl;itsLFYi#5}|^_vS|j~S(rFA%c|`jrLH=M}TPq6XD>l~C*VzT8mkN5%?>IMZ>{ zzBsUhJB&NvOrx6CdSCo~znK$cjWx_n2&-6gbfIeQB>cR-jZ+jvvdn7luecFhYDN~$ zBehMG!t6|B;b6sBnZbrl4#6qlWcSe_8PK-wjEyeZ z>)Z_c*tx?@kMM-)Lb6-=w^}aRV~$EZ&d$!vm1Zhgn_6`I3Ysh$Q=;dYAvKwmaW>qM znQn8ig^8<>-e^BLdQ|Gc5~}3F@jUX1TEsx%jCOqc!7X2zJT#CL*_}7PCo^}eqor%7+!BxJNAXf9h0LT(` zvjR8qM>P^xr>dmyA`WxMuEkr98Jwd^j4xjo{MObJBZ5x8lLTwm+nqGI>sAy=y>pCb ztOn+7bS5NoWu~w7!ETm53zgWWURwpoowkbz{8R_3S~ zcyVab`U-#5XmtA)+mde}xojChZ-K;b*kQRW62PbgI_?udn|d@VbXTg?UjOqtL!(hE zv7y+npgU@=Cm6hwgc906To&WOH9D}L^>PGw;1SqE2s!5Q|EQc>sKGhD%WY)g-<#B! zUDb~(6@Ks3Cpqt>cktWQ$dAt{tJZL0AfhO$I&UxdU^$bX8YCHlL|ArwPIbdincnf3 z5qMk>uCJ*R0m^>a3|jS3d@FG$bzeapRn;oQyI9EUJY>J5dJUE7jvt{MPCX7mS*oEW zP$FC%P@|<@6XPagJ z*eQAuGnq>|Db%j0vJ&8;oyF&oCE9vZM64M_7c1@{S&cCMzuAkvfx>5I;xdX?)DCVz0JiO_EI;24rbv>7=vNMVfhLwq`1~?! zU9a#;#bVS8S+sLOw+*(3Sl=VGUR+gJL?PZX&P4n#H2YB5s1%i*R;&+0RJ3@KBY9cR zi=OY>L7?t}hy1)c&aH)W4EGo$YU?SsTMg;hOKSl!HICjcik#q6P7g;3i-;z#M+T>J zbsV}im#5hAGP|zJ2aJew>iuFtm{fkaJDzeNvw1(*_*7qS>ZIe$qo5&g?mJGb`Vy@6 zLCC!w2Pn(d9;h~s7YWD_K0~*`jOgUy5cHsFgu2r-TK_NyxU(0pXBO0)R+Wvbp4r*m zXUi~&$rXxGQcd$yVb8E8g_t-Ht|iolDur_rbJ^sfpVr=5a(@fP%w{Wog~(iz<)m6xCc zkt>DJ@@~up%dtIA>MvF6(v){L7t>TWP;qCJFA0@5SBx674~esM5=;XJ;Z2P98n}{6 zhj`uAd52{xooMUN+yXq65A~VH3R7UjOz6gNVR$eUm02+gKEe#J0am5wj8M1}(IgP1} z_>D2(O;%^*ek!TfsDIjyggnX%{vXG9xL@ue{3WNygikl-Pmj@%lV8Na;;Cbw z8QJu%D`@1>x{34|*iuq?UYW;cm2p0Jv4xq{6%gw1`CYw)nmbTMBdZ%q49nO@ve5zBp^w5&>g`T7HYXE$4O^ zHQH0rPi0_c`lO~#44ZC|yva&9EVw}id!}lZ3qzE{M=6y?{Jbo3&8N9+hV)s!#tV`w zRmky;WRerfjJ8{R1XS5lL5+SNsi|?41Rzi{%)b#B{@D}9kSIdqo1lcM_&6qt<@zV% zd0)*8X^nCYFs}Xj-MxpqzvH;Po}E3EJW8V!r(~lDu?=4~3LQ%ukR{d21>_9G zl20Zw2?bi~_`5%?2>m>~1a;~8zs%K8^RciRbX5bt=cjY@1JX2L@>xsya{bED&L^f} z)`h0}Me?>J|3K1zNa(*6G-ISe7lx2tb+0|zw8?|T9*8K9h0IRiPn2{h5I7wbD^&PB zK7zVM*>2o3*gqlTZpuHGN`+#Zu*S_Y-thO>ENhY)ut@)KK@+9B0w_sqE za~qjhOT!eR_*+>0TBwnP35(IK3@>_wpR3N2c?{WZ2ZksJ`$L`iy;nMQdlh&wyE&W8QPcj7=%9aSh#%r z-H+&kfCGJ_vbJ@OP@bgkr>;HI~@T`v&0 zqE3O|Vna$Zz$;^q(v-Yl2j8bjwP;iNhP>~$9}U}>s&UbsN)4d~8_Mu3E+^MXX*p(aKgYWG`yB9kYB5 z-R_kL!(;8GWoAfp6;+OjVk>LMI73|iax<<;UDd5TPzI({HFnH8(qp}Hy8W8Bms<~4 zrlBL-F}-8YfLh+X*rGy>4Qj;E6=?SL`vVW4p@6D#UOB6DlcaLhtvZqn$GjC4Yw({i zM(C%kOeAj^S0XO44fD}~zzvcD-M%#AT)PI^Rx61-Cppj6398S9(chDsnT=tzV@b)5X)pE<4%8*;fCO6qB-Yr&p>`kjN~rAJ9lGW> z8lf{9A|FyDj>fy2hUnYeaz zqQ8v;rx0^&`j*2uRN6+At6!0R>dIm>mS;*!DheKXE5d|Fq0K|pSAKXpur55pLY{IP zse)A!VtXM)HlmD-_zoB0plH+hS@*j4mJ50$GhK7|E;u=NGKHgTX;HGuoCd4rPXFt( ziQSG`eg(OH>k3w6CSO65R6mW{#8XI7@lQ8Z%p&I-iHwxRb&EBQg_ouv$fUKHG^%U# zHJ{FQ-A<`kiz3mQ_F6I$wk3dYUe-s}%?K0x_{RJy-9yKvY2z{wB*s_fHTBL)(p={w ze_Y?h>bQsUuYA-m(YQF}FTMd!E^v%P7oxebMxrjHdw$Dp$c2F;xS#Y(1!z2j40w0> zTzX0$nK{Im<(>P0Pf}W{D5#`p|LaZQHgw{D4M`!L)L(AXz4=QXdYQlcKrj9ub5c() z{hx46FZQ1x#31^A-sB^L9`?Ty<*&tn&oF2({JZ=ABj-O;=igNyR8RO9%Tv25JiM5`ZqKw;`-tIm)_xJ5)%GdGNkBFWI~xgkVm)(RO{F} z+S&$;zrLr4OQvWwj>~RH7e+5ee`_j26B%hytPT8-d zJw7@xV0kpVMUPTMIBT9=(WTuX%c?k?3TSX@)S#L=RewQfBEtBbcuKv_3lR^2gJo!I zE1csq+M85uC4|Q4BAJ)ht({nNNckVg==e&H+|xofGOgx4aKI3R4QZE)gwCXS)LZpQ zYqexWbK7ZhK$iIKCf-V_u#*zKOMF-QxVt59vXl8n9UF~ZV!piTgsEd?V9^FIRoX3F zsK<)x-FXr&D29VdWYa%08GeZLZ87+aX?a0)DrVBr+=KiFcReeavIV$$fvTXL+G&;J znu+%(U{qW*=NUWuc-FQOa;c-vuJK{ZB_OvH)3D1jx6v>E#^7jx!Q?aPtnABaR5uv| zLS|xeUNYm6k*_9yKxbl%Y~<9aO1McywC$6s?5kcV1Q^Dd!C- zD}-jfACn|99!sX|>P9B!63>qaJamy|IfkmeFQAe{$g_5C?Jv^iLA1ZF=%1EgBae5n zfxham7z+qZD`X8$-VCX-+7Cf7g(k9y1c?jDJg97YA8tw1#XbmpXINo%*=2{>gdMPz zgWR2=1+X{d(f@dbezZ^zva3~8G|PKPS{Yry0OVr;fa}kf@$rD7=E&s`!q1md59;X_8y!b^ zma3mEZtyq=Q#_P=k4rLc3nRb+LZRs_mVBM;3)@?%Ms@OUp_))kU(c0|=bV*yagzLA z>aU)|EF8ZM@wTa?to*~HYitkt`eQn+e~2XZdoNn#XRq6?4Sny?#-m%J1lut3H|4NH zuerzhML{#7CLV!^jJ;Qrronx@g%SiqMVnGq2+LE=D*x`rJzi_o3*Or*Gn10J-`QXqOsF~5H*}QJRPptY_NvCyGWI4lMS`)u4WX;Fg0Up_)VyOGQeKpLJR7O)@gb^K zlxEXrLuVTZh8vTUu=d0d9ZB0=HT$l!L%$_WOv&o|Bb_HFOE1gdquFjQh1*8dwUjx* z-V;KNL9@yQFG;HSi7kE*MkOTFM$Bo&dee;FmVL(z##%&&cvJYH=28yuy+p}!3@mR# zlEGdc-&JfK7K}@vVhrc@Pcy6oXQdCO{a@IC(npdI#nyCUGIE4VOi}&JVh(qbb(XQT z*Wrmiip_jQ8k#qB@M}wu#PQY``bQ!R&kJUDUfeGH9e;BZ&4j>S?Chgge|D}cd2ubl z!ma%Udsavum!Jln#bG$#Us^wLu1SFhDBX5AD{%LoxwBc$?Py7tophdOm0I1yUAIKW zyyHeVmk+vN+|y`&CRLbu_;`m}XS3W81iz|K;lyVQSAXT8p(7rm@b)Pp928Z&19#H> zNS1aXRiH%qaNoPnA$--!el5(pNV(2fRM@J`9hAQ6Ni*8M0BkOPURQ-G0e|M#DXtyd zFEs>7%V9r=$!pcg5Ly;jW(1?Spfq!2gFc89JyV(=3)tG_S)N_JtVNX(SK^9!f5*}> zC(xC!gS^C6OXUJb4|t8D%R@Ios(qcW^+r4aVrr$|Gf-yQD+%!s4pL}J^e^ABJV=pf za)3-YMQ)mvJ>2=?H00c1lfCFIP!p_js7n&;Z}FCi+|Pghq}Cay+C}t!XL>oAoEd_1 zNpfL$S)aGypg#Pv{^rjDHSAM1?q`~(zW48126U^8R?*u0*+=y(t>7*yGnwHYEU&;_ zzUa*W(nQ}3nHTWNVh=s36pX-g&`P2eJUqFmjPn+l7~RW#GQqv?9Wq2-cO(Q4R4ORZ z<(vM_Y;FsGpVpdJi}6+n`{k-Gk67?C-XHMp_wWsti5Gjil;75&f*(m2F}7|D zno^{IpB0Fu#hu$D=+=pzYw<0La5t;5F?UN5tFKns<3ApjSFC5^oFKnN9cj_yj;=nw zAK@9aBTy^M#+>s5%W1HIL|B)HE8+#E_6M-nmsuOEa#N<^oO0uRG~iu5tCGHQEY?CU zyuR#$-Uz(JvAs6v$KygyJge`cCwc_we%DwN)%y___qs(D@Uo`JH(<_(mc{hs+s0AV z`l-&qdpBwiH17k?YlecFd;_;7D_l_tyf7K!@vnQi@k|2lgui;`0SNXt!6efoxsSK> z!*<*23g?!y=aV39%#&1`FWcNXO$kXyat32d&L#F*E-3~1gn@}7KD^Hr7|07Li8(BI z@|vDZ3A;8cTnU9k)sodJ3c4{1^)rLd8%{Fp!;c?n3q@LHrWJlMeQ+R zkSwEppiFzi*&}ll8ewz{@PMHcbY)3ej|3&tY|I2(?BLN1qr&Xh>2Pj{bDx=kLGR_F z`ZHIh248xHoOXH#7b4{4wk_%Q!LJvvWvn?M zp=!5oX;hNAF*$K16RGAGrPd}QR^=>4RinpszvC0|J(~j&BI$uSyVZyy(N~9y0c#dC z^_t+`yyciGbkn?JBy;+JKXh6vOV;>6nj8rdbU(U&dhxFH3(}=?1mbcYTVY3#Rx6|B z@=k~J2Y;-v1~{%6znXSpmV{(dmrn%4_cPUhKgmW;J?Aj&SwRe%E7Ej5#8c0Dnl+$d z$M|7Gs`F1zghrEh9kpD#5gmav-i9VzG5VL8xp z0h(!FNIg`p@*cw57w99@6&I+}U*GR1-?-UJov+p`L`3t`bEpwEgS?CZBLFqi;|-1H(Gg?SUfQ*L@|6gONPRuHqY;lI#35A`fvKyV=LGOTm7|sn&)Q zg-{NN|;)TmrWn-u)cVU;0xTL^0o|1(^z>?d&({K7a6|R!*uD>YU zc|mO;N5fQMvZ$$KRxabKz(iJUyBHn{nlUZ9RJ7WJaYl$tnljpa`aXA7cW{2MU@JCa z+Y2)t(i4!04IC~Psi>}eUAF9G*4$PiF0{~HquGbyWiUB?a)73qSv_0Ax-3w)R)7&P zX+t?yTec*S_&KoXdOOTFQ?E+WLnCUEPJIIAXIo(N zS-^4J1XPzK{z$!AMPwr98->zjk8L2J)`rOW&hVh$L?q5GPe};hfl<RpPM3u`!A4u*#Rv5KfSr=~t(_ghaR-(kz7{N>#Hak5omL?vL>kW zjX66@1Qz<_%*WjJXme2|QLG1h+>$OZk*DCf6~z1a;&T+BepG^gCcH>K^BC-^9yAEe zI1zP(n`#D4nXm zgYIUukuR8A3g6uXT6VLcN$?1THMQrV_S4}_ZZ+bWg2~zbzk^xT9~mwKK%)!hb*HmC zhyji|c`>4FMY%7Sag2rA_pl_F5~=oju-j4RX;F_2fc&I1bq*Mx<>l+^zUt_oi ze_Lb?T#v_)cu)nl-UNuTdLdU8kJcL*4`T6X3UW}Fujr>+Y0R2+FpxjF(S+RM-Nti{7SUe7E_S? za-f9xLuPqv)nZ(Qb6ov+e{)P#ck4j1{l0Xski?^9L7~13hm6v}(e`|2EJZ39-`RHj<;J$ax;>YvB|8>8AjaKZpL01~CEtJN*xhPq~Qz From 56a2e1d9f1f014c032422446720c580a5cd36dc3 Mon Sep 17 00:00:00 2001 From: wsx Date: Sat, 8 May 2021 19:33:03 +0800 Subject: [PATCH 3/4] the worked version with the evn yaml file --- README.md | 13 +- data_gen/__pycache__/__init__.cpython-36.pyc | Bin 0 -> 106 bytes .../__pycache__/preprocess.cpython-36.pyc | Bin 0 -> 1811 bytes data_gen/__pycache__/rotation.cpython-36.pyc | Bin 0 -> 1789 bytes data_gen/gpu.py | 35 +++ data_gen/io.py | 203 ++++++++++++++++++ data_gen/ntu_gen_preprocess.py | 3 +- environment.yml | 88 ++++++++ log/data_tree.log | 16 ++ log/train_aim.log | 21 ++ .../__pycache__/__init__.cpython-36.pyc | Bin 392 -> 376 bytes torchlight/__pycache__/io.cpython-36.pyc | Bin 7050 -> 7034 bytes torchlight/gpu.py | 1 + 13 files changed, 376 insertions(+), 4 deletions(-) create mode 100644 data_gen/__pycache__/__init__.cpython-36.pyc create mode 100644 data_gen/__pycache__/preprocess.cpython-36.pyc create mode 100644 data_gen/__pycache__/rotation.cpython-36.pyc create mode 100644 data_gen/gpu.py create mode 100644 data_gen/io.py create mode 100644 environment.yml create mode 100644 log/data_tree.log create mode 100644 log/train_aim.log diff --git a/README.md b/README.md index 601db34..411cfdb 100644 --- a/README.md +++ b/README.md @@ -16,10 +16,16 @@ In this repo, we show the example of model on NTU-RGB+D dataset. # Environments We use the similar input/output interface and system configuration like ST-GCN, where the torchlight module should be set up. +``` +cd torchlight +cp torchlight/torchlight/_init__.py gpu.py io.py ../ +``` +change all "from torchlight import ..." to +"from torchlight.io import ..." Run ``` -cd torchlight, python setup.py, cd .. +cd torchlight, python setup.py install, cd .. ``` @@ -30,13 +36,14 @@ For NTU-RGB+D dataset, you can download it from [NTU-RGB+D](http://rose1.ntu.edu ``` Then, run the preprocessing program to generate the input data, which is very important. ``` -python ./data_gen/ntu_gen_preprocess.py +cd data_gen +python ntu_gen_preprocess.py ``` # Training and Testing With this repo, you can pretrain AIM and save the module at first; then run the code to train the main pipleline of AS-GCN. For the recommended benchmark of Cross-Subject in NTU-RGB+D, ``` -PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml +PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml --device 0 1 2 TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xsub/train.yaml Test: python main.py recognition -c config/as_gcn/ntu-xsub/test.yaml ``` diff --git a/data_gen/__pycache__/__init__.cpython-36.pyc b/data_gen/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..600c7f33ca097297a747c4e361f68e8da10a3d08 GIT binary patch literal 106 zcmXr!<>iXIKRJPsfq~&M5W@jTzyXMhS%5?eLokCTqu)w~B9JhG_$98Vr=OBok{F+! hnx`KhpP83g5+AQuP(n5yc_$5JN*eB$V{=3}fxJLhC^Hc4&`m&o(}%PiTP_`=m#xOH1El z&tU~xJ|VovibI!HX4oq+hm~{=@RVs4#st>lFVRlz9QH%|Q>)Y2-N6dK1I89(Iy%ST zKXY2ngq`IX9(S{^A*N*mBh|(lZnQ{^iiB*Y3 zlM&+qY^60QE1is#eH16BUXe4i(sPb@oTgqai25P(yG(q`SmdqnSOj7aNB$@fe2^)N z#$xBTvS}a!&zT0TV3H>Ril zylx$kc!h*PrIE5nL3-j<4GV}8A|83A`J0FnWzY9A`aE|o$DisHJ>~RxFk*T<^c=Iq zTbZxq3t9_DxO6OXWWOT*8esXm4S#*QF2B#B-GsA*$44wpJISe9gI7O_`6vhnQv*c* zA*PSOAiRp3#Klcq#RT8M4b!*qo0f}PxDLJs{F-r&D!C2NKAOY>(hCZsc#+% zCF#hW3Y>@2RM|-&j#YtWf<@F@TY%^+PJlT&OED4f_I+N_+Uzh6C~ufy8a!r(<`wh0 eDPV6D(@Dra(w8D#1>=%Bc5&5ee&_-k7yk?CCg7p~ literal 0 HcmV?d00001 diff --git a/data_gen/__pycache__/rotation.cpython-36.pyc b/data_gen/__pycache__/rotation.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..722593f2ca1ecb0d44a88bf840ab7c3a85f35edd GIT binary patch literal 1789 zcmcIkJ(Jr+7~U^^*gpF*6DN1%0;ISiqPR=~WhUGV6jWCLmF8v~St}GDwtQL1$C+7G zhLVO~KubqMO~;Q=OLYY^bj(okylcz3Tn2^)+53J=dY^Zn)%*Q^_m24f`L`Zpzp-nV zfPIWrehi_R=2I4Pr$sEZ(DEA=ORcmAt+cNLXitZ_1MTas?m-87NB5ya%|Bz&;0jM- zCn-Idu3K0E)<;<7F=wn|XWX(CKNBmw?@Lxba1p-Z zM^Um-9ilqeR`1$ARl6G0Z|+2NQ{CBAx2mB$Vp;QX$xQjHwc7!5! zfgGUJ83;hMz<~HkB#lBQ_NS8~ndY%7=4KZ0#=pqYYHnP6YjFoa?5s?j=k8-dc1 zmz#aqTO!i~sZsad&tcG8UnBCNi-5>2war5m6m55U+dXyFWmcW&DZ mVK#KbwJFgT#<fd literal 0 HcmV?d00001 diff --git a/data_gen/gpu.py b/data_gen/gpu.py new file mode 100644 index 0000000..306c391 --- /dev/null +++ b/data_gen/gpu.py @@ -0,0 +1,35 @@ +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/data_gen/io.py b/data_gen/io.py new file mode 100644 index 0000000..c753ca1 --- /dev/null +++ b/data_gen/io.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/data_gen/ntu_gen_preprocess.py b/data_gen/ntu_gen_preprocess.py index 6323b30..9bc8423 100644 --- a/data_gen/ntu_gen_preprocess.py +++ b/data_gen/ntu_gen_preprocess.py @@ -140,4 +140,5 @@ def gendata(data_path, out_path, ignored_sample_path=None, benchmark='xsub', set if not os.path.exists(out_path): os.makedirs(out_path) print(b, sn) - gendata(arg.data_path, out_path, arg.ignored_sample_path, benchmark=b, part=sn) \ No newline at end of file + #gendata(arg.data_path, out_path, arg.ignored_sample_path, benchmark=b, part=sn) + gendata(arg.data_path, out_path, arg.ignored_sample_path, benchmark=b, set_name=sn) diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..9d0ea0f --- /dev/null +++ b/environment.yml @@ -0,0 +1,88 @@ +name: asgcn +channels: + - pytorch + - https://mirrors.ustc.edu.cn/anaconda/pkgs/main + - https://mirrors.ustc.edu.cn/anaconda/pkgs/main/ + - https://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge/ + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - blas=1.0=mkl + - ca-certificates=2021.4.13=h06a4308_1 + - certifi=2020.12.5=py36h06a4308_0 + - cffi=1.14.5=py36h261ae71_0 + - cuda90=1.0=h6433d27_0 + - cudatoolkit=10.0.130=0 + - cudnn=7.6.5=cuda10.0_0 + - cycler=0.10.0=py36_0 + - dbus=1.13.18=hb2f20db_0 + - expat=2.3.0=h2531618_2 + - fontconfig=2.13.1=h6c09931_0 + - freetype=2.10.4=h5ab3b9f_0 + - glib=2.68.1=h36276a3_0 + - gst-plugins-base=1.14.0=h8213a91_2 + - gstreamer=1.14.0=h28cd5cc_2 + - icu=58.2=he6710b0_3 + - intel-openmp=2019.4=243 + - jpeg=9b=h024ee3a_2 + - kiwisolver=1.3.1=py36h2531618_0 + - lcms2=2.11=h396b838_0 + - ld_impl_linux-64=2.33.1=h53a641e_7 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.1.0=hdf63c60_0 + - libgfortran-ng=7.3.0=hdf63c60_0 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - libtiff=4.2.0=h3942068_0 + - libuuid=1.0.3=h1bed415_2 + - libwebp-base=1.2.0=h27cfd23_0 + - libxcb=1.14=h7b6447c_0 + - libxml2=2.9.10=hb55368b_3 + - lz4-c=1.9.3=h2531618_0 + - matplotlib=3.3.2=h06a4308_0 + - matplotlib-base=3.3.2=py36h817c723_0 + - mkl=2018.0.3=1 + - mkl_fft=1.0.6=py36h7dd41cf_0 + - mkl_random=1.0.1=py36h4414c95_1 + - ncurses=6.2=he6710b0_1 + - ninja=1.10.2=py36hff7bd54_0 + - olefile=0.46=py36_0 + - openssl=1.1.1k=h27cfd23_0 + - pcre=8.44=he6710b0_0 + - pillow=8.1.2=py36he98fc37_0 + - pip=21.0.1=py36h06a4308_0 + - pycparser=2.20=py_2 + - pyparsing=2.4.7=pyhd3eb1b0_0 + - pyqt=5.9.2=py36h05f1152_2 + - python=3.6.13=hdb3f193_0 + - python-dateutil=2.8.1=pyhd3eb1b0_0 + - qt=5.9.7=h5867ecd_1 + - readline=8.1=h27cfd23_0 + - setuptools=52.0.0=py36h06a4308_0 + - sip=4.19.8=py36hf484d3e_0 + - six=1.15.0=py36h06a4308_0 + - sqlite=3.35.1=hdfb4753_0 + - tbb=2021.2.0=hff7bd54_0 + - tbb4py=2021.2.0=py36hff7bd54_0 + - tk=8.6.10=hbc83047_0 + - tornado=6.1=py36h27cfd23_0 + - wheel=0.36.2=pyhd3eb1b0_0 + - xz=5.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.5=h9ceee32_0 + - pip: + - argparse==1.4.0 + - cached-property==1.5.2 + - dataclasses==0.8 + - h5py==3.1.0 + - imageio==2.9.0 + - numpy==1.19.5 + - opencv-python==4.5.1.48 + - pyyaml==5.4.1 + - scikit-video==1.1.11 + - scipy==1.5.4 + - torch==1.7.1 + - torchvision==0.9.0 + - tqdm==4.60.0 + - typing-extensions==3.7.4.3 diff --git a/log/data_tree.log b/log/data_tree.log new file mode 100644 index 0000000..3bfbe88 --- /dev/null +++ b/log/data_tree.log @@ -0,0 +1,16 @@ +data +|-- NTU-RGB+D +| `-- samples_with_missing_skeletons.txt +`-- nturgb_d + |-- xsub + | |-- train_data_joint_pad.npy + | |-- train_label.pkl + | |-- val_data_joint_pad.npy + | `-- val_label.pkl + `-- xview + |-- train_data_joint_pad.npy + |-- train_label.pkl + |-- val_data_joint_pad.npy + `-- val_label.pkl + +4 directories, 9 files diff --git a/log/train_aim.log b/log/train_aim.log new file mode 100644 index 0000000..1f84ef7 --- /dev/null +++ b/log/train_aim.log @@ -0,0 +1,21 @@ +$python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml --device 0 1 2 + +/root/AS-GCN/processor/io.py:34: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details. + default_arg = yaml.load(f) +/root/AS-GCN/net/utils/adj_learn.py:18: UserWarning: This overload of nonzero is deprecated: + nonzero() +Consider using one of the following signatures instead: + nonzero(*, bool as_tuple) (Triggered internally at /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.) + offdiag_indices = (ones - eye).nonzero().t() +/root/anaconda3/envs/stgcn/lib/python3.6/site-packages/torch/nn/modules/container.py:435: UserWarning: Setting attributes on ParameterList is not supported. + warnings.warn("Setting attributes on ParameterList is not supported.") +/root/AS-GCN/net/utils/adj_learn.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument. + soft_max_1d = F.softmax(trans_input) + +[05.08.21|19:25:22] Parameters: +{'work_dir': './work_dir/recognition/ntu-xsub/AS_GCN', 'config': 'config/as_gcn/ntu-xsub/train_aim.yaml', 'phase': 'train', 'save_result': False, 'start_epoch': 0, 'num_epoch': 10, 'use_gpu': True, 'device': [0, 1, 2], 'log_interval': 100, 'save_interval': 1, 'eval_interval': 5, 'save_log': True, 'print_log': True, 'pavi_log': False, 'feeder': 'feeder.feeder.Feeder', 'num_worker': 4, 'train_feeder_args': {'data_path': './data/nturgb_d/xsub/train_data_joint_pad.npy', 'label_path': './data/nturgb_d/xsub/train_label.pkl', 'random_move': True, 'repeat_pad': True, 'down_sample': True, 'debug': False}, 'test_feeder_args': {'data_path': './data/nturgb_d/xsub/val_data_joint_pad.npy', 'label_path': './data/nturgb_d/xsub/val_label.pkl', 'random_move': False, 'repeat_pad': True, 'down_sample': True}, 'batch_size': 32, 'test_batch_size': 32, 'debug': False, 'model1': 'net.as_gcn.Model', 'model2': 'net.utils.adj_learn.AdjacencyLearn', 'model1_args': {'in_channels': 3, 'num_class': 60, 'dropout': 0.5, 'edge_importance_weighting': True, 'graph_args': {'layout': 'ntu-rgb+d', 'strategy': 'spatial', 'max_hop': 4}}, 'model2_args': {'n_in_enc': 150, 'n_hid_enc': 128, 'edge_types': 3, 'n_in_dec': 3, 'n_hid_dec': 128, 'node_num': 25}, 'weights1': None, 'weights2': None, 'ignore_weights': [], 'show_topk': [1, 5], 'base_lr1': 0.1, 'base_lr2': 0.0005, 'step': [50, 70, 90], 'optimizer': 'SGD', 'nesterov': True, 'weight_decay': 0.0001, 'max_hop_dir': 'max_hop_4', 'lamda_act': 0.5, 'lamda_act_dir': 'lamda_05'} + +[05.08.21|19:25:22] Training epoch: 0 +[05.08.21|19:25:29] Iter 0 Done. | loss2: 876.8732 | loss_nll: 832.9621 | loss_kl: 43.9111 | lr: 0.000500 +[05.08.21|19:26:56] Iter 100 Done. | loss2: 118.9051 | loss_nll: 110.3876 | loss_kl: 8.5176 | lr: 0.000500 +[05.08.21|19:28:14] Iter 200 Done. | loss2: 76.1775 | loss_nll: 71.7404 | loss_kl: 4.4371 | lr: 0.000500 diff --git a/torchlight/__pycache__/__init__.cpython-36.pyc b/torchlight/__pycache__/__init__.cpython-36.pyc index d75de414854592b86f7c3e4bf36812c1d7880aaf..9f75a6bb9658ea632019f7c35cb9c2b299d1129d 100644 GIT binary patch delta 33 ocmeBR{=vj%%*)Hg6ts~|myuUlzbHSyMBg!3*WKA~asZqHfxcs~uDi3JZb4#6a)z$qWCunI E09}C&RsaA1 diff --git a/torchlight/__pycache__/io.cpython-36.pyc b/torchlight/__pycache__/io.cpython-36.pyc index 6643cdd031af6d93a7ca34b6d5c5652079d894fa..4b8a85cc437e91d0559b5d5b476091e9e1f1731d 100644 GIT binary patch delta 34 pcmeA&|7FHz%*)Hg6tt1;G!w6^eo=mYiN0g7uDi3}=66hUBmkYe3jLDY0)59|U3X_c-GaoD Date: Mon, 21 Jun 2021 18:13:55 +0800 Subject: [PATCH 4/4] loss decrease --- LICENSE | 46 +- README.md | 144 ++-- config/as_gcn/ntu-xsub/__init__.py | 2 +- config/as_gcn/ntu-xsub/test.yaml | 96 +-- config/as_gcn/ntu-xsub/train.yaml | 108 +-- config/as_gcn/ntu-xsub/train_aim.yaml | 102 +-- config/as_gcn/ntu-xview/__init__.py | 2 +- config/as_gcn/ntu-xview/test.yaml | 96 +-- config/as_gcn/ntu-xview/train.yaml | 108 +-- config/as_gcn/ntu-xview/train_aim.yaml | 102 +-- data_gen/__init__.py | 2 +- data_gen/gpu.py | 70 +- data_gen/io.py | 406 +++++----- data_gen/ntu_gen_preprocess.py | 288 ++++---- data_gen/preprocess.py | 134 ++-- data_gen/rotation.py | 84 +-- environment.yml | 176 ++--- epoch_loss_class_eval.png | Bin 0 -> 23988 bytes epoch_loss_class_train.png | Bin 0 -> 23976 bytes epoch_loss_class_train.txt | 90 +++ feeder/__init__.py | 2 +- feeder/feeder.py | 178 ++--- feeder/tools.py | 486 ++++++------ img/readme.md | 2 +- .../epoch_loss_class_eval.png | Bin 0 -> 23988 bytes .../epoch_loss_class_train.png | Bin 0 -> 23976 bytes .../epoch_loss_class_train.txt | 90 +++ log/data_tree.log | 32 +- log/train_aim.log | 42 +- main.py | 48 +- net/__init__.py | 2 +- net/as_gcn.py | 615 ++++++++-------- net/model_poseformer.py | 223 ++++++ net/utils/__init__.py | 2 +- net/utils/adj_learn.py | 567 +++++++------- net/utils/graph.py | 256 +++---- net/utils/utils_adj.py | 94 +-- pip_req.txt | 15 + processor/__init__.py | 2 +- processor/gpu.py | 70 +- processor/io.py | 234 +++--- processor/processor.py | 398 +++++----- processor/recognition.py | 691 ++++++++++-------- requirement.txt | 327 +++++++++ torchlight/__init__.py | 16 +- .../__pycache__/__init__.cpython-36.pyc | Bin 376 -> 391 bytes torchlight/__pycache__/io.cpython-36.pyc | Bin 7034 -> 7049 bytes torchlight/build/lib/torchlight/__init__.py | 2 +- torchlight/build/lib/torchlight/gpu.py | 70 +- torchlight/build/lib/torchlight/io.py | 406 +++++----- torchlight/gpu.py | 72 +- torchlight/io.py | 406 +++++----- torchlight/setup.py | 16 +- torchlight/torchlight.egg-info/PKG-INFO | 20 +- torchlight/torchlight.egg-info/SOURCES.txt | 14 +- .../torchlight.egg-info/dependency_links.txt | 2 +- torchlight/torchlight.egg-info/top_level.txt | 2 +- torchlight/torchlight/__init__.py | 2 +- torchlight/torchlight/gpu.py | 70 +- torchlight/torchlight/io.py | 406 +++++----- 60 files changed, 4386 insertions(+), 3550 deletions(-) create mode 100644 epoch_loss_class_eval.png create mode 100644 epoch_loss_class_train.png create mode 100644 epoch_loss_class_train.txt create mode 100644 log/best_performance/epoch_loss_class_eval.png create mode 100644 log/best_performance/epoch_loss_class_train.png create mode 100644 log/best_performance/epoch_loss_class_train.txt create mode 100644 net/model_poseformer.py create mode 100644 pip_req.txt create mode 100644 requirement.txt diff --git a/LICENSE b/LICENSE index ec0a6fd..e272c92 100644 --- a/LICENSE +++ b/LICENSE @@ -1,23 +1,23 @@ -Copyright (c) 2019, Cooperative Medianet Innovation Center, Shanghai Jiao Tong University -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +Copyright (c) 2019, Cooperative Medianet Innovation Center, Shanghai Jiao Tong University +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md index 411cfdb..2f9fec4 100644 --- a/README.md +++ b/README.md @@ -1,71 +1,73 @@ -This repository contains the implementation of: -Actional-Structural Graph Convolutional Networks for Skeleton-based Action Recognition. [Paper](https://arxiv.org/pdf/1904.12659.pdf) - -![image](https://github.com/limaosen0/AS-GCN/blob/master/img/pipeline.png) - -Abstract: Action recognition with skeleton data has recently attracted much attention in computer vision. Previous studies are mostly based on fixed skeleton graphs, only capturing local physical dependencies among joints, which may miss implicit joint correlations. To capture richer dependencies, we introduce an encoder-decoder structure, called A-link inference module, to capture action-specific latent dependencies, i.e. actional links, directly from actions. We also extend the existing skeleton graphs to represent higherorder dependencies, i.e. structural links. Combing the two types of links into a generalized skeleton graph, we further propose the actional-structural graph convolution network (AS-GCN), which stacks actional-structural graph convolution and temporal convolution as a basic building block, to learn both spatial and temporal features for action recognition. A future pose prediction head is added in parallel to the recognition head to help capture more detailed action patterns through self-supervision. We validate AS-GCN in action recognition using two skeleton data sets, NTU-RGB+D and Kinetics. The proposed AS-GCN achieves consistently large improvement compared to the state-of-the-art methods. As a side product, AS-GCN also shows promising results for future pose prediction. - -In this repo, we show the example of model on NTU-RGB+D dataset. - -# Experiment Requirement -* Python 3.6 -* Pytorch 0.4.1 -* pyyaml -* argparse -* numpy - -# Environments -We use the similar input/output interface and system configuration like ST-GCN, where the torchlight module should be set up. -``` -cd torchlight -cp torchlight/torchlight/_init__.py gpu.py io.py ../ -``` -change all "from torchlight import ..." to -"from torchlight.io import ..." - -Run -``` -cd torchlight, python setup.py install, cd .. -``` - - -# Data Preparing -For NTU-RGB+D dataset, you can download it from [NTU-RGB+D](http://rose1.ntu.edu.sg/datasets/actionrecognition.asp). And put the dataset in the file path: -``` -'./data/NTU-RGB+D/nturgb+d_skeletons/' -``` -Then, run the preprocessing program to generate the input data, which is very important. -``` -cd data_gen -python ntu_gen_preprocess.py -``` - -# Training and Testing -With this repo, you can pretrain AIM and save the module at first; then run the code to train the main pipleline of AS-GCN. For the recommended benchmark of Cross-Subject in NTU-RGB+D, -``` -PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml --device 0 1 2 -TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xsub/train.yaml -Test: python main.py recognition -c config/as_gcn/ntu-xsub/test.yaml -``` - -For Cross-View, -``` -PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xview/train_aim.yaml -TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xview/train.yaml -Test: python main.py recognition -c config/as_gcn/ntu-xview/test.yaml -``` - -# Acknowledgement -Thanks for the framework provided by 'yysijie/st-gcn', which is source code of the published work [ST-GCN](https://aaai.org/ocs/index.php/AAAI/AAAI18/paper/view/17135) in AAAI-2018. The github repo is [ST-GCN code](https://github.com/yysijie/st-gcn). We borrow the framework and interface from the code. - -# Citation -If you use this code, please cite our paper: -``` -@InProceedings{Li_2019_CVPR, -author = {Li, Maosen and Chen, Siheng and Chen, Xu and Zhang, Ya and Wang, Yanfeng and Tian, Qi}, -title = {Actional-Structural Graph Convolutional Networks for Skeleton-Based Action Recognition}, -booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, -month = {June}, -year = {2019} -} -``` +This repository contains the implementation of: +Actional-Structural Graph Convolutional Networks for Skeleton-based Action Recognition. [Paper](https://arxiv.org/pdf/1904.12659.pdf) + +![image](https://github.com/limaosen0/AS-GCN/blob/master/img/pipeline.png) + +Abstract: Action recognition with skeleton data has recently attracted much attention in computer vision. Previous studies are mostly based on fixed skeleton graphs, only capturing local physical dependencies among joints, which may miss implicit joint correlations. To capture richer dependencies, we introduce an encoder-decoder structure, called A-link inference module, to capture action-specific latent dependencies, i.e. actional links, directly from actions. We also extend the existing skeleton graphs to represent higherorder dependencies, i.e. structural links. Combing the two types of links into a generalized skeleton graph, we further propose the actional-structural graph convolution network (AS-GCN), which stacks actional-structural graph convolution and temporal convolution as a basic building block, to learn both spatial and temporal features for action recognition. A future pose prediction head is added in parallel to the recognition head to help capture more detailed action patterns through self-supervision. We validate AS-GCN in action recognition using two skeleton data sets, NTU-RGB+D and Kinetics. The proposed AS-GCN achieves consistently large improvement compared to the state-of-the-art methods. As a side product, AS-GCN also shows promising results for future pose prediction. + +In this repo, we show the example of model on NTU-RGB+D dataset. + +# Experiment Requirement +* Python 3.6 +* Pytorch 0.4.1 +* pyyaml +* argparse +* numpy +* torch 1.7.1 + +# Environments +We use the similar input/output interface and system configuration like ST-GCN, where the torchlight module should be set up. +``` +cd torchlight +cp torchlight/torchlight/_init__.py gpu.py io.py ../ +``` +change all "from torchlight import ..." to +"from torchlight.io import ..." + +Run +``` +cd torchlight, python setup.py install, cd .. +``` + + +# Data Preparing +For NTU-RGB+D dataset, you can download it from [NTU-RGB+D](http://rose1.ntu.edu.sg/datasets/actionrecognition.asp). And put the dataset in the file path: +``` +'./data/NTU-RGB+D/nturgb+d_skeletons/' +``` +Then, run the preprocessing program to generate the input data, which is very important. +``` +cd data_gen +python ntu_gen_preprocess.py +``` + +# Training and Testing +With this repo, you can pretrain AIM and save the module at first; then run the code to train the main pipleline of AS-GCN. For the recommended benchmark of Cross-Subject in NTU-RGB+D, +``` +PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml --device 0 1 2 +TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xsub/train.yaml --device 0 --batch_size 4 +# only can use one gpu otherwise got the error "Caught RuntimeError in replica 0 on device 0"" +Test: python main.py recognition -c config/as_gcn/ntu-xsub/test.yaml +``` + +For Cross-View, +``` +PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xview/train_aim.yaml +TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xview/train.yaml +Test: python main.py recognition -c config/as_gcn/ntu-xview/test.yaml +``` + +# Acknowledgement +Thanks for the framework provided by 'yysijie/st-gcn', which is source code of the published work [ST-GCN](https://aaai.org/ocs/index.php/AAAI/AAAI18/paper/view/17135) in AAAI-2018. The github repo is [ST-GCN code](https://github.com/yysijie/st-gcn). We borrow the framework and interface from the code. + +# Citation +If you use this code, please cite our paper: +``` +@InProceedings{Li_2019_CVPR, +author = {Li, Maosen and Chen, Siheng and Chen, Xu and Zhang, Ya and Wang, Yanfeng and Tian, Qi}, +title = {Actional-Structural Graph Convolutional Networks for Skeleton-Based Action Recognition}, +booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, +month = {June}, +year = {2019} +} +``` diff --git a/config/as_gcn/ntu-xsub/__init__.py b/config/as_gcn/ntu-xsub/__init__.py index 8b13789..d3f5a12 100644 --- a/config/as_gcn/ntu-xsub/__init__.py +++ b/config/as_gcn/ntu-xsub/__init__.py @@ -1 +1 @@ - + diff --git a/config/as_gcn/ntu-xsub/test.yaml b/config/as_gcn/ntu-xsub/test.yaml index 30fbde9..f81fd70 100644 --- a/config/as_gcn/ntu-xsub/test.yaml +++ b/config/as_gcn/ntu-xsub/test.yaml @@ -1,48 +1,48 @@ -work_dir: ./work_dir/recognition/ntu-xsub/AS_GCN -weights1: ./work_dir/recognition/ntu-xsub/AS_GCN/max_hop_4/lamda_05/epoch99_model1.pt -weights2: ./work_dir/recognition/ntu-xsub/AS_GCN/max_hop_4/lamda_05/epoch99_model2.pt - -feeder: feeder.feeder.Feeder -train_feeder_args: - data_path: ./data/nturgb_d/xsub/train_data_joint_pad.npy - label_path: ./data/nturgb_d/xsub/train_label.pkl - random_move: True - repeat_pad: True - down_sample: True -test_feeder_args: - data_path: ./data/nturgb_d/xsub/val_data_joint_pad.npy - label_path: ./data/nturgb_d/xsub/val_label.pkl - random_move: False - repeat_pad: True - down_sample: True - -model1: net.as_gcn.Model -model1_args: - in_channels: 3 - num_class: 60 - dropout: 0.5 - edge_importance_weighting: True - graph_args: - layout: 'ntu-rgb+d' - strategy: 'spatial' - max_hop: 4 - -model2: net.utils.adj_learn.AdjacencyLearn -model2_args: - n_in_enc: 150 - n_hid_enc: 128 - edge_types: 3 - n_in_dec: 3 - n_hid_dec: 128 - node_num: 25 - -device: [0,1,2,3] -batch_size: 32 -test_batch_size: 32 -num_worker: 4 - -max_hop_dir: max_hop_4 -lamda_act_dir: lamda_05 -lamda_act: 0.5 - -phase: test +work_dir: ./work_dir/recognition/ntu-xsub/AS_GCN +weights1: ./work_dir/recognition/ntu-xsub/AS_GCN/max_hop_4/lamda_05/epoch99_model1.pt +weights2: ./work_dir/recognition/ntu-xsub/AS_GCN/max_hop_4/lamda_05/epoch99_model2.pt + +feeder: feeder.feeder.Feeder +train_feeder_args: + data_path: ./data/nturgb_d/xsub/train_data_joint_pad.npy + label_path: ./data/nturgb_d/xsub/train_label.pkl + random_move: True + repeat_pad: True + down_sample: True +test_feeder_args: + data_path: ./data/nturgb_d/xsub/val_data_joint_pad.npy + label_path: ./data/nturgb_d/xsub/val_label.pkl + random_move: False + repeat_pad: True + down_sample: True + +model1: net.as_gcn.Model +model1_args: + in_channels: 3 + num_class: 60 + dropout: 0.5 + edge_importance_weighting: True + graph_args: + layout: 'ntu-rgb+d' + strategy: 'spatial' + max_hop: 4 + +model2: net.utils.adj_learn.AdjacencyLearn +model2_args: + n_in_enc: 150 + n_hid_enc: 128 + edge_types: 3 + n_in_dec: 3 + n_hid_dec: 128 + node_num: 25 + +device: [0,1,2,3] +batch_size: 32 +test_batch_size: 32 +num_worker: 4 + +max_hop_dir: max_hop_4 +lamda_act_dir: lamda_05 +lamda_act: 0.5 + +phase: test diff --git a/config/as_gcn/ntu-xsub/train.yaml b/config/as_gcn/ntu-xsub/train.yaml index 2d43f64..939f3a3 100644 --- a/config/as_gcn/ntu-xsub/train.yaml +++ b/config/as_gcn/ntu-xsub/train.yaml @@ -1,54 +1,54 @@ -work_dir: ./work_dir/recognition/ntu-xsub/AS_GCN - -weights1: ./work_dir/recognition/ntu-xsub/AS_GCN/max_hop_4/lamda_05/epoch9_model1.pt -weights2: ./work_dir/recognition/ntu-xsub/AS_GCN/max_hop_4/lamda_05/epoch9_model2.pt - -feeder: feeder.feeder.Feeder -train_feeder_args: - data_path: ./data/nturgb_d/xsub/train_data_joint_pad.npy - label_path: ./data/nturgb_d/xsub/train_label.pkl - random_move: True - repeat_pad: True - down_sample: True -test_feeder_args: - data_path: ./data/nturgb_d/xsub/val_data_joint_pad.npy - label_path: ./data/nturgb_d/xsub/val_label.pkl - random_move: False - repeat_pad: True - down_sample: True - -model1: net.as_gcn.Model -model1_args: - in_channels: 3 - num_class: 60 - dropout: 0.5 - edge_importance_weighting: True - graph_args: - layout: 'ntu-rgb+d' - strategy: 'spatial' - max_hop: 4 - -model2: net.utils.adj_learn.AdjacencyLearn -model2_args: - n_in_enc: 150 - n_hid_enc: 128 - edge_types: 3 - n_in_dec: 3 - n_hid_dec: 128 - node_num: 25 - -weight_decay: 0.0001 -base_lr1: 0.1 -base_lr2: 0.0005 -step: [50, 70, 90] - -device: [0,1,2,3] -batch_size: 32 -test_batch_size: 32 -start_epoch: 10 -num_epoch: 100 -num_worker: 4 - -max_hop_dir: max_hop_4 -lamda_act_dir: lamda_05 -lamda_act: 0.5 +work_dir: ./work_dir/recognition/ntu-xsub/AS_GCN + +weights1: ./work_dir/recognition/ntu-xsub/AS_GCN/max_hop_4/lamda_05/epoch9_model1.pt +weights2: ./work_dir/recognition/ntu-xsub/AS_GCN/max_hop_4/lamda_05/epoch9_model2.pt + +feeder: feeder.feeder.Feeder +train_feeder_args: + data_path: ./data/nturgb_d/xsub/train_data_joint_pad.npy + label_path: ./data/nturgb_d/xsub/train_label.pkl + random_move: True + repeat_pad: True + down_sample: True +test_feeder_args: + data_path: ./data/nturgb_d/xsub/val_data_joint_pad.npy + label_path: ./data/nturgb_d/xsub/val_label.pkl + random_move: False + repeat_pad: True + down_sample: True + +model1: net.as_gcn.Model +model1_args: + in_channels: 3 + num_class: 60 + dropout: 0.5 + edge_importance_weighting: True + graph_args: + layout: 'ntu-rgb+d' + strategy: 'spatial' + max_hop: 4 + +model2: net.utils.adj_learn.AdjacencyLearn +model2_args: + n_in_enc: 150 + n_hid_enc: 128 + edge_types: 3 + n_in_dec: 3 + n_hid_dec: 128 + node_num: 25 + +weight_decay: 0.0001 +base_lr1: 0.0076 +base_lr2: 0.0005 +step: [50, 70, 90] + +device: [0,1,2,3] +batch_size: 32 +test_batch_size: 32 +start_epoch: 10 +num_epoch: 100 +num_worker: 4 + +max_hop_dir: max_hop_4 +lamda_act_dir: lamda_05 +lamda_act: 0.5 diff --git a/config/as_gcn/ntu-xsub/train_aim.yaml b/config/as_gcn/ntu-xsub/train_aim.yaml index c4c54b7..d74e1cf 100644 --- a/config/as_gcn/ntu-xsub/train_aim.yaml +++ b/config/as_gcn/ntu-xsub/train_aim.yaml @@ -1,51 +1,51 @@ -work_dir: ./work_dir/recognition/ntu-xsub/AS_GCN - -feeder: feeder.feeder.Feeder -train_feeder_args: - data_path: ./data/nturgb_d/xsub/train_data_joint_pad.npy - label_path: ./data/nturgb_d/xsub/train_label.pkl - random_move: True - repeat_pad: True - down_sample: True -test_feeder_args: - data_path: ./data/nturgb_d/xsub/val_data_joint_pad.npy - label_path: ./data/nturgb_d/xsub/val_label.pkl - random_move: False - repeat_pad: True - down_sample: True - -model1: net.as_gcn.Model -model1_args: - in_channels: 3 - num_class: 60 - dropout: 0.5 - edge_importance_weighting: True - graph_args: - layout: 'ntu-rgb+d' - strategy: 'spatial' - max_hop: 4 - -model2: net.utils.adj_learn.AdjacencyLearn -model2_args: - n_in_enc: 150 - n_hid_enc: 128 - edge_types: 3 - n_in_dec: 3 - n_hid_dec: 128 - node_num: 25 - -weight_decay: 0.0001 -base_lr1: 0.1 -base_lr2: 0.0005 -step: [50, 70, 90] - -device: [0,1,2,3] -batch_size: 32 -test_batch_size: 32 -start_epoch: 0 -num_epoch: 10 -num_worker: 4 - -max_hop_dir: max_hop_4 -lamda_act_dir: lamda_05 -lamda_act: 0.5 +work_dir: ./work_dir/recognition/ntu-xsub/AS_GCN + +feeder: feeder.feeder.Feeder +train_feeder_args: + data_path: ./data/nturgb_d/xsub/train_data_joint_pad.npy + label_path: ./data/nturgb_d/xsub/train_label.pkl + random_move: True + repeat_pad: True + down_sample: True +test_feeder_args: + data_path: ./data/nturgb_d/xsub/val_data_joint_pad.npy + label_path: ./data/nturgb_d/xsub/val_label.pkl + random_move: False + repeat_pad: True + down_sample: True + +model1: net.as_gcn.Model +model1_args: + in_channels: 3 + num_class: 60 + dropout: 0.5 + edge_importance_weighting: True + graph_args: + layout: 'ntu-rgb+d' + strategy: 'spatial' + max_hop: 4 + +model2: net.utils.adj_learn.AdjacencyLearn +model2_args: + n_in_enc: 150 + n_hid_enc: 128 + edge_types: 3 + n_in_dec: 3 + n_hid_dec: 128 + node_num: 25 + +weight_decay: 0.0001 +base_lr1: 0.1 +base_lr2: 0.0005 +step: [50, 70, 90] + +device: [0,1,2,3] +batch_size: 32 +test_batch_size: 32 +start_epoch: 0 +num_epoch: 10 +num_worker: 4 + +max_hop_dir: max_hop_4 +lamda_act_dir: lamda_05 +lamda_act: 0.5 diff --git a/config/as_gcn/ntu-xview/__init__.py b/config/as_gcn/ntu-xview/__init__.py index 8b13789..d3f5a12 100644 --- a/config/as_gcn/ntu-xview/__init__.py +++ b/config/as_gcn/ntu-xview/__init__.py @@ -1 +1 @@ - + diff --git a/config/as_gcn/ntu-xview/test.yaml b/config/as_gcn/ntu-xview/test.yaml index 50a1400..1496724 100644 --- a/config/as_gcn/ntu-xview/test.yaml +++ b/config/as_gcn/ntu-xview/test.yaml @@ -1,48 +1,48 @@ -work_dir: ./work_dir/recognition/ntu-xview/AS_GCN -weights1: ./work_dir/recognition/ntu-xview/AS_GCN/max_hop_4/lamda_05/epoch99_model1.pt -weights2: ./work_dir/recognition/ntu-xview/AS_GCN/max_hop_4/lamda_05/epoch99_model2.pt - -feeder: feeder.feeder.Feeder -train_feeder_args: - data_path: ./data/nturgb_d/xview/train_data_joint_pad.npy - label_path: ./data/nturgb_d/xview/train_label.pkl - random_move: True - repeat_pad: True - down_sample: True -test_feeder_args: - data_path: ./data/nturgb_d/xview/val_data_joint_pad.npy - label_path: ./data/nturgb_d/xview/val_label.pkl - random_move: False - repeat_pad: True - down_sample: True - -model1: net.as_gcn.Model -model1_args: - in_channels: 3 - num_class: 60 - dropout: 0.5 - edge_importance_weighting: True - graph_args: - layout: 'ntu-rgb+d' - strategy: 'spatial' - max_hop: 4 - -model2: net.utils.adj_learn.AdjacencyLearn -model2_args: - n_in_enc: 150 - n_hid_enc: 128 - edge_types: 3 - n_in_dec: 3 - n_hid_dec: 128 - node_num: 25 - -device: [0,1,2,3] -batch_size: 32 -test_batch_size: 32 -num_worker: 4 - -max_hop_dir: max_hop_4 -lamda_act_dir: lamda_05 -lamda_act: 0.5 - -phase: test +work_dir: ./work_dir/recognition/ntu-xview/AS_GCN +weights1: ./work_dir/recognition/ntu-xview/AS_GCN/max_hop_4/lamda_05/epoch99_model1.pt +weights2: ./work_dir/recognition/ntu-xview/AS_GCN/max_hop_4/lamda_05/epoch99_model2.pt + +feeder: feeder.feeder.Feeder +train_feeder_args: + data_path: ./data/nturgb_d/xview/train_data_joint_pad.npy + label_path: ./data/nturgb_d/xview/train_label.pkl + random_move: True + repeat_pad: True + down_sample: True +test_feeder_args: + data_path: ./data/nturgb_d/xview/val_data_joint_pad.npy + label_path: ./data/nturgb_d/xview/val_label.pkl + random_move: False + repeat_pad: True + down_sample: True + +model1: net.as_gcn.Model +model1_args: + in_channels: 3 + num_class: 60 + dropout: 0.5 + edge_importance_weighting: True + graph_args: + layout: 'ntu-rgb+d' + strategy: 'spatial' + max_hop: 4 + +model2: net.utils.adj_learn.AdjacencyLearn +model2_args: + n_in_enc: 150 + n_hid_enc: 128 + edge_types: 3 + n_in_dec: 3 + n_hid_dec: 128 + node_num: 25 + +device: [0,1,2,3] +batch_size: 32 +test_batch_size: 32 +num_worker: 4 + +max_hop_dir: max_hop_4 +lamda_act_dir: lamda_05 +lamda_act: 0.5 + +phase: test diff --git a/config/as_gcn/ntu-xview/train.yaml b/config/as_gcn/ntu-xview/train.yaml index da040ed..85344f7 100644 --- a/config/as_gcn/ntu-xview/train.yaml +++ b/config/as_gcn/ntu-xview/train.yaml @@ -1,54 +1,54 @@ -work_dir: ./work_dir/recognition/ntu-xview/AS_GCN - -weights1: ./work_dir/recognition/ntu-xview/AS_GCN/max_hop_4/lamda_05/epoch9_model1.pt -weights2: ./work_dir/recognition/ntu-xview/AS_GCN/max_hop_4/lamda_05/epoch9_model2.pt - -feeder: feeder.feeder.Feeder -train_feeder_args: - data_path: ./data/nturgb_d/xview/train_data_joint_pad.npy - label_path: ./data/nturgb_d/xview/train_label.pkl - random_move: True - repeat_pad: True - down_sample: True -test_feeder_args: - data_path: ./data/nturgb_d/xview/val_data_joint_pad.npy - label_path: ./data/nturgb_d/xview/val_label.pkl - random_move: False - repeat_pad: True - down_sample: True - -model1: net.as_gcn.Model -model1_args: - in_channels: 3 - num_class: 60 - dropout: 0.5 - edge_importance_weighting: True - graph_args: - layout: 'ntu-rgb+d' - strategy: 'spatial' - max_hop: 4 - -model2: net.utils.adj_learn.AdjacencyLearn -model2_args: - n_in_enc: 150 - n_hid_enc: 128 - edge_types: 3 - n_in_dec: 3 - n_hid_dec: 128 - node_num: 25 - -weight_decay: 0.0001 -base_lr1: 0.1 -base_lr2: 0.0005 -step: [50, 70, 90] - -device: [0,1,2,3] -batch_size: 32 -test_batch_size: 32 -start_epoch: 10 -num_epoch: 100 -num_worker: 4 - -max_hop_dir: max_hop_4 -lamda_act_dir: lamda_05 -lamda_act: 0.5 +work_dir: ./work_dir/recognition/ntu-xview/AS_GCN + +weights1: ./work_dir/recognition/ntu-xview/AS_GCN/max_hop_4/lamda_05/epoch9_model1.pt +weights2: ./work_dir/recognition/ntu-xview/AS_GCN/max_hop_4/lamda_05/epoch9_model2.pt + +feeder: feeder.feeder.Feeder +train_feeder_args: + data_path: ./data/nturgb_d/xview/train_data_joint_pad.npy + label_path: ./data/nturgb_d/xview/train_label.pkl + random_move: True + repeat_pad: True + down_sample: True +test_feeder_args: + data_path: ./data/nturgb_d/xview/val_data_joint_pad.npy + label_path: ./data/nturgb_d/xview/val_label.pkl + random_move: False + repeat_pad: True + down_sample: True + +model1: net.as_gcn.Model +model1_args: + in_channels: 3 + num_class: 60 + dropout: 0.5 + edge_importance_weighting: True + graph_args: + layout: 'ntu-rgb+d' + strategy: 'spatial' + max_hop: 4 + +model2: net.utils.adj_learn.AdjacencyLearn +model2_args: + n_in_enc: 150 + n_hid_enc: 128 + edge_types: 3 + n_in_dec: 3 + n_hid_dec: 128 + node_num: 25 + +weight_decay: 0.0001 +base_lr1: 0.1 +base_lr2: 0.0005 +step: [50, 70, 90] + +device: [0,1,2,3] +batch_size: 32 +test_batch_size: 32 +start_epoch: 10 +num_epoch: 100 +num_worker: 4 + +max_hop_dir: max_hop_4 +lamda_act_dir: lamda_05 +lamda_act: 0.5 diff --git a/config/as_gcn/ntu-xview/train_aim.yaml b/config/as_gcn/ntu-xview/train_aim.yaml index ac9e2ff..ec8aaf7 100644 --- a/config/as_gcn/ntu-xview/train_aim.yaml +++ b/config/as_gcn/ntu-xview/train_aim.yaml @@ -1,51 +1,51 @@ -work_dir: ./work_dir/recognition/ntu-xview/AS_GCN - -feeder: feeder.feeder.Feeder -train_feeder_args: - data_path: ./data/nturgb_d/xview/train_data_joint_pad.npy - label_path: ./data/nturgb_d/xview/train_label.pkl - random_move: True - repeat_pad: True - down_sample: True -test_feeder_args: - data_path: ./data/nturgb_d/xview/val_data_joint_pad.npy - label_path: ./data/nturgb_d/xview/val_label.pkl - random_move: False - repeat_pad: True - down_sample: True - -model1: net.as_gcn.Model -model1_args: - in_channels: 3 - num_class: 60 - dropout: 0.5 - edge_importance_weighting: True - graph_args: - layout: 'ntu-rgb+d' - strategy: 'spatial' - max_hop: 4 - -model2: net.utils.adj_learn.AdjacencyLearn -model2_args: - n_in_enc: 150 - n_hid_enc: 128 - edge_types: 3 - n_in_dec: 3 - n_hid_dec: 128 - node_num: 25 - -weight_decay: 0.0001 -base_lr1: 0.1 -base_lr2: 0.0005 -step: [40, 70, 90] - -device: [0,1,2,3] -batch_size: 32 -test_batch_size: 32 -start_epoch: 0 -num_epoch: 10 -num_worker: 4 - -max_hop_dir: max_hop_4 -lamda_act_dir: lamda_05 -lamda_act: 0.5 +work_dir: ./work_dir/recognition/ntu-xview/AS_GCN + +feeder: feeder.feeder.Feeder +train_feeder_args: + data_path: ./data/nturgb_d/xview/train_data_joint_pad.npy + label_path: ./data/nturgb_d/xview/train_label.pkl + random_move: True + repeat_pad: True + down_sample: True +test_feeder_args: + data_path: ./data/nturgb_d/xview/val_data_joint_pad.npy + label_path: ./data/nturgb_d/xview/val_label.pkl + random_move: False + repeat_pad: True + down_sample: True + +model1: net.as_gcn.Model +model1_args: + in_channels: 3 + num_class: 60 + dropout: 0.5 + edge_importance_weighting: True + graph_args: + layout: 'ntu-rgb+d' + strategy: 'spatial' + max_hop: 4 + +model2: net.utils.adj_learn.AdjacencyLearn +model2_args: + n_in_enc: 150 + n_hid_enc: 128 + edge_types: 3 + n_in_dec: 3 + n_hid_dec: 128 + node_num: 25 + +weight_decay: 0.0001 +base_lr1: 0.1 +base_lr2: 0.0005 +step: [40, 70, 90] + +device: [0,1,2,3] +batch_size: 32 +test_batch_size: 32 +start_epoch: 0 +num_epoch: 10 +num_worker: 4 + +max_hop_dir: max_hop_4 +lamda_act_dir: lamda_05 +lamda_act: 0.5 diff --git a/data_gen/__init__.py b/data_gen/__init__.py index 8b13789..d3f5a12 100644 --- a/data_gen/__init__.py +++ b/data_gen/__init__.py @@ -1 +1 @@ - + diff --git a/data_gen/gpu.py b/data_gen/gpu.py index 306c391..e086d4c 100644 --- a/data_gen/gpu.py +++ b/data_gen/gpu.py @@ -1,35 +1,35 @@ -import os -import torch - - -def visible_gpu(gpus): - """ - set visible gpu. - - can be a single id, or a list - - return a list of new gpus ids - """ - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) - return list(range(len(gpus))) - - -def ngpu(gpus): - """ - count how many gpus used. - """ - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - return len(gpus) - - -def occupy_gpu(gpus=None): - """ - make program appear on nvidia-smi. - """ - if gpus is None: - torch.zeros(1).cuda() - else: - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - for g in gpus: - torch.zeros(1).cuda(g) +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/data_gen/io.py b/data_gen/io.py index c753ca1..5b43720 100644 --- a/data_gen/io.py +++ b/data_gen/io.py @@ -1,203 +1,203 @@ -#!/usr/bin/env python -import argparse -import os -import sys -import traceback -import time -import warnings -import pickle -from collections import OrderedDict -import yaml -import numpy as np -# torch -import torch -import torch.nn as nn -import torch.optim as optim -from torch.autograd import Variable - -with warnings.catch_warnings(): - warnings.filterwarnings("ignore",category=FutureWarning) - import h5py - -class IO(): - def __init__(self, work_dir, save_log=True, print_log=True): - self.work_dir = work_dir - self.save_log = save_log - self.print_to_screen = print_log - self.cur_time = time.time() - self.split_timer = {} - self.pavi_logger = None - self.session_file = None - self.model_text = '' - - # PaviLogger is removed in this version - def log(self, *args, **kwargs): - pass - # try: - # if self.pavi_logger is None: - # from torchpack.runner.hooks import PaviLogger - # url = 'http://pavi.parrotsdnn.org/log' - # with open(self.session_file, 'r') as f: - # info = dict( - # session_file=self.session_file, - # session_text=f.read(), - # model_text=self.model_text) - # self.pavi_logger = PaviLogger(url) - # self.pavi_logger.connect(self.work_dir, info=info) - # self.pavi_logger.log(*args, **kwargs) - # except: #pylint: disable=W0702 - # pass - - def load_model(self, model, **model_args): - Model = import_class(model) - model = Model(**model_args) - self.model_text += '\n\n' + str(model) - return model - - def load_weights(self, model, weights_path, ignore_weights=None): - if ignore_weights is None: - ignore_weights = [] - if isinstance(ignore_weights, str): - ignore_weights = [ignore_weights] - - self.print_log('Load weights from {}.'.format(weights_path)) - weights = torch.load(weights_path) - weights = OrderedDict([[k.split('module.')[-1], - v.cpu()] for k, v in weights.items()]) - - # filter weights - for i in ignore_weights: - ignore_name = list() - for w in weights: - if w.find(i) == 0: - ignore_name.append(w) - for n in ignore_name: - weights.pop(n) - self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) - - for w in weights: - self.print_log('Load weights [{}].'.format(w)) - - try: - model.load_state_dict(weights) - except (KeyError, RuntimeError): - state = model.state_dict() - diff = list(set(state.keys()).difference(set(weights.keys()))) - for d in diff: - self.print_log('Can not find weights [{}].'.format(d)) - state.update(weights) - model.load_state_dict(state) - return model - - def save_pkl(self, result, filename): - with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: - pickle.dump(result, f) - - def save_h5(self, result, filename): - with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: - for k in result.keys(): - f[k] = result[k] - - def save_model(self, model, name): - model_path = '{}/{}'.format(self.work_dir, name) - state_dict = model.state_dict() - weights = OrderedDict([[''.join(k.split('module.')), - v.cpu()] for k, v in state_dict.items()]) - torch.save(weights, model_path) - self.print_log('The model has been saved as {}.'.format(model_path)) - - def save_arg(self, arg): - - self.session_file = '{}/config.yaml'.format(self.work_dir) - - # save arg - arg_dict = vars(arg) - if not os.path.exists(self.work_dir): - os.makedirs(self.work_dir) - with open(self.session_file, 'w') as f: - f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) - yaml.dump(arg_dict, f, default_flow_style=False, indent=4) - - def print_log(self, str, print_time=True): - if print_time: - # localtime = time.asctime(time.localtime(time.time())) - str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str - - if self.print_to_screen: - print(str) - if self.save_log: - with open('{}/log.txt'.format(self.work_dir), 'a') as f: - print(str, file=f) - - def init_timer(self, *name): - self.record_time() - self.split_timer = {k: 0.0000001 for k in name} - - def check_time(self, name): - self.split_timer[name] += self.split_time() - - def record_time(self): - self.cur_time = time.time() - return self.cur_time - - def split_time(self): - split_time = time.time() - self.cur_time - self.record_time() - return split_time - - def print_timer(self): - proportion = { - k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) - for k, v in self.split_timer.items() - } - self.print_log('Time consumption:') - for k in proportion: - self.print_log( - '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) - ) - - -def str2bool(v): - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - - -def str2dict(v): - return eval('dict({})'.format(v)) #pylint: disable=W0123 - - -def _import_class_0(name): - components = name.split('.') - mod = __import__(components[0]) - for comp in components[1:]: - mod = getattr(mod, comp) - return mod - - -def import_class(import_str): - mod_str, _sep, class_str = import_str.rpartition('.') - __import__(mod_str) - try: - return getattr(sys.modules[mod_str], class_str) - except AttributeError: - raise ImportError('Class %s cannot be found (%s)' % - (class_str, - traceback.format_exception(*sys.exc_info()))) - - -class DictAction(argparse.Action): - def __init__(self, option_strings, dest, nargs=None, **kwargs): - if nargs is not None: - raise ValueError("nargs not allowed") - super(DictAction, self).__init__(option_strings, dest, **kwargs) - - def __call__(self, parser, namespace, values, option_string=None): - input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 - output_dict = getattr(namespace, self.dest) - for k in input_dict: - output_dict[k] = input_dict[k] - setattr(namespace, self.dest, output_dict) +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/data_gen/ntu_gen_preprocess.py b/data_gen/ntu_gen_preprocess.py index 9bc8423..99b27a4 100644 --- a/data_gen/ntu_gen_preprocess.py +++ b/data_gen/ntu_gen_preprocess.py @@ -1,144 +1,144 @@ -import argparse -import pickle -from tqdm import tqdm -import sys - -sys.path.extend(['../']) -from data_gen.preprocess import pre_normalization - -training_subjects = [1, 2, 4, 5, 8, 9, 13, 14, 15, 16, 17, 18, 19, 25, 27, 28, 31, 34, 35, 38] -training_cameras = [2, 3] -max_body_true = 2 -max_body_kinect = 4 -num_joint = 25 -max_frame = 300 - -import numpy as np -import os - - -def read_skeleton_filter(file): - with open(file, 'r') as f: - skeleton_sequence = {} - skeleton_sequence['numFrame'] = int(f.readline()) - skeleton_sequence['frameInfo'] = [] - for t in range(skeleton_sequence['numFrame']): - frame_info = {} - frame_info['numBody'] = int(f.readline()) - frame_info['bodyInfo'] = [] - - for m in range(frame_info['numBody']): - body_info = {} - body_info_key = ['bodyID', 'clipedEdges', 'handLeftConfidence', - 'handLeftState', 'handRightConfidence', 'handRightState', - 'isResticted', 'leanX', 'leanY', 'trackingState'] - body_info = {k: float(v) for k, v in zip(body_info_key, f.readline().split())} - body_info['numJoint'] = int(f.readline()) - body_info['jointInfo'] = [] - for v in range(body_info['numJoint']): - joint_info_key = ['x', 'y', 'z', 'depthX', 'depthY', 'colorX', 'colorY', - 'orientationW', 'orientationX', 'orientationY', - 'orientationZ', 'trackingState'] - joint_info = {k: float(v) for k, v in zip(joint_info_key, f.readline().split())} - body_info['jointInfo'].append(joint_info) - frame_info['bodyInfo'].append(body_info) - skeleton_sequence['frameInfo'].append(frame_info) - - return skeleton_sequence - - -def get_nonzero_std(s): - index = s.sum(-1).sum(-1) != 0 - s = s[index] - if len(s) != 0: - s = s[:, :, 0].std() + s[:, :, 1].std() + s[:, :, 2].std() - else: - s = 0 - return s - - -def read_xyz(file, max_body=4, num_joint=25): - seq_info = read_skeleton_filter(file) - data = np.zeros((max_body, seq_info['numFrame'], num_joint, 3)) - for n, f in enumerate(seq_info['frameInfo']): - for m, b in enumerate(f['bodyInfo']): - for j, v in enumerate(b['jointInfo']): - if m < max_body and j < num_joint: - data[m, n, j, :] = [v['x'], v['y'], v['z']] - else: - pass - - energy = np.array([get_nonzero_std(x) for x in data]) - index = energy.argsort()[::-1][0:max_body_true] - data = data[index] - - data = data.transpose(3, 1, 2, 0) - return data - - -def gendata(data_path, out_path, ignored_sample_path=None, benchmark='xsub', set_name='val'): - if ignored_sample_path != None: - with open(ignored_sample_path, 'r') as f: - ignored_samples = [line.strip() + '.skeleton' for line in f.readlines()] - else: - ignored_samples = [] - sample_name = [] - sample_label = [] - for filename in os.listdir(data_path): - if filename in ignored_samples: - continue - action_class = int(filename[filename.find('A') + 1:filename.find('A') + 4]) - subject_id = int(filename[filename.find('P') + 1:filename.find('P') + 4]) - camera_id = int(filename[filename.find('C') + 1:filename.find('C') + 4]) - - if benchmark == 'xview': - istraining = (camera_id in training_cameras) - elif benchmark == 'xsub': - istraining = (subject_id in training_subjects) - else: - raise ValueError() - - if set_name == 'train': - issample = istraining - elif set_name == 'val': - issample = not (istraining) - else: - raise ValueError() - - if issample: - sample_name.append(filename) - sample_label.append(action_class - 1) - print(len(sample_label)) - - with open('{}/{}_label.pkl'.format(out_path, set_name), 'wb') as f: - pickle.dump((sample_name, list(sample_label)), f) - - fp = np.zeros((len(sample_label), 3, max_frame, num_joint, max_body_true), dtype=np.float32) - - for i, s in enumerate(tqdm(sample_name)): - print(s) - data = read_xyz(os.path.join(data_path, s), max_body=max_body_kinect, num_joint=num_joint) - fp[i, :, 0:data.shape[1], :, :] = data - - fp = pre_normalization(fp) - np.save('{}/{}_data_joint_pad.npy'.format(out_path, set_name), fp) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NTU-RGB-D Data Converter.') - parser.add_argument('--data_path', default='../data/NTU-RGB+D/nturgb+d_skeletons/') - parser.add_argument('--ignored_sample_path', default='../data/NTU-RGB+D/samples_with_missing_skeletons.txt') - parser.add_argument('--out_folder', default='../data/nturgb_d/') - - benchmark = ['xsub', 'xview'] - set_name = ['train', 'val'] - arg = parser.parse_args() - - for b in benchmark: - for sn in set_name: - out_path = os.path.join(arg.out_folder, b) - if not os.path.exists(out_path): - os.makedirs(out_path) - print(b, sn) - #gendata(arg.data_path, out_path, arg.ignored_sample_path, benchmark=b, part=sn) - gendata(arg.data_path, out_path, arg.ignored_sample_path, benchmark=b, set_name=sn) +import argparse +import pickle +from tqdm import tqdm +import sys + +sys.path.extend(['../']) +from data_gen.preprocess import pre_normalization + +training_subjects = [1, 2, 4, 5, 8, 9, 13, 14, 15, 16, 17, 18, 19, 25, 27, 28, 31, 34, 35, 38] +training_cameras = [2, 3] +max_body_true = 2 +max_body_kinect = 4 +num_joint = 25 +max_frame = 300 + +import numpy as np +import os + + +def read_skeleton_filter(file): + with open(file, 'r') as f: + skeleton_sequence = {} + skeleton_sequence['numFrame'] = int(f.readline()) + skeleton_sequence['frameInfo'] = [] + for t in range(skeleton_sequence['numFrame']): + frame_info = {} + frame_info['numBody'] = int(f.readline()) + frame_info['bodyInfo'] = [] + + for m in range(frame_info['numBody']): + body_info = {} + body_info_key = ['bodyID', 'clipedEdges', 'handLeftConfidence', + 'handLeftState', 'handRightConfidence', 'handRightState', + 'isResticted', 'leanX', 'leanY', 'trackingState'] + body_info = {k: float(v) for k, v in zip(body_info_key, f.readline().split())} + body_info['numJoint'] = int(f.readline()) + body_info['jointInfo'] = [] + for v in range(body_info['numJoint']): + joint_info_key = ['x', 'y', 'z', 'depthX', 'depthY', 'colorX', 'colorY', + 'orientationW', 'orientationX', 'orientationY', + 'orientationZ', 'trackingState'] + joint_info = {k: float(v) for k, v in zip(joint_info_key, f.readline().split())} + body_info['jointInfo'].append(joint_info) + frame_info['bodyInfo'].append(body_info) + skeleton_sequence['frameInfo'].append(frame_info) + + return skeleton_sequence + + +def get_nonzero_std(s): + index = s.sum(-1).sum(-1) != 0 + s = s[index] + if len(s) != 0: + s = s[:, :, 0].std() + s[:, :, 1].std() + s[:, :, 2].std() + else: + s = 0 + return s + + +def read_xyz(file, max_body=4, num_joint=25): + seq_info = read_skeleton_filter(file) + data = np.zeros((max_body, seq_info['numFrame'], num_joint, 3)) + for n, f in enumerate(seq_info['frameInfo']): + for m, b in enumerate(f['bodyInfo']): + for j, v in enumerate(b['jointInfo']): + if m < max_body and j < num_joint: + data[m, n, j, :] = [v['x'], v['y'], v['z']] + else: + pass + + energy = np.array([get_nonzero_std(x) for x in data]) + index = energy.argsort()[::-1][0:max_body_true] + data = data[index] + + data = data.transpose(3, 1, 2, 0) + return data + + +def gendata(data_path, out_path, ignored_sample_path=None, benchmark='xsub', set_name='val'): + if ignored_sample_path != None: + with open(ignored_sample_path, 'r') as f: + ignored_samples = [line.strip() + '.skeleton' for line in f.readlines()] + else: + ignored_samples = [] + sample_name = [] + sample_label = [] + for filename in os.listdir(data_path): + if filename in ignored_samples: + continue + action_class = int(filename[filename.find('A') + 1:filename.find('A') + 4]) + subject_id = int(filename[filename.find('P') + 1:filename.find('P') + 4]) + camera_id = int(filename[filename.find('C') + 1:filename.find('C') + 4]) + + if benchmark == 'xview': + istraining = (camera_id in training_cameras) + elif benchmark == 'xsub': + istraining = (subject_id in training_subjects) + else: + raise ValueError() + + if set_name == 'train': + issample = istraining + elif set_name == 'val': + issample = not (istraining) + else: + raise ValueError() + + if issample: + sample_name.append(filename) + sample_label.append(action_class - 1) + print(len(sample_label)) + + with open('{}/{}_label.pkl'.format(out_path, set_name), 'wb') as f: + pickle.dump((sample_name, list(sample_label)), f) + + fp = np.zeros((len(sample_label), 3, max_frame, num_joint, max_body_true), dtype=np.float32) + + for i, s in enumerate(tqdm(sample_name)): + print(s) + data = read_xyz(os.path.join(data_path, s), max_body=max_body_kinect, num_joint=num_joint) + fp[i, :, 0:data.shape[1], :, :] = data + + fp = pre_normalization(fp) + np.save('{}/{}_data_joint_pad.npy'.format(out_path, set_name), fp) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='NTU-RGB-D Data Converter.') + parser.add_argument('--data_path', default='../data/NTU-RGB+D/nturgb+d_skeletons/') + parser.add_argument('--ignored_sample_path', default='../data/NTU-RGB+D/samples_with_missing_skeletons.txt') + parser.add_argument('--out_folder', default='../data/nturgb_d/') + + benchmark = ['xsub', 'xview'] + set_name = ['train', 'val'] + arg = parser.parse_args() + + for b in benchmark: + for sn in set_name: + out_path = os.path.join(arg.out_folder, b) + if not os.path.exists(out_path): + os.makedirs(out_path) + print(b, sn) + #gendata(arg.data_path, out_path, arg.ignored_sample_path, benchmark=b, part=sn) + gendata(arg.data_path, out_path, arg.ignored_sample_path, benchmark=b, set_name=sn) diff --git a/data_gen/preprocess.py b/data_gen/preprocess.py index 86810aa..4e6ed9f 100644 --- a/data_gen/preprocess.py +++ b/data_gen/preprocess.py @@ -1,68 +1,68 @@ -import sys - -sys.path.extend(['../']) -from data_gen.rotation import * -from tqdm import tqdm - - -def pre_normalization(data, zaxis=[0, 1], xaxis=[8, 4]): - N, C, T, V, M = data.shape - s = np.transpose(data, [0, 4, 2, 3, 1]) - - print('sub the center joint') - for i_s, skeleton in enumerate(tqdm(s)): - if skeleton.sum() == 0: - continue - main_body_center = skeleton[0][:, 1:2, :].copy() - for i_p, person in enumerate(skeleton): - if person.sum() == 0: - continue - mask = (person.sum(-1) != 0).reshape(T, V, 1) - s[i_s, i_p] = (s[i_s, i_p] - main_body_center) * mask - - - print('parallel the torso bone') - for i_s, skeleton in enumerate(tqdm(s)): - if skeleton.sum() == 0: - continue - joint_bottom = skeleton[0, 0, zaxis[0]] - joint_top = skeleton[0, 0, zaxis[1]] - axis = np.cross(joint_top - joint_bottom, [0, 0, 1]) - angle = angle_between(joint_top - joint_bottom, [0, 0, 1]) - matrix_z = rotation_matrix(axis, angle) - for i_p, person in enumerate(skeleton): - if person.sum() == 0: - continue - for i_f, frame in enumerate(person): - if frame.sum() == 0: - continue - for i_j, joint in enumerate(frame): - s[i_s, i_p, i_f, i_j] = np.dot(matrix_z, joint) - - - print('parallel the shoulder bone') - for i_s, skeleton in enumerate(tqdm(s)): - if skeleton.sum() == 0: - continue - joint_rshoulder = skeleton[0, 0, xaxis[0]] - joint_lshoulder = skeleton[0, 0, xaxis[1]] - axis = np.cross(joint_rshoulder - joint_lshoulder, [1, 0, 0]) - angle = angle_between(joint_rshoulder - joint_lshoulder, [1, 0, 0]) - matrix_x = rotation_matrix(axis, angle) - for i_p, person in enumerate(skeleton): - if person.sum() == 0: - continue - for i_f, frame in enumerate(person): - if frame.sum() == 0: - continue - for i_j, joint in enumerate(frame): - s[i_s, i_p, i_f, i_j] = np.dot(matrix_x, joint) - - data = np.transpose(s, [0, 4, 2, 3, 1]) - return data - - -if __name__ == '__main__': - data = np.load('../data/NTU-RGB+D/xsub/train_data.npy') - pre_normalization(data) +import sys + +sys.path.extend(['../']) +from data_gen.rotation import * +from tqdm import tqdm + + +def pre_normalization(data, zaxis=[0, 1], xaxis=[8, 4]): + N, C, T, V, M = data.shape + s = np.transpose(data, [0, 4, 2, 3, 1]) + + print('sub the center joint') + for i_s, skeleton in enumerate(tqdm(s)): + if skeleton.sum() == 0: + continue + main_body_center = skeleton[0][:, 1:2, :].copy() + for i_p, person in enumerate(skeleton): + if person.sum() == 0: + continue + mask = (person.sum(-1) != 0).reshape(T, V, 1) + s[i_s, i_p] = (s[i_s, i_p] - main_body_center) * mask + + + print('parallel the torso bone') + for i_s, skeleton in enumerate(tqdm(s)): + if skeleton.sum() == 0: + continue + joint_bottom = skeleton[0, 0, zaxis[0]] + joint_top = skeleton[0, 0, zaxis[1]] + axis = np.cross(joint_top - joint_bottom, [0, 0, 1]) + angle = angle_between(joint_top - joint_bottom, [0, 0, 1]) + matrix_z = rotation_matrix(axis, angle) + for i_p, person in enumerate(skeleton): + if person.sum() == 0: + continue + for i_f, frame in enumerate(person): + if frame.sum() == 0: + continue + for i_j, joint in enumerate(frame): + s[i_s, i_p, i_f, i_j] = np.dot(matrix_z, joint) + + + print('parallel the shoulder bone') + for i_s, skeleton in enumerate(tqdm(s)): + if skeleton.sum() == 0: + continue + joint_rshoulder = skeleton[0, 0, xaxis[0]] + joint_lshoulder = skeleton[0, 0, xaxis[1]] + axis = np.cross(joint_rshoulder - joint_lshoulder, [1, 0, 0]) + angle = angle_between(joint_rshoulder - joint_lshoulder, [1, 0, 0]) + matrix_x = rotation_matrix(axis, angle) + for i_p, person in enumerate(skeleton): + if person.sum() == 0: + continue + for i_f, frame in enumerate(person): + if frame.sum() == 0: + continue + for i_j, joint in enumerate(frame): + s[i_s, i_p, i_f, i_j] = np.dot(matrix_x, joint) + + data = np.transpose(s, [0, 4, 2, 3, 1]) + return data + + +if __name__ == '__main__': + data = np.load('../data/NTU-RGB+D/xsub/train_data.npy') + pre_normalization(data) np.save('../data/nturgb_d/xsub/data_train_pre.npy', data) \ No newline at end of file diff --git a/data_gen/rotation.py b/data_gen/rotation.py index f82e6b8..9da8497 100644 --- a/data_gen/rotation.py +++ b/data_gen/rotation.py @@ -1,43 +1,43 @@ -import numpy as np -import math - - -def rotation_matrix(axis, theta): - if np.abs(axis).sum() < 1e-6 or np.abs(theta) < 1e-6: - return np.eye(3) - axis = np.asarray(axis) - axis = axis / math.sqrt(np.dot(axis, axis)) - a = math.cos(theta / 2.0) - b, c, d = -axis * math.sin(theta / 2.0) - aa, bb, cc, dd = a * a, b * b, c * c, d * d - bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d - return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], - [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], - [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) - - -def unit_vector(vector): - return vector / np.linalg.norm(vector) - - -def angle_between(v1, v2): - if np.abs(v1).sum() < 1e-6 or np.abs(v2).sum() < 1e-6: - return 0 - v1_u = unit_vector(v1) - v2_u = unit_vector(v2) - return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) - - -def x_rotation(vector, theta): - R = np.array([[1, 0, 0], [0, np.cos(theta), -np.sin(theta)], [0, np.sin(theta), np.cos(theta)]]) - return np.dot(R, vector) - - -def y_rotation(vector, theta): - R = np.array([[np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]]) - return np.dot(R, vector) - - -def z_rotation(vector, theta): - R = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) +import numpy as np +import math + + +def rotation_matrix(axis, theta): + if np.abs(axis).sum() < 1e-6 or np.abs(theta) < 1e-6: + return np.eye(3) + axis = np.asarray(axis) + axis = axis / math.sqrt(np.dot(axis, axis)) + a = math.cos(theta / 2.0) + b, c, d = -axis * math.sin(theta / 2.0) + aa, bb, cc, dd = a * a, b * b, c * c, d * d + bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d + return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) + + +def unit_vector(vector): + return vector / np.linalg.norm(vector) + + +def angle_between(v1, v2): + if np.abs(v1).sum() < 1e-6 or np.abs(v2).sum() < 1e-6: + return 0 + v1_u = unit_vector(v1) + v2_u = unit_vector(v2) + return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) + + +def x_rotation(vector, theta): + R = np.array([[1, 0, 0], [0, np.cos(theta), -np.sin(theta)], [0, np.sin(theta), np.cos(theta)]]) + return np.dot(R, vector) + + +def y_rotation(vector, theta): + R = np.array([[np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]]) + return np.dot(R, vector) + + +def z_rotation(vector, theta): + R = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) return np.dot(R, vector) \ No newline at end of file diff --git a/environment.yml b/environment.yml index 9d0ea0f..7e6873e 100644 --- a/environment.yml +++ b/environment.yml @@ -1,88 +1,88 @@ -name: asgcn -channels: - - pytorch - - https://mirrors.ustc.edu.cn/anaconda/pkgs/main - - https://mirrors.ustc.edu.cn/anaconda/pkgs/main/ - - https://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge/ - - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ - - defaults -dependencies: - - _libgcc_mutex=0.1=main - - blas=1.0=mkl - - ca-certificates=2021.4.13=h06a4308_1 - - certifi=2020.12.5=py36h06a4308_0 - - cffi=1.14.5=py36h261ae71_0 - - cuda90=1.0=h6433d27_0 - - cudatoolkit=10.0.130=0 - - cudnn=7.6.5=cuda10.0_0 - - cycler=0.10.0=py36_0 - - dbus=1.13.18=hb2f20db_0 - - expat=2.3.0=h2531618_2 - - fontconfig=2.13.1=h6c09931_0 - - freetype=2.10.4=h5ab3b9f_0 - - glib=2.68.1=h36276a3_0 - - gst-plugins-base=1.14.0=h8213a91_2 - - gstreamer=1.14.0=h28cd5cc_2 - - icu=58.2=he6710b0_3 - - intel-openmp=2019.4=243 - - jpeg=9b=h024ee3a_2 - - kiwisolver=1.3.1=py36h2531618_0 - - lcms2=2.11=h396b838_0 - - ld_impl_linux-64=2.33.1=h53a641e_7 - - libffi=3.3=he6710b0_2 - - libgcc-ng=9.1.0=hdf63c60_0 - - libgfortran-ng=7.3.0=hdf63c60_0 - - libpng=1.6.37=hbc83047_0 - - libstdcxx-ng=9.1.0=hdf63c60_0 - - libtiff=4.2.0=h3942068_0 - - libuuid=1.0.3=h1bed415_2 - - libwebp-base=1.2.0=h27cfd23_0 - - libxcb=1.14=h7b6447c_0 - - libxml2=2.9.10=hb55368b_3 - - lz4-c=1.9.3=h2531618_0 - - matplotlib=3.3.2=h06a4308_0 - - matplotlib-base=3.3.2=py36h817c723_0 - - mkl=2018.0.3=1 - - mkl_fft=1.0.6=py36h7dd41cf_0 - - mkl_random=1.0.1=py36h4414c95_1 - - ncurses=6.2=he6710b0_1 - - ninja=1.10.2=py36hff7bd54_0 - - olefile=0.46=py36_0 - - openssl=1.1.1k=h27cfd23_0 - - pcre=8.44=he6710b0_0 - - pillow=8.1.2=py36he98fc37_0 - - pip=21.0.1=py36h06a4308_0 - - pycparser=2.20=py_2 - - pyparsing=2.4.7=pyhd3eb1b0_0 - - pyqt=5.9.2=py36h05f1152_2 - - python=3.6.13=hdb3f193_0 - - python-dateutil=2.8.1=pyhd3eb1b0_0 - - qt=5.9.7=h5867ecd_1 - - readline=8.1=h27cfd23_0 - - setuptools=52.0.0=py36h06a4308_0 - - sip=4.19.8=py36hf484d3e_0 - - six=1.15.0=py36h06a4308_0 - - sqlite=3.35.1=hdfb4753_0 - - tbb=2021.2.0=hff7bd54_0 - - tbb4py=2021.2.0=py36hff7bd54_0 - - tk=8.6.10=hbc83047_0 - - tornado=6.1=py36h27cfd23_0 - - wheel=0.36.2=pyhd3eb1b0_0 - - xz=5.2.5=h7b6447c_0 - - zlib=1.2.11=h7b6447c_3 - - zstd=1.4.5=h9ceee32_0 - - pip: - - argparse==1.4.0 - - cached-property==1.5.2 - - dataclasses==0.8 - - h5py==3.1.0 - - imageio==2.9.0 - - numpy==1.19.5 - - opencv-python==4.5.1.48 - - pyyaml==5.4.1 - - scikit-video==1.1.11 - - scipy==1.5.4 - - torch==1.7.1 - - torchvision==0.9.0 - - tqdm==4.60.0 - - typing-extensions==3.7.4.3 +name: asgcn +channels: + - pytorch + - https://mirrors.ustc.edu.cn/anaconda/pkgs/main + - https://mirrors.ustc.edu.cn/anaconda/pkgs/main/ + - https://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge/ + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - blas=1.0=mkl + - ca-certificates=2021.4.13=h06a4308_1 + - certifi=2020.12.5=py36h06a4308_0 + - cffi=1.14.5=py36h261ae71_0 + - cuda90=1.0=h6433d27_0 + - cudatoolkit=10.0.130=0 + - cudnn=7.6.5=cuda10.0_0 + - cycler=0.10.0=py36_0 + - dbus=1.13.18=hb2f20db_0 + - expat=2.3.0=h2531618_2 + - fontconfig=2.13.1=h6c09931_0 + - freetype=2.10.4=h5ab3b9f_0 + - glib=2.68.1=h36276a3_0 + - gst-plugins-base=1.14.0=h8213a91_2 + - gstreamer=1.14.0=h28cd5cc_2 + - icu=58.2=he6710b0_3 + - intel-openmp=2019.4=243 + - jpeg=9b=h024ee3a_2 + - kiwisolver=1.3.1=py36h2531618_0 + - lcms2=2.11=h396b838_0 + - ld_impl_linux-64=2.33.1=h53a641e_7 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.1.0=hdf63c60_0 + - libgfortran-ng=7.3.0=hdf63c60_0 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - libtiff=4.2.0=h3942068_0 + - libuuid=1.0.3=h1bed415_2 + - libwebp-base=1.2.0=h27cfd23_0 + - libxcb=1.14=h7b6447c_0 + - libxml2=2.9.10=hb55368b_3 + - lz4-c=1.9.3=h2531618_0 + - matplotlib=3.3.2=h06a4308_0 + - matplotlib-base=3.3.2=py36h817c723_0 + - mkl=2018.0.3=1 + - mkl_fft=1.0.6=py36h7dd41cf_0 + - mkl_random=1.0.1=py36h4414c95_1 + - ncurses=6.2=he6710b0_1 + - ninja=1.10.2=py36hff7bd54_0 + - olefile=0.46=py36_0 + - openssl=1.1.1k=h27cfd23_0 + - pcre=8.44=he6710b0_0 + - pillow=8.1.2=py36he98fc37_0 + - pip=21.0.1=py36h06a4308_0 + - pycparser=2.20=py_2 + - pyparsing=2.4.7=pyhd3eb1b0_0 + - pyqt=5.9.2=py36h05f1152_2 + - python=3.6.13=hdb3f193_0 + - python-dateutil=2.8.1=pyhd3eb1b0_0 + - qt=5.9.7=h5867ecd_1 + - readline=8.1=h27cfd23_0 + - setuptools=52.0.0=py36h06a4308_0 + - sip=4.19.8=py36hf484d3e_0 + - six=1.15.0=py36h06a4308_0 + - sqlite=3.35.1=hdfb4753_0 + - tbb=2021.2.0=hff7bd54_0 + - tbb4py=2021.2.0=py36hff7bd54_0 + - tk=8.6.10=hbc83047_0 + - tornado=6.1=py36h27cfd23_0 + - wheel=0.36.2=pyhd3eb1b0_0 + - xz=5.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.5=h9ceee32_0 + - pip: + - argparse==1.4.0 + - cached-property==1.5.2 + - dataclasses==0.8 + - h5py==3.1.0 + - imageio==2.9.0 + - numpy==1.19.5 + - opencv-python==4.5.1.48 + - pyyaml==5.4.1 + - scikit-video==1.1.11 + - scipy==1.5.4 + - torch==1.7.1 + - torchvision==0.9.0 + - tqdm==4.60.0 + - typing-extensions==3.7.4.3 diff --git a/epoch_loss_class_eval.png b/epoch_loss_class_eval.png new file mode 100644 index 0000000000000000000000000000000000000000..246c771b3465f15b5e60cde15008fc8eaa862b84 GIT binary patch literal 23988 zcmeFZcR1GJ`!{|g*$J5$C9*f!s}vzb8QFV}>`f|EA|oTSvWg;m@5m;E?0F}fa@(Hs zt+B@NE5s|v zeaXh%-Nj9UkI(VH-@xnaYR$*@oQw!ALg1pL=Y~R&nj(L(a%FRDQ7FU9w-w~?d8e#S zczNHuf84f-Vade09V~b05$dfD4mKx)!W(?E8}hgp&!^4|qnYjr& zo`2`YXJ?S-AbWL{MnYili3T)H9Q;Ve_x}I?_5W9(Cp`?V9ow9^EsBXvP>`I7Y2ATH3(K1XdCR?fajP`2*shd2 zE2B3mT^AhM3a{f%i{PI-cQz$R?(*f!@dM%E;Y!*ptgK{|lncjh>K0O|>rVGX;{UNHelg0CMPsvzh`cVg?)@2gIfidbCEU4Kc{QG;TvtLzw^+E5KOQV%;Z1BsG5#6!7ge<3Xz-)921`2-??qccoWoC`4 zU1D~F&`;3jpNdZg!ld)%4-|1bekCKO%cCQA9w^KyTs4`>NzEsp#QE5 zW=F@!=jql$C&l^mauyaWIA@8!eLt{TpdVn6@lRZvsMh*yAWlflt2$BTvG->{U1&up zVBh`sk54Kvk*21mqVQeHrpN z*f=;J?tLJ?5q53?#)5KRpG-YitLAGyw(d>KvS^P{O%Wptu5?3x&({-zk=@eOr755A zyeR3p^6k5y3Kb3M#{P0aawy5gH}7t~c%zx6oHXWKKjAQ1wD7zwRVLu4)XCv?w*TH% zD{K{*z=~$%%|wk~rVvx*Q{{8T!m>J%2fT-1H_?P9?mKhKb*Toaq*(PgvW1- zjs!mEm{`7(%R^Zz)NXSxMc?dgqmwHt#E>~ChAqQg@0@mwqde~1RjAooLvPl#!{mB$ zq`+d_@keEK;Z;|H`dUxLx@@OK-MV7A)i}nuKVgzDIVLvt4vbXqi`mW4zofEY|t^QP-cBw z2rkk1a^3O1J&r+<(r#WoRd954^kV6-QD>q0`&%!&DAc7)BLuR8PL3BBlbx^GChTSf zw1ksoKPa_Zo4uA-tZ(Bn3BW8@nJB?wq;d`7mj69_jF!|`=~^x`3VMeKY4NnCz0ghH5&7K zC?l?@iZfDPYX-iUwL7anOP@V^W(qN}GFh7}V*kTnC-LhG>g{gPDVI*3at)XiBpJLB zwv7pEFTTuzP4~x>k?)w=hLD|%L++a6&z7Zu9NV!no|`YIJow;idtCGi9?rF{wXrJ4 zW%&O6?J!wG8{C~NI#6OeKwL0ZVyhK+uqvvlrSjD zV1JeMuyJraJ37D^3#$80!e)jIpaQ9ymfr-mz1Z!x?$s-eae;{Ey#KM~eUwZ`1g+3` zf+8PV9y*a1O46{A-^=CC)scmT{LiU)I$)QM)cT2{1OxMVyM_&eF|qNpj>A6j*2`XDh4Tv^&kA zBt9^;d%JT{9GaQcfh;eA~{^%+i z?bUCeW4LfMl+Y1dw+8!!4JxX0{ZTwF5mg^#E~m-`s%-UiwR`t2zz()NKHTHhEevK$ z^A0JW^nC$a!1^0TXY?qLPQdKUdd>DZ*u3wYCaNxQXgg{-Va)cQuvQ!}_Cr^|Ol7SNHzB1BZYYYkJ+4JYK&7`s{f`T!Ce;^0W_3#e2 zUU`6m@HB;bBy9g90!a>4Q|t97Jw=wCBNY#?2w8QR4CdZdO_xcn3plXyRaI6dynXvN zZM+mEIRS~Ucog*sm-Cdq!j|X>a}rtxpe1@$vCZFl!U{K4`KwlLIuG zbYAFqz1F4}#b^nU4;O7dqU((3>D;K|($3F-u-ygVQ|@H={rh*_8t;c?+m<>?O85}; z30+zcyTk_*e(SaVyE9>sLSDiwgj`XKM5zf3Lkfem+zmBw7s8kC_n$4DQR%iUl?2y= z)$;8I9Mba?6gL1C?XH$j39ih!Rw9KMimoM8@7;p~A@f#Z%+CJbWg&}i=e6^7 zrY1QuXEy!s?&WG=1)Da9l2K9Jg3{(Q^OTxhuZYF4+7ovg(q8u6_qS^Hf9H+E)Ez>s zKzxUi%Iu1v5UC8*Mn zWIT+2_@h(304rq})%x!4?r!X%)3vfyY2jk6P|1|;+*xS7=$z>?)8v1+*>(mOxAKHu zFm}Q%Uo(r)*nblj8MF81ER@Tb-Cb|5=qji-btlJuKTGWuy}iZqv_H*GH$EQB*NcTx z)EdR4p1^Au9uZ*yYv3?e_A*V{&uFuEF)P7n7Qn?7Kz+yxFflRxC@}7Y46}=tJs}W# zY7N1yo#xu2*qE6^CaS%*$K84@_ttL}HLUT3i|4^AWgAqw zE+YA7wock_NBA|DPEVoJK=z%pE|}^4yn!6`3Q=pMyGP zvxx;Khu(GAvt8Beb@m^13Y-oP4^=;CW~m57a~}Q;E!N;Et%}uM}7DQSr9sFw$fMK3?E?F!&@aEScKazyHt5 z=)gD>nN}LKUqQgZPBz3c0{HIbX?NFPCjAi5F*kPycFx4Q*4yj2+uv3AD*a}R+96fl zfml%meu0cYblj!+rgUo1(O(^$L_$hVgF!tjNRtR$L@sCGHQ`~^|L*2le#kcbt+^E6 zN?mESHvA}P%zsVQ%OWWVetYlJgTK$)BA8^;kgN%17%CEBL%v>d#deS6Z8f!??*UYh znIO7i5)+k_l)j0iE5@^~J-vStdiCh7jrU}2LAb8h{vTVwN#T)^o>M9d1b_e(U~`H4 z{$U67Rm`9$)v$0}pnGuKqbN1Y*Z>)CtDP;K-{beakjIaoMMbg6$fRlIs8wQU1uc|9 zt(E~Lo~NQ}y+JD7l6@yxx5$FvF&-(@0-3#X3Z<$(*S@v;ltBo4yf5_*MK9^e4S`l- zHv|j{A0z1aQy3$IcaHvBpGOTua#y-+I`FCfLem>-^(TQ6UTa+tnT{zh@Ey3ROwcSP92Dt)ppOa(e z7_`)1a+knDz_&dRKkbf?7q5zmi46;-NO(MlHMOj&1twUTcA0CdIeqB$FY=8A9|Jy< zf>~<(YHhindLw_Y+_UjtFjr8RkNG@^BgzM1y^zej?h0>zZEz6)&`Rm^*J1lXk#FjD zNvicg_sDF}0Fozi@r=yVY0S++zXeJ!YO>ypYA*#|^%F&2bzBp9of_=H(b zb@cMG>T>o{p#tyv{Oz2_^9$KOwymo0`Y42h=*)7O)`?7e)z5o8UR7X@QGwl~06{J8 zy>ST`>zZw;LW^m{%zS$+%W=otS7sIdzR1=9Y21Wtm4q6!%fjz4;S5fqkEO z`0yd%Z(#h<{&~7z%;Gq-T0JBDRd6$MY>tT)Ch!F&9NEaC4Hnne)uHx!(CV^Kz_-}# zRbx{Ir@d1sDF*7hvANkV7R>BAin}C_6PG`42*fujK9L)WXxcs zzUSXNt>K;T3Oh6~jOfjGL8JfP842%PaKb&L+G+c#hH&z|aCqf@5$9lKr%%n_i>FJ# zV%#s+mzMu`yKloZw{OIg9Qk5JxLgtW>l8$Xn8&U{_%1H}ezy}9G6Pz0L-#GDbAOPj(sBEA z{+s0+jCWDzm}F9b*^<%It3n}y>S+mSK~E-dZ)*VvUO1rd1*30=xkKr5*iO6j=N!wE zkxAT&P`nTf)g{|`sxIH4G6qPL+uFD|l(RN~w4c3!pfW=_$-I+%&F7VH`m=n}`(u)y zf>oxNeGQYiox2LQ!%uK7&`VD6UM-3O+bB<^Df$5o< zEZ1+`xUH%RJAIx`OTWmO#NS(}!vF15?|JP}PiB|~g9Wmwso`$`=}tHW1jvHPS(Obu z>tQHGK&p)b_ZGNoo&F2IVHY^y63ia9b&AXL&dZNboiw z534G5J3CHk-#;e5`rbmp;OQ)o27m<^exef+kt=giz-2PdWB6oExZ(uPNPSQc&|YBa zDv+Fe-0L9~-USp0Q!ud%)s?>$!N@>BDnKzR*bQX#B9(u@>2OH?j?T723^Z3=M$AzF zrPM5Wp+M;I(J88f?Cw4~OL~cg2_HelR9aN1yP|SgKVEtEBys7&G)Tq3G$am7 zZ83&SaLv5mnbyMJFL*jj>A(iT+kTgE?#c<$jHjSsH}2rC>iPopJ<=pL%$ z>jN3|p@Ar_8hz2})^tr#bjOaS*M2Jy_hDI;G$93bQ~9SF3k*$M2{*&w=Tv-1r={dm zB+9W3vEx;=HUS&Wet|13WSy#W@J0w%g9o!R-X~XS?ve;=zJfO^-o;{1(9UF`|9O_} zT+=4bznybV6M2;#rF_qok!jrQH2&w=DT8AwH|TnL&zvfNX$rU*7M};xai<%*J>B3^ zN|ErXBsc0{{nbSzhUq~>7o;{C*h?zBDg6k-ODfz7JG$%X{YT$DGAu;I{0I+;#{S)b z9C9fQ)Zvut7zfN8OV|>5AzTp=7uj{<i={o#yWuNDG_z$V*KhgFC4;$wkTi`i-qBQ6M?DHqtG7R^U*k+ z8bS%=Jz2-}#1f`27Ll9an9ipPJ~>(s+5^%u(VGq;7$O-02pcK2Pxxx#8Xa?=FOfoQ z@-*jFp4BF50F^ZrbXQd1;ih711Z{J&bKPx-KUy&tlXzomh;w1P!SLy6lcwh8rJto- z`sGG}Ka2S44`V=C)BJdk1M~%(@d^RRRTCnqQHs?rY9`(IFd>>5dORo$sO7IY@1qhS ze|b#>h(<~8&;Z1=29BZ^n2Kb&a%W@h&j#&4d&Rxh=s`W11D27*SZvWw{#(I zSE{6F^ERYDaz@4;tO7q2-W zY;^I5zViZ1gEgoPdVvh)#Gn(;fsi{N#uRu-LPA2NEnt5s=g-N zF>Mq4GEvG$aQZ?J)a3aIuc-x)P+XZPE?j^Kuu)6Pw!W)wf{eS@jwqwPUY1Qg?Tkrd z2yTfz+MSpE?mLAt-x1Ji6funrzihFWhNIWCE)}BW!aXAi)AC)~?h*Y#lqdr2VH2ujwVEU~FtW*3Tl8WhaJf(nR&If%R$#+XdJxGRZ4ryR zXedri#p|P}`HZ?iRG;|DBnMR8xn}b$%us2fDw2KO^j2}{&!2DDoNMEIke-qfQD0yG zR3XOFZ@A%SRO*Krcu%I>mj|dx&oM`-oe}f=2c?AO#H}<)dYOxoRRa~l;Z`hH(k4yU z8Xlh6KWy$$^Z*sfDE$V$bfDs)ITUgNV&Y8O0fgf|1=VCU;;POp<9Xk_33|v0F^<~n zSlrtAh=^`q7oo&9H6_fo1+<11ZADr3gIRmX!eu=p?9nPx*keev2^-A5J~!MUFA%tu z5KI*Eo52Br%h7ozStx@Fm$xs=DFFqOaZ?8K=$DmjmcapD z;ZK6h`+3@`X;MkRbVnmD51X}4nQUHsbr|;VTW&(>I7>peu~OWt3EB^1z@NJ?(Y8X7 z7Q;`aABkK!4OP=p6g|kDFjMAR^X>Ff-md_RX#fhG3_Nf=-o3%2j2q%7XG4Mu1V&Rs zmkF;A1Z>!TsYycCMb1+)zJGo-${uf$LdN?AwMX`F^tN0Rb#sj3aPDs3rYf=RX z?pcD7@3Zo$z+}K;p~SS;0lY)CSWL1``@*0u_+A{r5rp(ywZQ|dEG!!!4aCqI=dlRRB!ZL9l!o8P+Qh6wE$e*Er3fh%;CgQxawIXb79Xc2`A7$pnZ-g#j;gXm5FYS9I&f7`_%REQht0H-mHsLLbO=mn=Ku zv*2%uO@9dJvRj~OzJe~E&>NWR@HW`UGEk~u*FP_)+kXzg88ZAg9piwn5U4&N9y#8M zT~Qh*dNCQybZmQCGb|n<%%j#Pn)&ALTJ^fkk9@t*zbXm5Ss;r-usL-cR$>`{O`}h0$G@euQIbDu6cwNYe z_QSS64G8~UJFMN=7kq+H+|JR9O&d`&G4+DRx7Q|n9KA41E8{Qrl1_wKOe_%=unL-^ zKR)R&CW1!zRoJF4xS>G~X-rf+ocsJ+swVAUYSYNy!Da+0Wvtve=lWB^I6h;>V5p(c zs%-$(at~8~vO4JnB_XM<2h_Kvv2s56hFMUgYBrjPO6-RtA+g-eREkz%q@bp5gZ5wt z()eIY_iq9i;WPV;js5<|3!q5p^rYpAvfsx!=j^q3=JWW^f53@257z|rF1LA)#h%pU znLWTP1f-VIiz{@68bQ z#zKN|&+FkkOxLp&o{=Z5Z9}U14x5m8`DeLvFEp~npiK|zLuE|2`O{(dz^kX`lJ<4L zc}uWf&HqW;AMd{a?b`USJ*_!XS70FT)lwyr)?lL5m%h_le`W`r2-N?%zPC3N6bz)Y z^nd{ZSXo1ZSf-OYFcq&yHGeB?H6_@jdl2xv;-VX` z|NJ9)#n5QxT3g=lbYWe(GJEI|Pz5kW$o*HJgY_0+mYrPvkcQOHQ}7ERQ5$@KSyzYM z4DA*!=-AExsBjpshz3#%tqqs^1)~43g%UU{c&nY01TDj(qIwH?FJFEHYYgJHXS+EP z-)3g-{6>qd)ysr@xs7Gn^(g(fZ@w>rC@=^ea$bgBc@eaZVcwUFbdY2NFy_TtZ#6w2 zjRTA|fLv%*hyVJ;ReoK>X&j*GqJ^#~07ozjBrY8VVPRpNIdjHBJON2zffl8>{5(9X zF<;};)0u2-ZU4cG|GOTD52(UjHVB08(NcTr$7e~*AOj%UEwtSqK^J)Or&>qys=+y9 znX@X2BWkb2xaw$qP`%#-AO{*~nt^27zO@|j{{$Pqe%T^LtBdYqsN^ir`#;TVyCfnqqz)a!owPyr~NV8jmf0r;n;1!1uPV>q2bP@ z2W*EkgFlYa+A+GU;ekmj+Q0obKR<42RWyiI7o_%&%lt`Aph+a=zq@gBIf_B@7W}}? zg6L+O{%{SXZCpg{0CMT{e@khyz}NsgvzjRgGvNh_K~;KA zcJ|g0W{QG=!STwQ`6qLmg%~z=Ln69AMW^b|ag#C@a5Ej00WWo}uO;_+=o{qfeqp|H zE28j~s_TSjiPf8cO7NM zrLj0m|1*sirKi}YyF6jsQ-VN^(beh<$N*5fopo#RHe8_*>Y(__p3v7d$QwK5Zt2I5yz3N*_?Nd@c+x6)v?p-F?7 zVu%k6Dhvs8|A;@LxF*RUk8CpIZ{-QWLj)~2arvH&DGgPbeDL+g^JCZ=QgZt#4JbU( zu~Y5VN!u5z+Tb%WkAF}AFJq>e7CsN2jJ9M7Y{751`TkrN&2zOsy{Z1CqNbJ|c(e~4 zwuVZXN0R?k4T*rJAU1Ec-vt&gC%Mop{{=-^#+=DL=u77}c7io0V-byv?IZK|EM ziRHlK0d?8UK`nKGz(emBA6OK#(jI&LZn}8)AGA@|t*O_NrUlS83k zGh_e=K|yQyO7ozL-ox1z0Mn+_^Kt{UCcnCkM*} zJb`N}&6Qce+rXfg;9R$N)|?nh03LJSex1>WL0q0-y1BbRWF#PvpfNrRr5Nmj zh`fo}Z)5a}fX#^h?I&vc{RxLww##}g#ueIBJ*kq4z+Al7YrF7x)o$Ou%*y)pND3;q zj*xFeeGdV>!Ov5`SzuL)y8!`)s1GmdLAn+I>;V=lWp8iq&vh8+OVNPagRk}vx&4oP zQ?PD#LF(^3F<<)~6tK*p2qZ_Gg&MNXhk@ z*fyr16hxGLe~Sb%m6-dl^U!~vKRI4M`TX0%ke3?TbQ#cQ1Ryj0m=w7I#9xJ}0@&IQ z{u;zf2{UlB&EO+0KXh*Y;mXgPc}S>#`loHz?jMXnYJfa$J6R)+LgvtY={w>|Y>#CV z0!m;IunxYMDm0@Sw2X)>WT+p)wgHm}Eh6qmhqOW>+M2%&tvakkQD;W4NgwFf$&9-X z+<|Lw=@-0CNrBcYH^}KIm;q3D+0MI2Z9)d z=;?saklrX3l4-mI@cXL0oNq1N4JCbV^&1Qn_4+G_akkIcEju)T z)j|4r(Yf#rAYP>meu=K;$3+lswQ~nMPEl}yCx>WR*Z^N^d9Pj#MR-pOg?cXPkMJ7yMz8oIL*8m1NWL|30Hwc0UxB1C z_wKKV-QYC*Kt+E0Gp8PHDP8mCFhe2ZX37~Gb1A9fKoPE z=6KO-{2?B=A?M(G$9>SL4ii;F(u#%NiEh>@dh$q<4&FaonqP{7eG(5X+)WV8^i!p4 z+Ov~37T#;r5&UVN5~_tb@aDAV2?oRuz?dw-H^etVb0{@^=~VD#=>API~}G?RIRhz)7@`esUNF32y+p$p26gj0#YZL9D6TG|wAkW3(Uy z6?DkTqnIv2-Y@MPIF{BS-MCBj*?*RWLY z+_41N4!Xy`-#LOrW&uBBW@YIqUG-^_W*w)h7C&8-*Vr4tXOLDnSZs_-crINEmI>HL zI_X61Eb0mqs87aE7S+d2=jJtF9Ez)|)4k_1)VA zliZ&_PVsDkMHXtj%7aUE!d)wS?4pQW%VKwxETR*!i0Uq#afO-@6B8pi0Y=V;Y9L~S zJ!9!*UG>U-OvH<0Sx}BZua-PrTP&cD;JVFJ$S8yEIi zth!7)V1r+SA1FM24>o2s2G{mJ?=Sp&apSMfDL_*MlnRXsT2M_-D*pb^FW4)M=hkZi zzwK8@Sz<2JsQ=WI1nH}+tWN=zY=WkG9*8ZlL~n39{Tj&W6+uUHbKY;}K#@9?A&tlu zdw?}ZM@QhmQI4$n0eC6=*|V9qVT`k%tJgZd3xJ{y8}rzt=?MX%^a0t0LYr{Khvmii zjmCdDtH!JpV5{=aw^S=rgyhbMK8ne#cPL$nnC zU5gQ}MFn{k$=Mx#C-9}*aAnX%thy2#0l9mFULoc>j}PF!pEU}0!GEKKf>~eQlO`m? zzjXUAM7#_WPQj*DvNH**2ci?0FRw%XMj?e9DE}g`m3j-VX@CV_EV40D{Odm>l8Vi5 zN~BO;cp5l!3ojmJkTObrgP75+^B0H4byqAV0t{hLDb{95GoWiqV?I*twGO=&sH*^y z)+_Zf64HMOhPuEoRI1Qg-}dil~4uw3wf?=a+6%rUe(E^%{5gfaSt zgG%%T!tg99BX}PJ1_@fS-8gh66=8>ce)dm1gYugM2NnR(brc4IC&hjF2akTjW9omN zBEAPxfX*Rc6u>-;oQQZq#pCKcR%A(pND<&Gc=Gfq;(I^d>kYE1M1imJKXO~?S);u9 z$NcLzN6P8abaCsnoz+3LCCy%U?o5?~jCLDZXnYTt7#Y)6RT>a9=~}u1VaEGF2awrf zZ+iMCOtitQF-=)-SZtW-9QAp(g40~Vh)BSwf!#Q3a866N!m5z*PZ0ojH4U3tk6~aBX~pottSj0lql@X~;6hJnj0f5b-|YxO1^W4YJ*r_FED_v_9?3BgF{K%E7zn>pa_go$$C z3O;elTQ8YS`=&4J?#blS=O72ISg7ZHh%%th0Bw43vdzuK`Y>Gp|Bsy=YX=DK#;x-mTGl|m-QQ8-;S&&3UV(RPkk}o+%oQ}@!a`Ze`;8W8nx)CWn^0HG8!mL0(@fpMVzyO$OI70)raE;u) zJsNQM17z)gdjk*3fGj})mWb119%J$1$K&FYrz{`}Bz^?tqF{MSoTt*jM^aV3cW|Hu zw)3^BWnTDLz3pa%7Ro8{+^0~vXEtNcU@H?!ap6t_>?D>0h*P;Z~j zDoIv$8RP4TQWZxLWk+*buNJQmv>%Xlj7kE*9+huYE3uAq`*`rs7Hz$9LFuRSY2zQS z@BwOPbu0nMnn6h*bS3|+jtyUDZcpI6_H^sz<-f%F$!5r;D?PGaAa|*PID2>mO{9`K z0L#cB33zL{lMilW+l}s?(~5M#8^|Yw{PFG_DagTbrbkd_1^*CJy*<2BMuFfa;=F`l zq>c9K@(n@avjHn30%<|VK8WxH``X@Gy)h~FbTKdJ!j;Ib+$scu%y}}e$ z&P3?Vyfqv7bQ|e(>j}?Xd1uMj$e5r$ z9YU_IoQsr%BTYE4PyVzjC3ND7ZDYSM$ecVX+`Xgp6R+N%ki_@j+GT($ur(j6P7WNG zDdJM;bbmrK$X$fbd7jw(TU9gg8-nu(PBDN>qdf-B$fSb^j@Uzy zb3ZU!URXEZcG|Erl34ye-SV0lP_dT>KQseNLe4{pPI_NKAo%g|F(AO>ccbzOyzgw0 zF5l6g07kJ)}6O}bh#Dgb*5syl!b3}H&64tZ;?lc&fevMZO z0b_&$3uP#2!90P_Sncigr@1->MxPgA4SqOWf<|q*kIOXmm@SNQOC*W|rJf=t4@ce* zQu5;=lufYS#zOGVYXbf?ef@SPwP!Bt%BmG)<#u%bIh@#73C|TJ$cDuM_TV8T@N%VXITk?N>A^o}A|Xy_mKV=s^5WC)jFwOw4snP0i4oThZT3XL?9y zKKwnkWAQfI6E#3z`6Zlpi3B|YTFyV{u7{B_sew(v4-OG@EX2h)4ttzBhFo!t^ibK^ z+27wQl7fL<89Jh1UVu1MQdM>5^V?Zz{~G7o9od4M_hXLX^XZZf9(@QIQoF#Y8ziJ( z!v#zzz#=r}?4Se#PM!mLTlcF1q8%F>drRW8#qg3sqA%Zy{<6WjbG+t~(J;+*v#)VY z;A8-p3ceMtJtmcz2K#^sIN3oX)d&5M`O+Cxy3Fi`DD#X7S;tfNl#Wz3N>xn_k&M^S zAeBN6(|o5u0LP$uVXPzN&PC+=(D34i3f8=<{I<2pZ|RL~#Xg5u zYecXkDudg`aVp?A+wb85bs#@jNZSR-J{o-K*wb*j1&;rWSt+!u+#kE1ROPaBCZFz9 zPJwX?w$hAUCj9t6oqR2x^m}>6)(z;|%lsG49us0UmQg$dGik=UGpQZcg-WGv-t*c+ zOBc4b&-~jgB$UCkmHz&zqw=V-KpeXj2vjVoT#E2-tV6YUhanlataA`|RX6~AQuH!j z>3U)6bwr}?QpFlR<9XQ;k`zfh}M5R{xETgXXupJ=OHiR=Mh8sjb- zAtPUaKH`L%Ek!PAJ5K74W3t}6p+|ap7y00?nq;jIcK632#Fc-VjULX)y8c_-uM!D| zSA=q#G9(8L4KMpP-fk~n>wY3YV%i-vGfr#YkNiFaey?^@IV320yH`j3$@_==)~JZj zcwB9e=6}p00`FC@Eubc3AA1(a_Y1x%AzeIXn5;b)UyG9y>ySeMA#@iGXrr*7Uylvk znCI)WWx2pb3U;;-9h3bVI#LGE2EPOx$@Ox*oXsEpKgL9cZSw7QJ-d&V)YzIfz5VuJ z6t2O9a?vWAe;@l!%sjiVQ97~rm#eKQcX>Kb-o-tk)*r~zFxb?-E$r?FcX}M}9eccb zgPq6L;5Xdd#WvmNyt4%ZHNb^|j`kl3{rXFJ&kolXeaVwB!uHI`GA`baF*R7F7*-f) zQS(B`uf;w$Dyc2$d-at3Jz^nwt)u^b{S;21a};)nJXuS>et1jw8u=R%sruDEPiwKH zUsSYwml5q2T)a0w-%pZo!6?!W^915t^FOyhC;y^RO_7{31inN47kQzQn$TIogfmm9 zkpm&s)wAHE-NEg-2HB<)mTz7-fk>hbOWiT}Do#T^JAL?5^|g@OsU{fCF@l?L>K*n+ zJ5OWQ+ljoM73P-*;yfYM@y$J~WwCIJ30hd}F=Nl;p6nDf273U&beTIzohl1TZ6&(^ zzFh!pECo;EpPp+@^$ElJ{@r8;<$ZECafmu{wJaI*n3_;%ROoBQTq>bY6xPKM)C{BQ zjwD&wM%1o?2%`qN6_5Pal!s@Sj8K%xH-6Jv^DlXsPal}N{LWo{0=w@m41a`b2;Vno zp^I0u3@b&NKt)k5MFL6JYkhKk%zHPpw9ikFLnzhvJb7c-<1AUosC+-{I#$_t7(L!* zB2@#N0_+265V#50WZ*RBcF{r5-a(Ct81OBviHDAfffH$|FU~PZn zu|S%i24G-tBB>d~zoL>8C>J(x<`0x55|C%#!f8(T zweckQXTeo7+Cid*!T;#e!>9p$ZwK?7r~EwKo`T@@%P?s5o*cgibsy+y9|*J#3_8`G zj%&3TBO~9XEbi6xQ&Sf^)R8Aoo*3NJjOLc`BMdwyKe925CZsm>4+X9MTeic0m1qjA(Mfp);@i=bQXaJh`8gxE1`zTF+CN%RRUs%0VrWro_dc9iMKdXugV^@@;7l2x3t`KjECp8fl)c(NF+tv{ zScxeqT|2Ajljzn{iCP6f3A=Hf)QwPLQ^3-|L}*=}*MpL{1sl#4&hL;hF_Fr_v2*ab zY4^y_zj^%`ch$&)F^4`(I?n{{inbmNMPm*&adwxcG)f`{By`_$+;cd~n)l(-1rD2t zNAt3RR30xGpE8^F?A&y2*t(m<6y%QH!!GRLnKq_*u6W}MVNaJrSh-|sHy9y)WQ|qQ z2c|yUT*iIv^1RV0_A*P8p4V3RKvFl~uQ%#CuAxY6gu~ho%AErG_u-V7fbDk{I33#! zB8wh8GXMpI)ePj5q_uR)sxUdie7ykEqmz`qFur@47w*|$dmZXC4KnYwj}A=FEk_U# z7A{=WJ($`~ugh)FS($ zWp0U~-(cf(+QHG$ATUyj!xf+OO7MLInnkqAe+ikD&V;&Os9ri4^uHEJ)KfB=ERk2z?&8#m?3P__;sK=4> z0|%NZxnA{=US;XFis^4X>-xds9=_raL2K)7@^|TJ2pgk&e24X>y+5be$-f1mK@d%S zrW;RO<+phF$-*J)oT2RX(Pg2vAb0D<0*|YEhvo@ULKb(dcn@EBRW{_DpAIQ0-k767 z$BY;D+Bz@QIud?fy%q5ZvD?;#N-uTNbr5r~^?Y{+;md37&)zE5<&6#OoQmzu zhrC0(yT?B3(FH;;ZdYA@#BbA?evrK$dpv^a4Oz$LWAF?F0YNC;g?G`_)RZ3gdr76q%P(k0did6+DFx3zrj?W%-!{&j81nI-rcRvT+r}(yP40o5xV?Z` z-6+_`osMSUuMS!~V0 zs&2ya;l}yVyBH_5QBJR~+@Wux`^7(o9v;#(tj@G^CD?3i1_pzx28+aZ;%x7V64*`6 zNXdB2f5-UmQG57)G0wHIwMPBvG8oy{&YWMs@B}t)u*rWFTGJOkob4=K3h9)4VJ#SB zNTxDKv`d#uT^P=)=+!iMkm3_Ppzqg`Hdw5tn3`dB{`^Y*SS@)V-FlaLYRMGGz984{ zYu%wI(%zfP+0-88HLl0&dm~P}j&)k&lq6nd)Fi{Js2HC8V{M|0@CyoGxtiuwgvt(C zncN>dbf@L9HXM07GktADu&ol!V_Pq@sPFqQtX+37RSf@PeZ}ry-XTqDJtm2lT@UW+ zW0+ryEJfp^7#ci&~coab82O!*pF4UT6TL_EO7d8E@tR*c9>S*Qw1?hAMVLs zv5Ri2eAA29ULcZgu_8m<7+PKix<3#5d3vdXzjDbXRjH9tlP^WYdYV2;kWu88`(WqNHZm@Cg{=L`?=!9> z?i3%CWQv1Q2gFJ|D7fl*K09Fvltl^GF}HdmSaGJ$T>a=_3D@43;$;?y_PpQOxW zxr}_>uD>-gf5DB}NO@`sm=&2ixgHba(If}<$p|V|q7mfAuU2KThDN`|t$*iTkrbrz z9>#=(l8PUsCOv{cK>;C=ey8!OUC=_P+~Zr`71fXP!r zP@!c+Ev{0C&2*xzXXR>}r*Tc~I0>4GIx4gZV=(@Pn*QNtiFe?m9!o>AXjM3EDWXA9 zq$Cp`PnZ*#noiY4gLbtIq!X5~m5r$CNt)%&6^pZv`%*zvu+T++nOXe5#=bD;SGS9! zD??ZoL^^R*GPZ5Z#XBr^rnM~veL1>sg~K{L&KL7I$W*dYp=J4Wh+eb&+yQu+FrAwb;jOI~=#i@KoO(J%oo385 zf&E5?;Pu$NodH>N$~&k6XQr#EP1;EL-!ix#@ZEC?*p%_pX>NnFk-SW~0eHf+sOZ02 zf*;@r4-g?{?7p>hwXjzcvQ2DW@;PKZgZ7)tZ?98Wr@NB8HFL8LjcipHEB1h?m!mto zNSAxDGCWK?0lWC1dEd6j?c`lB(b*%g;PyzE3rD6aVdbnllbqpUb~Sp7s+gPV+s8%k zC0XPoZ1fwVF?ZoS7o!}~z2gN9qtp7Wzd4+s`dIQX)Ds+Sm!JwL@Um7_BF3Uh@QVM} zQ6yk||NC)efJQ5z1p-!SRd@&zunjP08*L&FH5vHdPcH)Q(|MEubeB2aM-|H(+OPj@ z*@=4bVh&DxWPw1JuU{UCIE%n&b1LAuY9ilMEtO5}0wUHD9A}8fFu|F9@K{+xCm)5J zMgU91HW;$qh7G+IG?kI}A4p~6@>e70mB1sCU}^|*jvhEIpfOMyrwpW=1>>ViWa9a? ztxc=SeTjHkIf1tuJfVnJ1wKv95DSu*vhwo-AJDx%qe06dYqvET!^WT9?~GGh8el&8 zXJY-Yvo@~M??FgWQrrHCT?o&60FBsMH{sPQIneU(iHNG_1pt&^77*yKT^TKA<9i3E zy2Iz$3!qOA=Lz7*f2`>QoL97gr&mFDH$e(cx(CCbnsDIt%rW-Jt%!>!K_kk{+#Jg~ zF#-PPxjsoE_Q?w|lCGbC@QOT@2(0$d;ZENB^*$<*-xRx7YCGcPE;8Ldt>zCu=kA&`wT!B$-qOdJFoGEvqevATH8>~#C$YUdrqetJqDZyhhK3z6}iE5dvRQ2t} zYcB&E4)X8@;1l=+1Ya1C_Eym)cw|g&E;}4`0dq(0s+hL7Dp3Xr2!LlboKKIj{DKA{ zCyMnzIgTBx0aS45H71RHXT`LUst;~liBfyx4aL3cXl;IvHb-fUCkmf$4a>Cu4y*HiZjzT+|Z%elJ? zgF6-Rm%~~3pOJA_5Y`HhWq?PzT?OS2EZ9XJ%i3dKT*u))f$~fJ9n>m&Kc5JP1N zxm+_?T3@{$bTBsMQfUxDC&JmgWtf+rG6kChI)X$nt7q`=kS98P0f%Z0Y4RN%iSqe* z3lP;6px*E^Rv>Q(6`>9fTG2D1gGlU+RIU+7y|K-ULx2wnfaQvH2OzX~}$0?!ly zsb-s9uczYs9`x-W!)+s`HE1!nApQ;HSP7i=*%q_iFeIq!R(-%nXxGaH&{1s4fdi~8%F2JBvR9*oO<-N8`oE+W=! z^z^h1zLsKx<|#hv2Z$&ycAc61&g}i4zBf?Vqh7pba^$FiUgRl@-ky2&ehX}tI2LUO ziVbCo70?QF;NxD2cc~jT{=2C#Psi2g|T0fd;(1Mw+Mm~Y;c%M5-BEjW<+6* zkxuDIUS}|xvfS6P@^NFX=S%bFhcAAo-}>iLN@aCvh%hDuOfXX`_VIa`=$zu9>P-mj zFOK&+S?rzICH6|_x_1z<^4PE~bAlw()z>Eqxd%bpP!Y9kPjug{|SQ z8hsbDw21|LL+r4aq8jA;Uvd?R1>8d}$t3NDZx(w~uo*$G@SyPmfZXUZYRY_S1j|Ns zxp`0TaVyC?tOYsi&&Z8*>g??$n0B?sjQQRy@V)1=7bS_))*sMKQA}pV;fBEWHi9Ma z0D;?4bsd{A1w)+Gyrjv|@{t>(^jlDBjJh6ev{6#6{wIq%%N7u68n1tB?S@~&Uf z=s??XdAPh84c;~uaVu(AUL`Z@D#U$^i&yIq{alLm;OOYN^ugO?c{F6Fjqi!FhKw7y z*r(3SS!0a}nK`V^dU2pT1-hc_#M7mCa8Y!rfWAj zy_YHLZ+KFGo)>QALhFl_Bk-)cza{Lqf0${BkJ5gl*=1clo^Ijht0R9jB`rA%lEhu0 zLQHEIBZ*`zD%v35Q`gkRgGBa$?N)63C=Oj;Lj|+1M11m$%=NbVq9*}L9h=v=$Dv#>( ztIgZ4ch^~AYlZ_;7PsNUIKd@nX8!m@$z}F086%f`umNUQl3(1KdRyKf$S9;Wz@c;P@PtoF4B-qrwWlixq!#p-&`U@{uTfKkQ zngA+##qkhr0~5`R*b)-t_v(?QWsPmk{zZs)~EA+%m}<(gcWvR?~POvKH&I2|hCV$)g^&*{kdJ+0!pQe$ZDNl;~0g3!Ux~wMw+VgBAc& zxr9C5Z(3-^)bW8!$Az)${rDj3!1-$yimeq;u6Z92&BH%zom zO&+(Yk~qnt7t5lSj6ERaZzBmw$2tvvN)q^8JOXR%3In%1n41H1r>Uaw=Mo`jt-%*HY9DiFq$R zT#Z=DSVRxHaSEKF6WW~h^`?b^}K(!D#aypu7m32_&JK>6~Wpu z%G01jmLh`x#h%t2e41ynu@d@badx}htwPOjjNh5+j3&f<}L1yYl~(t!9iWI-rf84g0%sk{2Q+BN_qeQ literal 0 HcmV?d00001 diff --git a/epoch_loss_class_train.png b/epoch_loss_class_train.png new file mode 100644 index 0000000000000000000000000000000000000000..51e019edc4c558c878d4b2abbe1f87c89aab39e5 GIT binary patch literal 23976 zcmeFZbySvL*DZX}-I4-Q(%s#HARs7GN(l(kN=hprASIxvbcle8gaXnfp@4wW9STSz zQc`DK_`L5q zK`^HAap4n*zVRRM9~sXJCZ1Q_Zh87x-?c+@tv%hH-8`Ld-(>N&yX$e=%~ecLOi+}c z#lh3l-9uJL$mPEe5Olk1FC_Gcf(R}`;C|8613}2F(f?r-DCOTq5OZoRbrnP3wAG1w zHw=Fs<8Ac3uzqgMBsqdj{fWd-U8J>ce6`xM&dx}%a@F23+17i2=KXcg8q=k6@lJlR z_c(l==ljkKi5uRaC?_XZQ4h0DzyD(EAT*5ugBmBpdVVXYZl2oXtH;-Xhq>v+J@EfH zOq1=%DJUosCiEjomDS--Y%~rF`a!xi#zpw2t-T3C3I86UrwKv7_I~*PfBpZNVZrOe zND)o?92QM0)ZE=olE@Jirv?vKshn(ghMJmHSeTqJJa8wx6sg6c6|zuwcQ09_gYP}h zii(L*oH@fG`;eR;Sr17~Jsa{`NqKE;?ZI~lhQ6ZSU!M4+q>`EKs}w|ya_%m}LlYC} z=)4bjO>44J_V$+4zI5q=Caa@pT6#JKIXR-Kso7`nmXb)Bjg9Y?=!?h6$+&q}q8b~O zrDbKyj6)Z3ZA~9GSvCIbe3p=${7}+mba>UaN2(d~ab(9MdScw?I-J;ecr7C%rk;#B zoK2rU6YxGI{PpXXFdcEPgLn?RDO2(%b#-0DUpR8nz4T0QJEgFG;+M|{qa?1&ZO=t))Z0r;YjNS|$hG0a{jJRjpD9)u znJSB&O5e?!9c4F!4sw=GSBgPyVL?HNc3$;uxgw)6{M#$E5g}2L1x+w2m6Lu%+1c6d zckim&+H%zI&!0ttNzGlyKT;$lB*1en_T}Qmol|lxjo;hzg?UmgGRnR7#>nKOdv>c| zokje}dwX_iX*$Pp2A9zfO<$h!35+tzdgAx>^|^a{YZP15Jx)pSyFS%!D{9^J00*Ct z@PS#8#xp_l61z^CEx*vFkYC^P*;!Z+hFIJpvvSoNH`4AM!MpV3Tp%hKaKy&Nz1bE+ zU-_r1@;6g))!mYd&=IiavMho$e25fapE%p~urKP1!udc?bgyLqu ze2J|*aHnedUCs3DEG8Kl*`KwEvoQElUAv#TGN%WOEnE9?HDQ1`w3@jx12B=t90Z-u zgxcYCXJ%&~1fLvC2{Xj{u9UWuQ&M7badCZc8X|>|t{7g}R%^tPwjd=c$3m#!R9EP_wb9^5Dy0d!_zX0 zFUoZ@%(U!HCFxKCB6?1eV8v^yy)oa!q&g$S$D8jV^Zmg2JPdd zq|jW=_{nbzlV5ylU)9z!PE>of5DC>HK0ZFq5KB~?I?4_X+jSSQQmjsmB;0{vr?p7dTEWx0@q7i#w3ws z1>fS@--NTXvm0=_YN27XttcxuH@;=yHg!Ni06H4-`#+NK@$nn6$mI2x7XtqN%8!hV zZM~0$R}y>@q>~{}5AWSJ-4yEfYxiO>>vtDEe*PyI9pT~O^L8^rF3H#HH*y0Mi%qOV z*~+Nx?CgFQUd{heR#xV?Y*xQvQ&jx<+_@5mvE$b6@`^yKeZ}Ys1EPw-;dk#IMMvAk zYOt@5*;k#kGyCf>D+Sfo+O(6>ht#@G$Zl2h8Nch$)6@I?%rc0lvy5NqPp{fvZg913E*HS**nw64KI$A?qF=!GgTB@HM?#2@e|^LGF5ZEdL4glw}gMs5P<4 zaJt8JIV0xB4_ow7Vz;^)2aB$EudX$m1U-5BlvP~(%xPd9?OZoB8k`jK?=A{zcZ$(q zw+KG+@?0CIhqdGG<<$x&LDIuOB4c7cHQvYSE4<1z9(;1Nx6wj&XM5=ya;@Rya}t-$ zVsAE`%v}x|A&Z!mm6hv%$InK^#l5R=-v9kGGA>Tz{jKilrKLw?N{78^rINu4amoXw zZ$E@R^ieD^RYQ~a!eZ)6(v>6Gi%f^ zCxNd2n)UkV3(+*$kXGIA+-Z8>qev8*y;S?skT>D+Y+K8P;%5W@$_Rm`cY+m5t;K;J?$?!VVJ^I|*+B$n=Y#A8-Ld^c=Omlc- zRMdT_?8gUFC$n$`sqvoe1}?2c*8}NMxgBIs?D8`wg%IioN#qGd9^Re>+uK7Hj;ny}#@d?KQn872IJ z<5IkmjC(T1N6SBWp)z-Tt*WZB9(-&1T*iY_hm-2>pI=z%dJ!-MO#=he0@o|B5q^^r z1fq6hYpVqk`}EwLP5r@E)k20Eyi2ri&B_TsfBw7&A7dcv^#_8On3%or zIrkaDfTQi9josadhzLBz{U1~Sr~WMVVSWAjb#LW;&onH*V}<@#@1{if`H43-H-{_T z&dDn%Y<`tnXW`;%zwwe;+8r6IbZdXisy4m69F2uXI{o`+CoCtdZ{NO^+-xO`VUQra zfB*hPUENj)7~wm!pHh!q&B`6aKYS49;NoJ3Jn!o2s{dv{qJo5osAXy3%_d|$Rt^sA z%)lM`prdW$u&}W4KOO}^CCFR*c_%~$TSq-?$iknbzP>&(I=W@O;e-^xfF~>!NOclq z6LtIZslM|s?A>=(Ow~0sZniumarb=tH~99@JM&FI2{S*RS++S9*F6jgLC3r?08bd< zzg|HMXSc#Sx&vtwt%~qZm&fa3uo4zMW#xCrJvzF(FZlXOOF9pSKoLcd8|LPlnJ0&M zurR`}6sp54mEQi&2}A1Z>pR^oHKp$C%nu0(0ygYi&>jX%t<=$YbwK+w3qJx_)O50* z+<;0qe;t}RIw=Ijr{-`Xk$^q-8`Yi!`Ij@>@-u@T zLCs+SaP>X^a{I~A)X9YS+;FK)Y0#11c;M~?H#c`9tY7}P@0PkRL~&rw>7-m40TAtN z&WU&}89Gna?!U@Z+g3}o(|aFS1IhM-nE)XX(ahQ!DPmG~#6PJe9g0_jH`iU7?pB7! zmUEU*NJ!Jnj6qdZHLK4-y}i&_&S&itpwiek#>LGoEoz>OWg=34Yp#CZADb+Jkk@^5 zJABipJS1YP+g^qXfCS}zsr&t6Mehl_$NTkp_TQXmEAs|aPjj{ZaNBFOy8dB>|L&TZ z*vIN>N$rQ%tw|dAUMsC&aImql;u6y#ea5T+LJ6p;C3;4u7Z!%N>2q>(p@58YXL@nH zo$=@UT9yK)Y>wKEGY%WL}T|_qYE)qXIPO`OmnaP9X$7ByUQVQ z^K<={iH|TAC)Ii}D?H z7)%$oLZTM+%>DmdyM4QGnnPW;vmB1_>&j4=lL#Sc{C*+)X5jfuJRNl(9Y8THEiH$Q z{`U3@jy*Z&bCf=p#eDmA1(qm$fu)`Id;aVQD-E30IYRMb-!6bzm|}n^I5-#)roFH3 z>A}p*Jl`N^$VAkLJN@g57ZZy=5>=phF)<9HcJJ@sL0J+!iX*MOE2RtToF<4QQmo{P zUG>EY4r1ncs-4jvSzxhY+US#2VNYNo_zeq@Prv3qO{D^j#aqHmoG#`@+`cYd0a;9# z3WFl&9Dm1=GP}j`DzWj22?&O`=g;XNR;mVfV7kLNv<567{9HfvL${!Qb6B0sN?)IR zGGI%0E7w!;=+D&{W(5Y#I7SutwrDKx)hjd4BT1*hw-&$CItAKNEWc%YU9p~@uGlWS z41=BX_a+t2bVmZaz+}Vm-(`~uXZoQ1pU=+Hgusg4*w`r8mA}{O7f`k?uNrvcwAV=( zD2-yBZV+XNMZu}Iwl+Bxl}Aejnr#6SLeVL;X^%@?yIwq`mjB(f)^T>bU0irS4fi4X zwvW5dgtW(>i+J;fuisJPmClO?ZVg8?Gcz-gaNz7m)O_UU&c$yp59ONIc+*L{vv{wL zdbRXGY6&%ZBhm06G}Os?VIM9SZbvPg;+rG^BW6mQ`U1DjN|*(z7Y+_iWPE(P=D9$$ zD49h!F2^#;P$qI%L*6N`=v&_IUEZGGh4tLj)HLAAUM!r_d&^=#mpj6e!Fi1W;+$<$ znk0r~81p5Yzz#H?VWv#07gdiTb|U9M80d|Rdk90((pD0hK6BY z%Y)VcyoSr{h!Cr?P5#vx4t<6%Jk6>B(TT!Pm*dX?<5J++~> zmM0?|%!a^7JW*w3?Han~NsN9#^zSNVVk;QVR+e0Lv+72y8k!QCLo^vVb10Js@Fx3w zT6Ngv-Tb-!3+ap8r|$z#_tD7*(}VYPxM|?3!@54z=ZZ^TxRERU~D>=+&xAqMQ&&iMKb8p{r{zv+LaG6l!3BZ_EyW zUCISiZt}r_0?7j4dQn@udCaXL7P8o&fBjb#W+Wt+mT)O>3LR1}J1T{cU^{V!Qs_9J zJbA*x$@!{H@?duzg}2M>yY{~? zMvP&3IoU3HX0+i^R8$l60am5XfPbNXd#IiF+A9(F|cC zX@EfPqHT#F({y(ie)Q;32@Legcoo(43TF)c0gq=-pN2z6f`y=^Bk$5PHc3euwDsIs zpNfCVXAEpry+cV%pFo}apQo(7vI7=YD32|Ho63y{nO8%CWhW zAt*H|Dk|Ews-~)10;sVoO@?Z{ZkHJo3#-)jb0Y2w%y|Y)JRC+Q=UboBVXRVwwd%ch zu?Ak~z0CCxbG#{rZr2B1iE)s^mR!FaLlE_~Ud{C!|BZjO-`X4#P7vcWPX~ zL{}s?tYZ31d{sCvrxbCv4Pe9n&oWH#+MkPJX@5^>a9f-}554eCqz{Hvnf^KB4}ZKh z3gU@8e?Hx~d*=+(X&Q{_Wp2Ezbadvdl&g3{rh>nuyu5rbbWShx^NDl@-OmsbMDTrM?s!Lv4BzM1ci?g}889}B3 zcbSeip{>88zq`JC#7D!Dbr09C(&E6~(1c}uM zEfXC}=dQ}?56N(I*4!Pc%vQJjS7b)$vlT)@@yOcd;+5inOG}gSz}=nn7ec0gd`@y0 zc>T1)fkb=x<%8T%OpT>wv`Yop_ zf=Zi2gCnDg|Fm|FP-NpwZ}ljVrR(7#(m{vP8UAZBKqO%{oE*#Dw6VFM;p7e&#-#nV z^>lZ;46zQ3UaY1A#&e*>f!9JH@5vwUR5r}SN=-H%tT)61uMr(H$gOdBhiQ7}$&E$6 z*i(`}p^n^-wj?)~9f)`g1f?oIMNwfO78Vxs87kLls8^!e|Au8*j@{DNPCTuNwp?K} zA%TDYs6hARx4RY#Lq7|d6B|JXUPVO(MvX9BwfRyo5iL_FjPgoMFY7cc%jTvgjH!Sv z5?Emm=etv}5J|Vm7{FjrF&gmbw(q4$x|j-z$2u6~-i&!hFT@qo1Q8bVL z-iXhPiHqd+MM8Is5?8<-Gs${30(YEkINDR|>qK_TYV|(xnWx5O2*8pYT&@p@yw4h= zMhxBkw{Q7Tj;f4*uIq&u2pVm`L-bX9IVEf|bZ^8RoA|R}UB5g(J!K=SS~yL@ipHUP zm<5(Jw3Z)SCm4W!l@8bw2DafrObjt}yH-Gt!)w0I$e!GAH=5Co!od6f`*g9horT-# z=;(a9?pB8*RIC)CP*+#CIiKprCL&^BRo0ewg;$tkoD3;>(Sr*hAP1N817i4QfSZ^1 zbw%yRk93f2F)5?OYHbk2sxyj_PPbMl0?ov6y_g$r!3SqE0)A^TLPuZ5ua<0g_`6f+ zxu_kXMl8d5QM24-rg=P&ib%R;qP&h3N0Sw8ENkkVUn zE<~$qIC;jA66!3DP%6HQ<9;)3GXr-km@#RU4}m(ZrmanAYpMY~H7s6r&Fb^p3i0}- z^^~N9z^EVT+=qnHKP*j_McsvkkDw&cg()ET`^ue+>0);VsODc+JZeE(4oSU5@{L#% zmG+myoJld6dh;v-#p<4CAOn`E7+puS6OZ&B%AZ6e?KSj}T9177$jJo&lR#27irJeH z8|6rf?Co5n5&ZQ9Cm+q4*V$hp{l7e=0tUpTvJ`%eHdFbO|5gn*XnS3^@SjP5BmQh` zgyiSj7ZnE`{T;X+eCN)c#mPFk@$qq#1rf1nW08>9emI7fkwlJb{^^Q=fA#9k>h6_U zawF4Qmfs8vEm%(x*$5?TW%}B1X|BcYv**tvZ-0N44efBj+qa*90}{6FAWxI^A`}!9 zlp2?mkQkVy*|^o6GT>A(=5^WMGHkzk_J*s-AXC>Ts^BwWFcRuM+7~e(Ri5}wOu<6n zpDYIA2f8Y0&&6I{f2zR1Kmk+A!8W9p91()M4+AwJ2MeFpdO`yI&_vsF}h zBJM+;F9^aHKx({}9|GCi0(Bg%m1ut_H<{1h39S8vOvS)nfI4PlIM~=$t&gbD?&a)1=Cy;47Zly)dL+v1a!-izCsj2u;o^19qMz06ZLT|6MZXb+1v;fVa0 zTN+mEx!^+z=vkUy>86s?&^&}j>AO-cqv9y*>c6w1qOY%?Ucm$`FVG~nQ9$M~R_mX9TzG(HYfhdfCF3C9@dz>wV_o+78)uI&zgo2G;HL{t!XBMge z5ny$OW&**Vyu1D9`(+zw^MHWSOpjJV7DwK50f~5h*{|G;1Q2Mnu?>$H-R9r6t0sRv zR`((HBVEHcFgH@rq=vFkfci)*EiEB!;K3}R3^YuySBs0CY1^M)-%(QM5GXq!Wvhya z(f(RlxzWXCte3LqMq5P9_Y&U&)0gwZyYf-++@AJf<2V7juC9VW@r$8^A2f}v&}E~z zk(f!g|B(A5${*Xb#WcZf&6k5z3ncCO?kk)HMrbyfEN{U5a=aiAqQ^B zLi1_*{uV1d0z3{3TO*`Y5$GJ{4z7@>CvqJ6>u56W9TCReqOls76qlcT=79yM%V2__ zj=&a(xW%!`;)3`F8#X~UbD()n%z_yio?@PTwX1SnsZwMA2D;3fQkAhSv3!g6I=j1D zfebFS?_maV=2498cAFQOMngrnzJ+!3haR$Apsi$2F`bAikus3B0O_)TFJGK$P=dIP zfNC?Bpvf$9d0rHYkeyTHXPYH4y+r#v| z$VZPTnF_U1f-cj2;f?VPUU2jMb&%w+y8OtZC{Vo}29=GE{R7ekvmQ#O7zwBx<>Siu zb4yCdeWn7XKru3&dw;*sJCS2%KYLZ*$XdN{GLo5fGKU?_I@FYYv>U)slRzLqU&yAU zbPiBfYZMh%*7skiZj#WK!mpLM@q4P7Vs}?hpZE1V&RgJ(e1zfb7X)>UOLnG(Ju2EVbRs?DH*ug2wC=o3fjVE+?%Jr zQR~}R%Ma^G%>FAXgyC?$XHrnB+iYT=n0|TXW>1y;;n7i@bf*18{4hHphll4sAiz5F zKtF`a4Coq$3JjS@#lV2#c(>ssD(2SJhw6KC%RD0Zy^b4WKPks|k1f(_O}J%g;N}E3 zskabL4vuNaV~_3VVP5FO9Ud=UtbAE$uJ-s1U%O2v=Zvort9;rulLOz&VR&Z2u5jDG z{rz93NK}R;LY_W-3RZ=SdU~j=D*`=}fx*hd;u_pcr9LbwBW88SI8q?w{>=%}OsI^9 zhxfsGgi<5dz!1CU$2)_W_w8%Vzzuo2%=`oL1T23T8fh&9C_Lu7RDy!)JLtUI&2!7R zkH6mFZ7z(KMQc}Nfe1l5?k&*Ar2*SPGuBsTcj@EDk0x)G0`^L>Y;?`_(dv%;ycHe7QD@~F7}(F&*M&?(g5rIqneOZ-s@G- z4pVOTd>31=^m#d)(?`gPQmY=t!Q$w*a{hT3$OS8&j)4K?`37C7P-#( zgQ1b+tmO0H<6rriXg9tNs&<9jRNUjogh+4A#_IB*S==I13s!4x z*qIP07#Ilf&L9<=zP(y#2t5xqjo=M#mKTAkG6e#XOg$>ktjHGmbbAY){O3ic6pCqi0$IihF66p(ooL^7U2eh|2Z&FrsKxc%@XfKZ4VsLQ<o*cziXOUO~F9w6`$~&g1P~&Y#b{fTPE)n z0s}>Eb65zW@*oK~X27^(QXjCN9>A395$rk8>vovUZ%C&rbz2vPOgFsO z35jQtdnjmLjUaky($=N#1TSY_gpT0D?eD$d956I8x;a^Uk51GUUoTx2hnDskL#5WG zOD&NU>>Hn`E?ae{NDSBb2))uve*{p&dS$q@@+jT;v%+fN%-m9B?M(&Bt3o#VA2SzG zG-ybM`u6%*m&SmyngL)7kNIbLsJKm1sCEMr6~aK>-zwv|`U;JH6n6>)#{I7h#q#-@ z@_T`EJkAPU1SbrG%q?;_k?4eYwy(?_28yXhgt&1?N$plt{ujy|7`+7&Ej)c>Ygx8h z1RIcW)YR0KMbdna?N+Z5t$R8noba<$5Y8q8rTgy<167#=2X@ssX{3v+)-CZ7XmZ=i! z>9ifo{;u^|w*Z@g>4&O`gkNB1LNzwjnB+RT^S60^)PB%t<5O#^%A0`{%ciCsF(r$8 z;S|v0jqlt+TZm=ra<#_F@5_JwCqBe`@Zf>#|H*5)tus`DEJ_LzTz7Go&cSo$yl)Q8 z6K?CDk}{c^XJRn>?m$*rBsF!xFrh8uw*lGMq#;Nl=y=~Al|?;ML8k$6-h_Xw+KUgA zvFs~_j^i3dMlBN)6SHkxqf+SXaHiSz~WCvTM{Uo zxQMO2J*fXJwSu7arb@ar0}9@p^j}BqgqM=}XBb1uKk|Jd5~xi}NU(h5!t+eH`jW>f zv9c&hz=IeUnO!SBsCoKeGwHHoVa40vlVj+j-Ro|ZUIX63(3T~6O*-OSoP{mw9qbQ- zGq`k9IA(2rk;X1vxrT6NRK+xZBXmf2sy9l6A%2_Mo2m?U#XB?2IN)Rjd_H2I3C#%~ z_N$!H58EDU{=rg)bapF8Me|jZFpwR82?UUHev7QIg0PeaK6RnKjHXaL5uddiASx@{ ziob4s(8pAIRCP?(#1#q+r-Uh_4Jt0Z>3212m`K3(fHCTCL;VB*!Uui3nWKD+BWcZ~ zJg853jWqyhB{?~mz*-`xOA6J|k>2`1|7gvY1h3hvsFIh$33&!!IZ)bhkX6{h3<8S* z3xQ%^I+%awv2TN=VB1wEQtRd~p{Gr6p9{Gea=f5)(u6w0FtBi~fQJJ212*D3T+$>j zt~1y^(uZ+fZ{7+3fi3lui2y7J=Q3W^SNE%K!Sr=>c80?G&d<-kw-3WkU~g}K|MQ~8 zVL8M2z!xICZ@)`7hcC#F6qo-?A9$+%NavR?SJ^_4gZseEBFwLrO$c zlpHCUS-jAsg;w5o!25s|sqIvU0#OYz3Ley%NNA>?pIg@$yse&(a8pz2+5Xx%_<8em zT@;YQIT2*%G@xi;ASnBDczF0YBO?y@S{o?hY`J~9zE#pWm!Fd%(-FC~HmPtQ9ttxg zIXM|MvLYp3%SMMs>%kGK*~&nkC1hj>iZV<*P>Q(!Uiv&S>O*DwfDi+j0l+Y#t4rq) z7Bre#5f&*aTJW*L4Fu>YXCmGS$ZuHf(g>%V%9Rq-s_U3&%ex6(qczah9~UzGfds_L zdMzc)%d2fGUX8QHKs|VQRX`RO^3Fg42N82kCoDf;h1uBIv-J(VE;7selEWPwrt0fD z9HjET|AMk@U}Dn2d*3v;g+z{bQJs> zCZQ%KRenmUc~YqsV~D@dF)6Xa0Kwm^!Z`}&q7iU!aor9B0LcyqL%jIFl&}|;%&mi0 zW&Q<=00hg|R4FpBskWxN)km>w#!YX`wDK~K-aLNP%Klqxi&2+NG&yJ>MKVxD5E?Mst>%L3#6>vh8wE@yAvyvd?P zi7q^tQRmSQTCi4gRhSM-?p|gHQVb-}NFvTtr91@#j;zpvpqT}1ny1djGiO);@S|4J zgp?G=APr#&l8Nj0RF_{i4}2qhX}hd`O39xk_oljd;R4P>V)_SUihBs;Lu!hdxJ|Yz zmkz%E1d)vzL=|u?LyyzCRoTdD^;A0?GJ|rWw4no^U-FPN0B; zBUj=eNkI>;_;-;VvObHfECWzOG|N+fGgLk=djHeo(%UCi12;`F$^Wep+<$2&3(yJ@ z0_fl5?Od<>_b28i4;2=ELyLQwAZ_ztCCPcOkneByBe-50EuB-DpR!}#e7x~e6AdC1 z4R|uK4uPjvgHg_#7(A2%gHMx_XQ3_%ztF)<>`N#7fi_TqeAw8>P3u?@U~qk5=7d0 zWuvm-5(E~>2J$dwqCp@X^#~)-iD_SXnm!*}qb=+gj64bd1PA6GN=5v3+ftCdy#HKx3x;iD<-CL73_5T*#iTB%e|YIlv80iD>hx>x#tpq=lVE8g zOD=Qe%4Cbv&SXTPW)_4mKz?|G8{0CUn8GP-{<{oYYjJq|dgc z#|xRk^*;lF0-AxWB-s!j8;gUWPQc87-vq$sK#W*hxCzq8hZR8-BQVlNIF&Sy)P%?eC-wq{OcL;`!m-^5V+9Y5?YD;gAHH_M~NoT ziNNIn`YS0Oh=LLn^1(!F7&pLn|g{~1et$g9+on?cCzNjNNuryVyE}M&7T6>-+_v8nquLC9r#NBo}V)_@uZCg zh03bVN+4lUpjBW6yJyPiH^*a;@N6J-(Dq5tvL0kKi3b)%b$wi2c(}O)?;B)rlcG*Y z0scGkL&i-^q(n0v_>+_yf{uDC-7LYUI$PS#tOm}p4u}3mXg|;HdZ~Jlamn}R;=D`6 zL{TWi<6cZ*(3io+yEhYY4jAhYpb7&{j&{)&4~*bkq6`ubXePf7?;j#DYu{=|Ii_Ja zPUWPp$nQQ`{fd>Rdfg_$2eI79~pW+bq; z`2>owTWHr+?}ZbH;gz?Fg^os|yoqGwf*T(E z9>nV{Ta_WK(?Ee(?QP8H7aGF$hZd5tI0lK=30D%zeY5<9^%l;|%+=kFJX+#5JE+bn z%1m7}L9damb=1YnRO5GC!;@%22vqPzpQWa^oEK%|Ut^YAo5 zEr^4@sTqM#MT50C)rZPn)wv22;NNuECcv#E0Ecg$ORgK=~7$J!$03M z^{TD7V-7GH@)B7RH~O!gR3zBZpIg^bb2X@%wuK7}sySeMgzZ46j~-~QP}S@ZXzM_2 zB&Vf~pclKpRc3T<A}9en zD*#G5RW&sbyvb^}2d}$(dVWH!stgQf`GF+o6dA#<)bMi$YLX^+HcK4(*bwCiQl`gG zo`h9S1)YTp6n(c(+Fl}`jsuHg8`r^Co0&mFM)}Nmis;?xLX0Aj?X^LDxebcDK=TGe zjKdbKXh8Bv2-@6SynI;`dTP|q&Tm@wQbo!C|H8^ixaNX@`u!Kbh3yegr19GXO@l9v zgtN@W=wdkeCxd88Et#pAVqw(qL8JlGi^K^d@Y>)HV}%|nBO`+(2+#hMnmgoA-rwKP zE-RzFckdoPet}}djT<)@EdsX(aJw&%aCLmAf5Yk1PS#$wIU1k8q25rvEehpkgb*$( zvHAPA6zq9Dl(9X6ycKXI%8<& zVFG!)`DuxRI7o^DsuMxrc`kM+N@F%+I97M=h(LWrH$-JQraXh)7VrltV5m=*zLQgR z^n~Xs%4PtU`SRtGZ3%ciwlD65%W|{xDdE^ly9Qdt8ytteo@rl>v{yqlWatD!Ot{ zA+(smgVWH)pgiZus(GWj*EZltpn4=TJDQS{WawyS;tR2wdwDMM$udB0Nq5KJBP}5s zIR>#Q(ueBn&Ff#zZm)4%eF6&f-7}q z-vpabZKV<~VumY-Q&1!5;X} zFVe+oM;RpCr=zCFM>+)g)xPb&+1n0KtVru<3N+~Mb}2ryoRDGpXan(aFw{0kfYH+^ z`S&pW`at7Ib8~Q?p+u(B!^F^CCmaO${j%9lOqvBsy`DkhvfYW(dxDnM?qg??CMT)FQZq7J;W*8D6P zqPMU({mvje>jP%P!qe&9TFL-_j?UVkyTwM&m3T1`E78w;t9M0IQ>#?>+shxqiZ};! z;mG3xnv7wI16Lp+p#EZjrD1k4923P)pvI(B9hvgKT6lJIqoxLQS~mG!rt}miv_p~% za>aL<$MhkIcpMe!Obt8Q&;Iy18;P+RAd;0{d~alc66Qmf#2J2HTagYt_FA3(gQq%NXg(9`&cFq4Jw{ zcv!F}N%MEC7V;mz_@&QN^sUb9r^0|4hAtPM_`_;_a*eK3Y+JNm=gTMK2vLe3zrspx zy^D05iK53E!`Ie^pBM;*p97dMTwoy!_}x?E&V9r?sdce+(!==i3(JRHhoVmwV`)Ho z&GXOl*3T&~CThkSzZ2{!e9lty)M8HAC$K#DjVe=Vk_xtzfFC1fN|5%tjVWRBoR)^K^c9LCjd9BJT&AOs=vgKM+oFv-7%uV&7eTOp}NW z0XuUy!2+c&{}SyWlm?;h-wS#Rt7shYa0sF#cAA9GipI#mdV!=^3kwYcV zP&v8jmlcQ`_~=La<_{DQL{sPtemA)0g3AgqH;#6vG(n^lXCSEt!xl&gL!;hc%zMmH zIJyk-S?u%Y9XsZHNRk*)*gOkcjuI4;J@74^Y9ruJa|a*%pVhHb@3e+S8%gCQUEK@b z-WB@;;4L~^!Ie{8WrV!4T?Ur%%vsosgAD6q?7`N-Y3PGnP(Ll~d$)q01wcKM;E;Fy zm|k2g1U`grp}Z?27o(z9)E?(k)tWOJSUsdD!5%x4&=7zq_@i2Hv9B zG4v!L7#9bN&Vuh9HBX@~G}!i}4)7erG}u&Be6sc8l`NCLSfp!f7HJTeGEX;3fWz=X z?3s|E8z=o3(0&C{ssO(11smJe0myQ;w{FdX)BasK5(B#w-xSxI1>cWa`J|>`hSmV} zc62<4VQWi#JlPX%yYk6;Fb_J`J&^kL-u`l$f`7%}HfmI`w$9#v`r^gCu@}HZM5>oY zVfOa*p}QHtIR2pZH^7cRO9RWvyB;tV@FxA$7XK0dE~r!)ipKQIXQ7r7EVJ07mV$gS+92m|DOzB(T5D`lFQ+HHhc5EIP!7@BpFQo zMrg>@+!Di?hTknfn;8^+kBp2g4wtgSF3B)(Az~sxNeRqji_;N5mvEv1Hyt>-!y+T` zQDYu3QsmpzkQ&RG{rAb{XA=`A|=gB{n5fqIE$B?RV>a{{qz~ig?gjEU?cHbPsOZ{0&d3)8&NJk_o!<1{dqto{6hLH zHHCrAO)?e=0)tz^H^=b~?&N(6na|MkEwgk&Wy}@Wzz}s2pG))nEfN{3r252g^+9mR_Sqd2hLbb0bFMr>(i8F~p9ge5B z=Wm%#7b&HM!%%@WEA#~1834ghTlIbL5-h)?c#;@daNCA#nrXU$C6XZ{`&(EyS6L}& zEU;ljd%caIoK4t*)d;K3U*WqV7|7>>Z>}6~FCA-CwyF(>cSXo6c zMhEYZHg$9;BiDRrE>R9&x@ep4dwdIKb0pGA$UZ=rmErtf?G>445p@f{OBS4&gL?Ry zK&~GjDsLW_Kyd@CP9WASxy>=pe(uRofV$r<{*CghX}fmV>lC0j47hGbJ5(=p7D zA?R-hWW0aY2eN`kYS-ZvaMDPpUXOuoJbS)w3;{QoM5M>hsj(gQvVS7_8h1FwPB9`# ziZO-Z%}OPqORYgwM)U=eV?@~1XC(qZDG57|S6B~S&D=Z*b`a#yPw`x#%mpxLO)tN? zUtGg5DH!Ljf0sEh=hdSHHz-31OjmfQ9cETzfK#L!>^le7BR$NB7vOqdx~#(tvRlfS z>&5!Wo*=EtVT(k>jaK48U0TVHCT?th87$y`X8=}b>I zh8zJV#8jU2jmY2gvM6Qd4aYRIpUl=#kBx_E?jw=^2BG+C@? z2;o>oP2v|N5yFfZro4Xljz@xS6FP{)pepRIOG+otY8=Ms-|nBs%J|9)oURH`_t<{Y ztjcdxLIw!}NB7(xI6=Kc_nR##7e4Oo(v7|4>In6`+@x~j{zW8&zq8jkML(xF>|HJ8 z2#y_~-H2sN!p1jw)bl460s_i+Z{2UL*Q*DI?bcb8Y^#*t`}n%A>sl-sB|!JH68j4lVFY!QTz)@)kgvb>t$N_9BH~LBO$vicEr(@Su(|p|UbRus@)#kj;UU>UJ)Yd}a!mS{B z&>K89@eE01Xn*lxAl1LMJ9|6~XJZd!Q>Kg=bZb*zbm-k-crMpC))944JV~p9?zns4 zX>PxjdY+cXi>An&--Ks=@qwqEz$X^#;h~694u2=z*0G}aZtrx)PS7Ib%L~i5-qkw2 zSiw96^K6da-Zkv3kD9khCHWQN!`>EmajIcYc{~KgxNN;l6!+KY&aByI$jen0Fh*CV zhil}80gC+^N9u0CH#55!qD;mG?)RJ*sntOZbQBWY=#PeZ7eE$}RQgW4OlH~G*QJ{F z(IoOpChm5Z*Zu4XB(3%iZeQez{q0$5zX z-=@>$G)%Dt2wZTD(Bar&th8$I@Xxc~gzU-AN<_I@4>`2-d67I-jGyO}DmOjEDse(c zG=<8XHD z1(riXI!f)_+kFU6ILRqcOb5HL?#R~Qn_u@$8M`MVH#?%@0<|jJ&@w7f=}}G2Z*53( zX5)rA#o{-&qSDNi#2P+X9akB;ed=3Y3p(br*UYOz#0^t6TAN0rdxJfRL}|UAa?-D_ z9~KQ(?0ORlS4zCF?b@5YSmc5S@~~!7PWSq{Wjz^nTS{?SHTZNak1S*M1OGIjJEsUV zBVIS}^qd%Nf2a=(ep$5|_7u;(>LM8o!xh$R)Xk*YFi2s0AFY}em5L_<_7A8)NNCaL z$*paUEF)z9LVPXUJ1c@}P=2yy-23Zan~pA_&?~I4S1rr%;{39rLYkrPQN7-L#huhE z8z2yJAcR^EI57TdQ$6bB%QoWK7_mt{P`GD+=Mzf^`%X;$SQ!>IsFmrV3lATf5@Aaz z2p)`~^v?3Hyntg?1yP-MB{1peq_v7Nea@kkAxdQ6&C(abWZyUiC}Z??Co{>*-y<9#L{xu3|1b(&`Gv!*s{4DHx1uN3*T+jdm5qLfBD<3U{MFeBP{ z6zt+tul(w`(vlm@W^3T=G7J$uaI#FNTP5)cgw;QY>e>rbxekIqn1~c`S^U&xqzi{0 zw~Dgz+?KgOfWq@V%ZT%Gczac96l%nXyWfgWL~&7Zu5L>4uv|3sJL(H&;3;;q z=`A_2;pS;G#6-i`AM+W2}XwWf}`}_LTwX@qt zSow1}Ufxds5R32q)(8YR z9q2@7r{J0;C1#swpI+MjmNtfaoht4Q))-;E#%nq`j`FUOEGV;`8~X zC}~pb6k33;RyT%JCQT=mW<(Hh8@sy8Rb54uM{q_+KF3>(?^83)n>~CIMXWx6p`GEH z<5yKJEsv#>x-@~kNnT`|p0+cNu;i=(kh1OOmu54<9>MnmLgLU@dAyH#%-zlHwDEhy zB)2+uAuF`u&#|-dPsOImDatYvTkT&>*WQn~oY#V!)WUaxW0PcG(1D@IQ`=2F2yBIL z^>F*k8+MFRx22kCkXunWHH9}xX?a@TwGObSU2WL9;L?b=h)=D0S9Dp7HXx{(nHuQ% z9*Wf%h1y!i>UEn;fxq+}tvgh7DBS?0>##AV&f{SrH+c_@ix{HVd)iC6`Z z__T`mTpk)1@keWS?jn}f(*x!E%-hUKh#rzm6jKKO#*RQH$lWr1;fx(R>`>$)WFys-41T1nXP&say5KQ0Q zsc|eBUteD$?E@c%fZ5@|g>hEG^GL2C5gmypj?w)2MAJ2LA~USQO~0_UTTR{VLO#?J z706aeA>b0B6;Sb*C60dE)}{fbFe-MlID9l+2ZZSIMSgAj;f!Fm0NUc9!|}4MjqRTT z@s&-0zki}oC=?0H;{@=iB@U61aZ6pPJ-hgwF}EwD5n%;1TdI1{Npf|p5>^-Wh{Ynz z-%2p2$zxQE)XEpt-rZ2g#gihy;C<319qaIzR-8pK&R8dxFmjvsJ*?CdU@}jIXLO(kOWGl(jK64o<25? zj=Eb5>;oYX%VgJ=U@R1aA>!cB-grAS)St!|R`)t`TV;z^r7tSTd8Oy+0Oo=x>Bx$+ zXZAy-rE3B&RUT8!J#3^RlQZ=fC65uy-ERXM9?a5GVd&tJR)&gdfqe|bYT{ml3}gvg z!PDN)ShA0qEan>ALi?K_V{k{A0P-Pi03tLiNh?3?_X(^t(KulWp+p3VDQ)4d+3b76 zDHcOBc3%f&ZT-LNk@m9HI{lggeKZuR;f>qz_L(B}usq|eFflEyxkLHg9jwaRd%IjD zUpslVPrLE)EujN+Vram0C#x-qzMv(BKJYOkKaDu6sk>SaD6O`&jl>dZ273;$7h+C8 zKa9H|ZJhmGLkw~?`vs{k{^5SRd00R5%@tS?`_$We^+rJRmG)sZ2~ zL?dYX_U&QlDWNxnK>s)+Bco%XuUL#2ZNN{!t0oE-?C1zX)QyrDKY5lGDh-riMBOEU zBw+h4h$s+G7-lKeg%QA{I{80Weg2IfKwY4Jc2#;T>@`C4ftOlZM&AF5sn^@zzfnaM zg9am$kv3zVXk8E;qAgVz!Z_|R_Ot-sO zTZy_^Qa);{X#)MYWv*n#zJ2jRjd*~1K>CRu zn+!Mv1Oz14U9xi>oSZO9XwLI(Ze%Gu#GAOQs;UM8-D)?)xRe=L+VBwryaFBABs!fg z0@`*WTm4kHY{r8+NU$Dm1K295oWbK`R_57XPVWEIsjX@f;^5$6>>oTh14DXS<%-cO zC9h)m&*Ir$xyvTYSoP~2WCfX8=uvI@)Wd+|;c<@T+sv}#)ulAx7a5+w6y)(4mUu|f z@N)R>Yvotvl4Q6>-k^UL^_@}V^!1PYHS(7)X9j6njRO~GOn^ckl;RbiM`Ot~OE~St zBiZ%XrRHwyRm|noG%l|jiD z(z7SaI)NZ_UChctJT)>Gkm_@Q2;soHsIK~b>{vs*d6P^+SlGEK3Jl;RL+?q%4A;rK z-G)C8;45Uxmq3F`APg>`GCy1eD07+%g52rd0QK7T$T3Q>Nt1PkCFv^TKYr|6^TEn| zhnc=&e9SQe&xA|sbld^eMb+(<9SOOmJJo~72j;cc)mage8kRZ$V|ZkIe{$Pf{b=-LysFbSd?CtC1Uf&|S9Fl0G}{ZUa-fx*BqU7!p;aYEiIVS9I*18xZSCxh_36-&T)PaV zr;QeZcK6`y;^&aCeQk3gxKC;OM5xP?A3&1AzDdB11%*2P%4=*JdQ; z9>8IgHytU=K~ai|iVBQ@NP1woRix8n5I84#hjK&`2tQ1fmW?rqCL~v)mi1l8c>Vf% z+*#i~{&j4&iwqtzsX_HD?ky?+C~Dt(0j literal 0 HcmV?d00001 diff --git a/epoch_loss_class_train.txt b/epoch_loss_class_train.txt new file mode 100644 index 0000000..515fdfc --- /dev/null +++ b/epoch_loss_class_train.txt @@ -0,0 +1,90 @@ +3.7712276314242685 +3.0636168965986927 +2.801343619823456 +2.6331870392384205 +2.5135307890929712 +2.4291697614202716 +2.3555361779060604 +2.2984420201291202 +2.2522135552288742 +2.197252894587989 +2.097349743214381 +1.9542330193516024 +1.7952129174223848 +1.6156651931379018 +1.4707137380617539 +1.3638808261168214 +1.2811493240666694 +1.2145853494477985 +1.1464360410734533 +1.0886454173869087 +1.0415393155084143 +1.0067537130545055 +0.9712598563652918 +0.9459512437851009 +0.9096426171339108 +0.883714666565097 +0.8630674364239435 +0.8406433062055244 +0.820840925855913 +0.8079694209025758 +0.7841788579339185 +0.776262023339371 +0.75239310020289 +0.7407916599995307 +0.7314519268123639 +0.7205715374548911 +0.7119348809259342 +0.6838741163310618 +0.675083714229081 +0.6665253263740542 +0.40875943099878626 +0.33014846760375033 +0.301763350591267 +0.27762719021237026 +0.26088881696264316 +0.2430464035457363 +0.2282850744710909 +0.2212755358291129 +0.203683311874431 +0.1944100352668947 +0.18915298141420142 +0.18149443655029596 +0.17873722235375974 +0.171926496180016 +0.164108280483452 +0.1580632201443505 +0.1529976990462558 +0.14594485707745689 +0.1456607071607638 +0.14255504573110114 +0.11923092203425435 +0.10815973842863215 +0.1090422415932918 +0.10648260271793217 +0.09845620263738979 +0.100937618059459 +0.10059720761977203 +0.09929326052758904 +0.09756028030852115 +0.09850751669961134 +0.09668084698679134 +0.0974692665948768 +0.09672294503566119 +0.09527862614642296 +0.09309141301707675 +0.09378876050791575 +0.094307657253778 +0.09193214094837932 +0.09031117155016503 +0.08870918162732286 +0.08976529900305628 +0.08807121031712889 +0.08787393360298594 +0.09003039282064657 +0.08867947479501018 +0.08774756830658956 +0.08815590344746861 +0.08792258952912853 +0.09093503984912055 +0.08763042784054735 diff --git a/feeder/__init__.py b/feeder/__init__.py index 8b13789..d3f5a12 100644 --- a/feeder/__init__.py +++ b/feeder/__init__.py @@ -1 +1 @@ - + diff --git a/feeder/feeder.py b/feeder/feeder.py index dba96f3..59a6333 100644 --- a/feeder/feeder.py +++ b/feeder/feeder.py @@ -1,90 +1,90 @@ -import os -import sys -import numpy as np -import random -import pickle -import time -import copy - -import torch -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -from torchvision import datasets, transforms - -from . import tools - -class Feeder(torch.utils.data.Dataset): - - def __init__(self, - data_path, label_path, - repeat_pad=False, - random_choose=False, - random_move=False, - window_size=-1, - debug=False, - down_sample = False, - mmap=True): - self.debug = debug - self.data_path = data_path - self.label_path = label_path - self.repeat_pad = repeat_pad - self.random_choose = random_choose - self.random_move = random_move - self.window_size = window_size - self.down_sample = down_sample - - self.load_data(mmap) - - def load_data(self, mmap): - - with open(self.label_path, 'rb') as f: - self.sample_name, self.label = pickle.load(f) - - if mmap: - self.data = np.load(self.data_path, mmap_mode='r') - else: - self.data = np.load(self.data_path) - - if self.debug: - self.label = self.label[0:100] - self.data = self.data[0:100] - self.sample_name = self.sample_name[0:100] - - self.N, self.C, self.T, self.V, self.M = self.data.shape - - def __len__(self): - return len(self.label) - - def __getitem__(self, index): - - data_numpy = np.array(self.data[index]).astype(np.float32) - label = self.label[index] - - valid_frame = (data_numpy!=0).sum(axis=3).sum(axis=2).sum(axis=0)>0 - begin, end = valid_frame.argmax(), len(valid_frame)-valid_frame[::-1].argmax() - length = end-begin - - if self.repeat_pad: - data_numpy = tools.repeat_pading(data_numpy) - if self.random_choose: - data_numpy = tools.random_choose(data_numpy, self.window_size) - elif self.window_size > 0: - data_numpy = tools.auto_pading(data_numpy, self.window_size) - if self.random_move: - data_numpy = tools.random_move(data_numpy) - - data_last = copy.copy(data_numpy[:,-11:-10,:,:]) - target_data = copy.copy(data_numpy[:,-10:,:,:]) - input_data = copy.copy(data_numpy[:,:-10,:,:]) - - if self.down_sample: - if length<=60: - input_data_dnsp = input_data[:,:50,:,:] - else: - rs = int(np.random.uniform(low=0, high=np.ceil((length-10)/50))) - input_data_dnsp = [input_data[:,int(i)+rs,:,:] for i in [np.floor(j*((length-10)/50)) for j in range(50)]] - input_data_dnsp = np.array(input_data_dnsp).astype(np.float32) - input_data_dnsp = np.transpose(input_data_dnsp, axes=(1,0,2,3)) - +import os +import sys +import numpy as np +import random +import pickle +import time +import copy + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torchvision import datasets, transforms + +from . import tools + +class Feeder(torch.utils.data.Dataset): + + def __init__(self, + data_path, label_path, + repeat_pad=False, + random_choose=False, + random_move=False, + window_size=-1, + debug=False, + down_sample = False, + mmap=True): + self.debug = debug + self.data_path = data_path + self.label_path = label_path + self.repeat_pad = repeat_pad + self.random_choose = random_choose + self.random_move = random_move + self.window_size = window_size + self.down_sample = down_sample + + self.load_data(mmap) + + def load_data(self, mmap): + + with open(self.label_path, 'rb') as f: + self.sample_name, self.label = pickle.load(f) + + if mmap: + self.data = np.load(self.data_path, mmap_mode='r') + else: + self.data = np.load(self.data_path) + + if self.debug: + self.label = self.label[0:100] + self.data = self.data[0:100] + self.sample_name = self.sample_name[0:100] + + self.N, self.C, self.T, self.V, self.M = self.data.shape # (40091, 3, 300, 25, 2) + + def __len__(self): + return len(self.label) + + def __getitem__(self, index): + + data_numpy = np.array(self.data[index]).astype(np.float32) + label = self.label[index] + + valid_frame = (data_numpy!=0).sum(axis=3).sum(axis=2).sum(axis=0)>0 + begin, end = valid_frame.argmax(), len(valid_frame)-valid_frame[::-1].argmax() + length = end-begin + + if self.repeat_pad: + data_numpy = tools.repeat_pading(data_numpy) + if self.random_choose: + data_numpy = tools.random_choose(data_numpy, self.window_size) + elif self.window_size > 0: + data_numpy = tools.auto_pading(data_numpy, self.window_size) + if self.random_move: + data_numpy = tools.random_move(data_numpy) + + data_last = copy.copy(data_numpy[:,-11:-10,:,:]) + target_data = copy.copy(data_numpy[:,-10:,:,:]) + input_data = copy.copy(data_numpy[:,:-10,:,:]) + + if self.down_sample: + if length<=60: + input_data_dnsp = input_data[:,:50,:,:] + else: + rs = int(np.random.uniform(low=0, high=np.ceil((length-10)/50))) + input_data_dnsp = [input_data[:,int(i)+rs,:,:] for i in [np.floor(j*((length-10)/50)) for j in range(50)]] + input_data_dnsp = np.array(input_data_dnsp).astype(np.float32) + input_data_dnsp = np.transpose(input_data_dnsp, axes=(1,0,2,3)) + return input_data, input_data_dnsp, target_data, data_last, label \ No newline at end of file diff --git a/feeder/tools.py b/feeder/tools.py index 0233fc7..942cfd2 100644 --- a/feeder/tools.py +++ b/feeder/tools.py @@ -1,244 +1,244 @@ -import numpy as np -import random - - -def downsample(data_numpy, step, random_sample=True): - # input: C,T,V,M - begin = np.random.randint(step) if random_sample else 0 - return data_numpy[:, begin::step, :, :] - - -def temporal_slice(data_numpy, step): - # input: C,T,V,M - C, T, V, M = data_numpy.shape - return data_numpy.reshape(C, T / step, step, V, M).transpose( - (0, 1, 3, 2, 4)).reshape(C, T / step, V, step * M) - - -def mean_subtractor(data_numpy, mean): - # input: C,T,V,M - # naive version - if mean == 0: - return - C, T, V, M = data_numpy.shape - valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0 - begin = valid_frame.argmax() - end = len(valid_frame) - valid_frame[::-1].argmax() - data_numpy[:, :end, :, :] = data_numpy[:, :end, :, :] - mean - return data_numpy - - -def auto_pading(data_numpy, size, random_pad=False): - C, T, V, M = data_numpy.shape - if T < size: - begin = random.randint(0, size - T) if random_pad else 0 - data_numpy_paded = np.zeros((C, size, V, M)) - data_numpy_paded[:, begin:begin + T, :, :] = data_numpy - return data_numpy_paded - else: - return data_numpy - - -def random_choose(data_numpy, size, auto_pad=True): - # input: C,T,V,M - C, T, V, M = data_numpy.shape - if T == size: - return data_numpy - elif T < size: - if auto_pad: - return auto_pading(data_numpy, size, random_pad=True) - else: - return data_numpy - else: - begin = random.randint(0, T - size) - return data_numpy[:, begin:begin + size, :, :] - - -def random_move(data_numpy, - angle_candidate=[-10., -5., 0., 5., 10.], - scale_candidate=[0.9, 1.0, 1.1], - transform_candidate=[-0.2, -0.1, 0.0, 0.1, 0.2], - move_time_candidate=[1]): - # input: C,T,V,M - C, T, V, M = data_numpy.shape - move_time = random.choice(move_time_candidate) - node = np.arange(0, T, T * 1.0 / move_time).round().astype(int) - node = np.append(node, T) - num_node = len(node) - - A = np.random.choice(angle_candidate, num_node) - S = np.random.choice(scale_candidate, num_node) - T_x = np.random.choice(transform_candidate, num_node) - T_y = np.random.choice(transform_candidate, num_node) - - a = np.zeros(T) - s = np.zeros(T) - t_x = np.zeros(T) - t_y = np.zeros(T) - - # linspace - for i in range(num_node - 1): - a[node[i]:node[i + 1]] = np.linspace( - A[i], A[i + 1], node[i + 1] - node[i]) * np.pi / 180 - s[node[i]:node[i + 1]] = np.linspace(S[i], S[i + 1], - node[i + 1] - node[i]) - t_x[node[i]:node[i + 1]] = np.linspace(T_x[i], T_x[i + 1], - node[i + 1] - node[i]) - t_y[node[i]:node[i + 1]] = np.linspace(T_y[i], T_y[i + 1], - node[i + 1] - node[i]) - - theta = np.array([[np.cos(a) * s, -np.sin(a) * s], - [np.sin(a) * s, np.cos(a) * s]]) - - # perform transformation - for i_frame in range(T): - xy = data_numpy[0:2, i_frame, :, :] - new_xy = np.dot(theta[:, :, i_frame], xy.reshape(2, -1)) - new_xy[0] += t_x[i_frame] - new_xy[1] += t_y[i_frame] - # print(new_xy.shape, data_numpy.shape) - # data_numpy[0:2, i_frame, :, :] = new_xy.reshape(2, V, M) - new_xy = new_xy.reshape(2, V, M) - data_numpy[0:2, i_frame, :, :] = new_xy - - return data_numpy - - -def rand_rotate(data_numpy,rand_rotate): - # input: C,T,V,M - C, T, V, M = data_numpy.shape - - R = np.eye(3) - for i in range(3): - theta = (np.random.rand()*2 -1)*rand_rotate * np.pi - Ri = np.eye(3) - Ri[C - 1, C - 1] = 1 - Ri[0, 0] = np.cos(theta) - Ri[0, 1] = np.sin(theta) - Ri[1, 0] = -np.sin(theta) - Ri[1, 1] = np.cos(theta) - R = R * Ri - - data_numpy = np.matmul(R,data_numpy.reshape(C,T*V*M)).reshape(C,T,V,M).astype('float32') - return data_numpy - - -def random_shift(data_numpy): - # input: C,T,V,M - C, T, V, M = data_numpy.shape - data_shift = np.zeros(data_numpy.shape) - valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0 - begin = valid_frame.argmax() - end = len(valid_frame) - valid_frame[::-1].argmax() - - size = end - begin - bias = random.randint(0, T - size) - data_shift[:, bias:bias + size, :, :] = data_numpy[:, begin:end, :, :] - - return data_shift - - -def openpose_match(data_numpy): - C, T, V, M = data_numpy.shape - assert (C == 3) - score = data_numpy[2, :, :, :].sum(axis=1) - # the rank of body confidence in each frame (shape: T-1, M) - rank = (-score[0:T - 1]).argsort(axis=1).reshape(T - 1, M) - - # data of frame 1 - xy1 = data_numpy[0:2, 0:T - 1, :, :].reshape(2, T - 1, V, M, 1) - # data of frame 2 - xy2 = data_numpy[0:2, 1:T, :, :].reshape(2, T - 1, V, 1, M) - # square of distance between frame 1&2 (shape: T-1, M, M) - distance = ((xy2 - xy1)**2).sum(axis=2).sum(axis=0) - - # match pose - forward_map = np.zeros((T, M), dtype=int) - 1 - forward_map[0] = range(M) - for m in range(M): - choose = (rank == m) - forward = distance[choose].argmin(axis=1) - for t in range(T - 1): - distance[t, :, forward[t]] = np.inf - forward_map[1:][choose] = forward - assert (np.all(forward_map >= 0)) - - # string data - for t in range(T - 1): - forward_map[t + 1] = forward_map[t + 1][forward_map[t]] - - # generate data - new_data_numpy = np.zeros(data_numpy.shape) - for t in range(T): - new_data_numpy[:, t, :, :] = data_numpy[:, t, :, forward_map[ - t]].transpose(1, 2, 0) - data_numpy = new_data_numpy - - # score sort - trace_score = data_numpy[2, :, :, :].sum(axis=1).sum(axis=0) - rank = (-trace_score).argsort() - data_numpy = data_numpy[:, :, :, rank] - - return data_numpy - - -def top_k_by_category(label, score, top_k): - instance_num, class_num = score.shape - rank = score.argsort() - hit_top_k = [[] for i in range(class_num)] - for i in range(instance_num): - l = label[i] - hit_top_k[l].append(l in rank[i, -top_k:]) - - accuracy_list = [] - for hit_per_category in hit_top_k: - if hit_per_category: - accuracy_list.append(sum(hit_per_category) * 1.0 / len(hit_per_category)) - else: - accuracy_list.append(0.0) - return accuracy_list - - -def calculate_recall_precision(label, score): - instance_num, class_num = score.shape - rank = score.argsort() - confusion_matrix = np.zeros([class_num, class_num]) - - for i in range(instance_num): - true_l = label[i] - pred_l = rank[i, -1] - confusion_matrix[true_l][pred_l] += 1 - - precision = [] - recall = [] - - for i in range(class_num): - true_p = confusion_matrix[i][i] - false_n = sum(confusion_matrix[i, :]) - true_p - false_p = sum(confusion_matrix[:, i]) - true_p - precision.append(true_p * 1.0 / (true_p + false_p)) - recall.append(true_p * 1.0 / (true_p + false_n)) - - return precision, recall - - -def repeat_pading(data_numpy): - data_tmp = np.transpose(data_numpy, [3,1,2,0]) # [2,300,25,3] - for i_p, person in enumerate(data_tmp): - if person.sum()==0: - continue - if person[0].sum()==0: - index = (person.sum(-1).sum(-1)!=0) - tmp = person[index].copy() - person*=0 - person[:len(tmp)] = tmp - for i_f, frame in enumerate(person): - if frame.sum()==0: - if person[i_f:].sum()==0: - rest = len(person)-i_f - num = int(np.ceil(rest/i_f)) - pad = np.concatenate([person[0:i_f] for _ in range(num)], 0)[:rest] - data_tmp[i_p,i_f:] = pad - break - data_numpy = np.transpose(data_tmp, [3,1,2,0]) +import numpy as np +import random + + +def downsample(data_numpy, step, random_sample=True): + # input: C,T,V,M + begin = np.random.randint(step) if random_sample else 0 + return data_numpy[:, begin::step, :, :] + + +def temporal_slice(data_numpy, step): + # input: C,T,V,M + C, T, V, M = data_numpy.shape + return data_numpy.reshape(C, T / step, step, V, M).transpose( + (0, 1, 3, 2, 4)).reshape(C, T / step, V, step * M) + + +def mean_subtractor(data_numpy, mean): + # input: C,T,V,M + # naive version + if mean == 0: + return + C, T, V, M = data_numpy.shape + valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0 + begin = valid_frame.argmax() + end = len(valid_frame) - valid_frame[::-1].argmax() + data_numpy[:, :end, :, :] = data_numpy[:, :end, :, :] - mean + return data_numpy + + +def auto_pading(data_numpy, size, random_pad=False): + C, T, V, M = data_numpy.shape + if T < size: + begin = random.randint(0, size - T) if random_pad else 0 + data_numpy_paded = np.zeros((C, size, V, M)) + data_numpy_paded[:, begin:begin + T, :, :] = data_numpy + return data_numpy_paded + else: + return data_numpy + + +def random_choose(data_numpy, size, auto_pad=True): + # input: C,T,V,M + C, T, V, M = data_numpy.shape + if T == size: + return data_numpy + elif T < size: + if auto_pad: + return auto_pading(data_numpy, size, random_pad=True) + else: + return data_numpy + else: + begin = random.randint(0, T - size) + return data_numpy[:, begin:begin + size, :, :] + + +def random_move(data_numpy, + angle_candidate=[-10., -5., 0., 5., 10.], + scale_candidate=[0.9, 1.0, 1.1], + transform_candidate=[-0.2, -0.1, 0.0, 0.1, 0.2], + move_time_candidate=[1]): + # input: C,T,V,M + C, T, V, M = data_numpy.shape + move_time = random.choice(move_time_candidate) + node = np.arange(0, T, T * 1.0 / move_time).round().astype(int) + node = np.append(node, T) + num_node = len(node) + + A = np.random.choice(angle_candidate, num_node) + S = np.random.choice(scale_candidate, num_node) + T_x = np.random.choice(transform_candidate, num_node) + T_y = np.random.choice(transform_candidate, num_node) + + a = np.zeros(T) + s = np.zeros(T) + t_x = np.zeros(T) + t_y = np.zeros(T) + + # linspace + for i in range(num_node - 1): + a[node[i]:node[i + 1]] = np.linspace( + A[i], A[i + 1], node[i + 1] - node[i]) * np.pi / 180 + s[node[i]:node[i + 1]] = np.linspace(S[i], S[i + 1], + node[i + 1] - node[i]) + t_x[node[i]:node[i + 1]] = np.linspace(T_x[i], T_x[i + 1], + node[i + 1] - node[i]) + t_y[node[i]:node[i + 1]] = np.linspace(T_y[i], T_y[i + 1], + node[i + 1] - node[i]) + + theta = np.array([[np.cos(a) * s, -np.sin(a) * s], + [np.sin(a) * s, np.cos(a) * s]]) + + # perform transformation + for i_frame in range(T): + xy = data_numpy[0:2, i_frame, :, :] + new_xy = np.dot(theta[:, :, i_frame], xy.reshape(2, -1)) + new_xy[0] += t_x[i_frame] + new_xy[1] += t_y[i_frame] + # print(new_xy.shape, data_numpy.shape) + # data_numpy[0:2, i_frame, :, :] = new_xy.reshape(2, V, M) + new_xy = new_xy.reshape(2, V, M) + data_numpy[0:2, i_frame, :, :] = new_xy + + return data_numpy + + +def rand_rotate(data_numpy,rand_rotate): + # input: C,T,V,M + C, T, V, M = data_numpy.shape + + R = np.eye(3) + for i in range(3): + theta = (np.random.rand()*2 -1)*rand_rotate * np.pi + Ri = np.eye(3) + Ri[C - 1, C - 1] = 1 + Ri[0, 0] = np.cos(theta) + Ri[0, 1] = np.sin(theta) + Ri[1, 0] = -np.sin(theta) + Ri[1, 1] = np.cos(theta) + R = R * Ri + + data_numpy = np.matmul(R,data_numpy.reshape(C,T*V*M)).reshape(C,T,V,M).astype('float32') + return data_numpy + + +def random_shift(data_numpy): + # input: C,T,V,M + C, T, V, M = data_numpy.shape + data_shift = np.zeros(data_numpy.shape) + valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0 + begin = valid_frame.argmax() + end = len(valid_frame) - valid_frame[::-1].argmax() + + size = end - begin + bias = random.randint(0, T - size) + data_shift[:, bias:bias + size, :, :] = data_numpy[:, begin:end, :, :] + + return data_shift + + +def openpose_match(data_numpy): + C, T, V, M = data_numpy.shape + assert (C == 3) + score = data_numpy[2, :, :, :].sum(axis=1) + # the rank of body confidence in each frame (shape: T-1, M) + rank = (-score[0:T - 1]).argsort(axis=1).reshape(T - 1, M) + + # data of frame 1 + xy1 = data_numpy[0:2, 0:T - 1, :, :].reshape(2, T - 1, V, M, 1) + # data of frame 2 + xy2 = data_numpy[0:2, 1:T, :, :].reshape(2, T - 1, V, 1, M) + # square of distance between frame 1&2 (shape: T-1, M, M) + distance = ((xy2 - xy1)**2).sum(axis=2).sum(axis=0) + + # match pose + forward_map = np.zeros((T, M), dtype=int) - 1 + forward_map[0] = range(M) + for m in range(M): + choose = (rank == m) + forward = distance[choose].argmin(axis=1) + for t in range(T - 1): + distance[t, :, forward[t]] = np.inf + forward_map[1:][choose] = forward + assert (np.all(forward_map >= 0)) + + # string data + for t in range(T - 1): + forward_map[t + 1] = forward_map[t + 1][forward_map[t]] + + # generate data + new_data_numpy = np.zeros(data_numpy.shape) + for t in range(T): + new_data_numpy[:, t, :, :] = data_numpy[:, t, :, forward_map[ + t]].transpose(1, 2, 0) + data_numpy = new_data_numpy + + # score sort + trace_score = data_numpy[2, :, :, :].sum(axis=1).sum(axis=0) + rank = (-trace_score).argsort() + data_numpy = data_numpy[:, :, :, rank] + + return data_numpy + + +def top_k_by_category(label, score, top_k): + instance_num, class_num = score.shape + rank = score.argsort() + hit_top_k = [[] for i in range(class_num)] + for i in range(instance_num): + l = label[i] + hit_top_k[l].append(l in rank[i, -top_k:]) + + accuracy_list = [] + for hit_per_category in hit_top_k: + if hit_per_category: + accuracy_list.append(sum(hit_per_category) * 1.0 / len(hit_per_category)) + else: + accuracy_list.append(0.0) + return accuracy_list + + +def calculate_recall_precision(label, score): + instance_num, class_num = score.shape + rank = score.argsort() + confusion_matrix = np.zeros([class_num, class_num]) + + for i in range(instance_num): + true_l = label[i] + pred_l = rank[i, -1] + confusion_matrix[true_l][pred_l] += 1 + + precision = [] + recall = [] + + for i in range(class_num): + true_p = confusion_matrix[i][i] + false_n = sum(confusion_matrix[i, :]) - true_p + false_p = sum(confusion_matrix[:, i]) - true_p + precision.append(true_p * 1.0 / (true_p + false_p)) + recall.append(true_p * 1.0 / (true_p + false_n)) + + return precision, recall + + +def repeat_pading(data_numpy): + data_tmp = np.transpose(data_numpy, [3,1,2,0]) # [2,300,25,3] + for i_p, person in enumerate(data_tmp): + if person.sum()==0: + continue + if person[0].sum()==0: + index = (person.sum(-1).sum(-1)!=0) + tmp = person[index].copy() + person*=0 + person[:len(tmp)] = tmp + for i_f, frame in enumerate(person): + if frame.sum()==0: + if person[i_f:].sum()==0: + rest = len(person)-i_f + num = int(np.ceil(rest/i_f)) + pad = np.concatenate([person[0:i_f] for _ in range(num)], 0)[:rest] + data_tmp[i_p,i_f:] = pad + break + data_numpy = np.transpose(data_tmp, [3,1,2,0]) return data_numpy \ No newline at end of file diff --git a/img/readme.md b/img/readme.md index 568dd30..d09b9d7 100644 --- a/img/readme.md +++ b/img/readme.md @@ -1 +1 @@ -Image used in the repository. +Image used in the repository. diff --git a/log/best_performance/epoch_loss_class_eval.png b/log/best_performance/epoch_loss_class_eval.png new file mode 100644 index 0000000000000000000000000000000000000000..246c771b3465f15b5e60cde15008fc8eaa862b84 GIT binary patch literal 23988 zcmeFZcR1GJ`!{|g*$J5$C9*f!s}vzb8QFV}>`f|EA|oTSvWg;m@5m;E?0F}fa@(Hs zt+B@NE5s|v zeaXh%-Nj9UkI(VH-@xnaYR$*@oQw!ALg1pL=Y~R&nj(L(a%FRDQ7FU9w-w~?d8e#S zczNHuf84f-Vade09V~b05$dfD4mKx)!W(?E8}hgp&!^4|qnYjr& zo`2`YXJ?S-AbWL{MnYili3T)H9Q;Ve_x}I?_5W9(Cp`?V9ow9^EsBXvP>`I7Y2ATH3(K1XdCR?fajP`2*shd2 zE2B3mT^AhM3a{f%i{PI-cQz$R?(*f!@dM%E;Y!*ptgK{|lncjh>K0O|>rVGX;{UNHelg0CMPsvzh`cVg?)@2gIfidbCEU4Kc{QG;TvtLzw^+E5KOQV%;Z1BsG5#6!7ge<3Xz-)921`2-??qccoWoC`4 zU1D~F&`;3jpNdZg!ld)%4-|1bekCKO%cCQA9w^KyTs4`>NzEsp#QE5 zW=F@!=jql$C&l^mauyaWIA@8!eLt{TpdVn6@lRZvsMh*yAWlflt2$BTvG->{U1&up zVBh`sk54Kvk*21mqVQeHrpN z*f=;J?tLJ?5q53?#)5KRpG-YitLAGyw(d>KvS^P{O%Wptu5?3x&({-zk=@eOr755A zyeR3p^6k5y3Kb3M#{P0aawy5gH}7t~c%zx6oHXWKKjAQ1wD7zwRVLu4)XCv?w*TH% zD{K{*z=~$%%|wk~rVvx*Q{{8T!m>J%2fT-1H_?P9?mKhKb*Toaq*(PgvW1- zjs!mEm{`7(%R^Zz)NXSxMc?dgqmwHt#E>~ChAqQg@0@mwqde~1RjAooLvPl#!{mB$ zq`+d_@keEK;Z;|H`dUxLx@@OK-MV7A)i}nuKVgzDIVLvt4vbXqi`mW4zofEY|t^QP-cBw z2rkk1a^3O1J&r+<(r#WoRd954^kV6-QD>q0`&%!&DAc7)BLuR8PL3BBlbx^GChTSf zw1ksoKPa_Zo4uA-tZ(Bn3BW8@nJB?wq;d`7mj69_jF!|`=~^x`3VMeKY4NnCz0ghH5&7K zC?l?@iZfDPYX-iUwL7anOP@V^W(qN}GFh7}V*kTnC-LhG>g{gPDVI*3at)XiBpJLB zwv7pEFTTuzP4~x>k?)w=hLD|%L++a6&z7Zu9NV!no|`YIJow;idtCGi9?rF{wXrJ4 zW%&O6?J!wG8{C~NI#6OeKwL0ZVyhK+uqvvlrSjD zV1JeMuyJraJ37D^3#$80!e)jIpaQ9ymfr-mz1Z!x?$s-eae;{Ey#KM~eUwZ`1g+3` zf+8PV9y*a1O46{A-^=CC)scmT{LiU)I$)QM)cT2{1OxMVyM_&eF|qNpj>A6j*2`XDh4Tv^&kA zBt9^;d%JT{9GaQcfh;eA~{^%+i z?bUCeW4LfMl+Y1dw+8!!4JxX0{ZTwF5mg^#E~m-`s%-UiwR`t2zz()NKHTHhEevK$ z^A0JW^nC$a!1^0TXY?qLPQdKUdd>DZ*u3wYCaNxQXgg{-Va)cQuvQ!}_Cr^|Ol7SNHzB1BZYYYkJ+4JYK&7`s{f`T!Ce;^0W_3#e2 zUU`6m@HB;bBy9g90!a>4Q|t97Jw=wCBNY#?2w8QR4CdZdO_xcn3plXyRaI6dynXvN zZM+mEIRS~Ucog*sm-Cdq!j|X>a}rtxpe1@$vCZFl!U{K4`KwlLIuG zbYAFqz1F4}#b^nU4;O7dqU((3>D;K|($3F-u-ygVQ|@H={rh*_8t;c?+m<>?O85}; z30+zcyTk_*e(SaVyE9>sLSDiwgj`XKM5zf3Lkfem+zmBw7s8kC_n$4DQR%iUl?2y= z)$;8I9Mba?6gL1C?XH$j39ih!Rw9KMimoM8@7;p~A@f#Z%+CJbWg&}i=e6^7 zrY1QuXEy!s?&WG=1)Da9l2K9Jg3{(Q^OTxhuZYF4+7ovg(q8u6_qS^Hf9H+E)Ez>s zKzxUi%Iu1v5UC8*Mn zWIT+2_@h(304rq})%x!4?r!X%)3vfyY2jk6P|1|;+*xS7=$z>?)8v1+*>(mOxAKHu zFm}Q%Uo(r)*nblj8MF81ER@Tb-Cb|5=qji-btlJuKTGWuy}iZqv_H*GH$EQB*NcTx z)EdR4p1^Au9uZ*yYv3?e_A*V{&uFuEF)P7n7Qn?7Kz+yxFflRxC@}7Y46}=tJs}W# zY7N1yo#xu2*qE6^CaS%*$K84@_ttL}HLUT3i|4^AWgAqw zE+YA7wock_NBA|DPEVoJK=z%pE|}^4yn!6`3Q=pMyGP zvxx;Khu(GAvt8Beb@m^13Y-oP4^=;CW~m57a~}Q;E!N;Et%}uM}7DQSr9sFw$fMK3?E?F!&@aEScKazyHt5 z=)gD>nN}LKUqQgZPBz3c0{HIbX?NFPCjAi5F*kPycFx4Q*4yj2+uv3AD*a}R+96fl zfml%meu0cYblj!+rgUo1(O(^$L_$hVgF!tjNRtR$L@sCGHQ`~^|L*2le#kcbt+^E6 zN?mESHvA}P%zsVQ%OWWVetYlJgTK$)BA8^;kgN%17%CEBL%v>d#deS6Z8f!??*UYh znIO7i5)+k_l)j0iE5@^~J-vStdiCh7jrU}2LAb8h{vTVwN#T)^o>M9d1b_e(U~`H4 z{$U67Rm`9$)v$0}pnGuKqbN1Y*Z>)CtDP;K-{beakjIaoMMbg6$fRlIs8wQU1uc|9 zt(E~Lo~NQ}y+JD7l6@yxx5$FvF&-(@0-3#X3Z<$(*S@v;ltBo4yf5_*MK9^e4S`l- zHv|j{A0z1aQy3$IcaHvBpGOTua#y-+I`FCfLem>-^(TQ6UTa+tnT{zh@Ey3ROwcSP92Dt)ppOa(e z7_`)1a+knDz_&dRKkbf?7q5zmi46;-NO(MlHMOj&1twUTcA0CdIeqB$FY=8A9|Jy< zf>~<(YHhindLw_Y+_UjtFjr8RkNG@^BgzM1y^zej?h0>zZEz6)&`Rm^*J1lXk#FjD zNvicg_sDF}0Fozi@r=yVY0S++zXeJ!YO>ypYA*#|^%F&2bzBp9of_=H(b zb@cMG>T>o{p#tyv{Oz2_^9$KOwymo0`Y42h=*)7O)`?7e)z5o8UR7X@QGwl~06{J8 zy>ST`>zZw;LW^m{%zS$+%W=otS7sIdzR1=9Y21Wtm4q6!%fjz4;S5fqkEO z`0yd%Z(#h<{&~7z%;Gq-T0JBDRd6$MY>tT)Ch!F&9NEaC4Hnne)uHx!(CV^Kz_-}# zRbx{Ir@d1sDF*7hvANkV7R>BAin}C_6PG`42*fujK9L)WXxcs zzUSXNt>K;T3Oh6~jOfjGL8JfP842%PaKb&L+G+c#hH&z|aCqf@5$9lKr%%n_i>FJ# zV%#s+mzMu`yKloZw{OIg9Qk5JxLgtW>l8$Xn8&U{_%1H}ezy}9G6Pz0L-#GDbAOPj(sBEA z{+s0+jCWDzm}F9b*^<%It3n}y>S+mSK~E-dZ)*VvUO1rd1*30=xkKr5*iO6j=N!wE zkxAT&P`nTf)g{|`sxIH4G6qPL+uFD|l(RN~w4c3!pfW=_$-I+%&F7VH`m=n}`(u)y zf>oxNeGQYiox2LQ!%uK7&`VD6UM-3O+bB<^Df$5o< zEZ1+`xUH%RJAIx`OTWmO#NS(}!vF15?|JP}PiB|~g9Wmwso`$`=}tHW1jvHPS(Obu z>tQHGK&p)b_ZGNoo&F2IVHY^y63ia9b&AXL&dZNboiw z534G5J3CHk-#;e5`rbmp;OQ)o27m<^exef+kt=giz-2PdWB6oExZ(uPNPSQc&|YBa zDv+Fe-0L9~-USp0Q!ud%)s?>$!N@>BDnKzR*bQX#B9(u@>2OH?j?T723^Z3=M$AzF zrPM5Wp+M;I(J88f?Cw4~OL~cg2_HelR9aN1yP|SgKVEtEBys7&G)Tq3G$am7 zZ83&SaLv5mnbyMJFL*jj>A(iT+kTgE?#c<$jHjSsH}2rC>iPopJ<=pL%$ z>jN3|p@Ar_8hz2})^tr#bjOaS*M2Jy_hDI;G$93bQ~9SF3k*$M2{*&w=Tv-1r={dm zB+9W3vEx;=HUS&Wet|13WSy#W@J0w%g9o!R-X~XS?ve;=zJfO^-o;{1(9UF`|9O_} zT+=4bznybV6M2;#rF_qok!jrQH2&w=DT8AwH|TnL&zvfNX$rU*7M};xai<%*J>B3^ zN|ErXBsc0{{nbSzhUq~>7o;{C*h?zBDg6k-ODfz7JG$%X{YT$DGAu;I{0I+;#{S)b z9C9fQ)Zvut7zfN8OV|>5AzTp=7uj{<i={o#yWuNDG_z$V*KhgFC4;$wkTi`i-qBQ6M?DHqtG7R^U*k+ z8bS%=Jz2-}#1f`27Ll9an9ipPJ~>(s+5^%u(VGq;7$O-02pcK2Pxxx#8Xa?=FOfoQ z@-*jFp4BF50F^ZrbXQd1;ih711Z{J&bKPx-KUy&tlXzomh;w1P!SLy6lcwh8rJto- z`sGG}Ka2S44`V=C)BJdk1M~%(@d^RRRTCnqQHs?rY9`(IFd>>5dORo$sO7IY@1qhS ze|b#>h(<~8&;Z1=29BZ^n2Kb&a%W@h&j#&4d&Rxh=s`W11D27*SZvWw{#(I zSE{6F^ERYDaz@4;tO7q2-W zY;^I5zViZ1gEgoPdVvh)#Gn(;fsi{N#uRu-LPA2NEnt5s=g-N zF>Mq4GEvG$aQZ?J)a3aIuc-x)P+XZPE?j^Kuu)6Pw!W)wf{eS@jwqwPUY1Qg?Tkrd z2yTfz+MSpE?mLAt-x1Ji6funrzihFWhNIWCE)}BW!aXAi)AC)~?h*Y#lqdr2VH2ujwVEU~FtW*3Tl8WhaJf(nR&If%R$#+XdJxGRZ4ryR zXedri#p|P}`HZ?iRG;|DBnMR8xn}b$%us2fDw2KO^j2}{&!2DDoNMEIke-qfQD0yG zR3XOFZ@A%SRO*Krcu%I>mj|dx&oM`-oe}f=2c?AO#H}<)dYOxoRRa~l;Z`hH(k4yU z8Xlh6KWy$$^Z*sfDE$V$bfDs)ITUgNV&Y8O0fgf|1=VCU;;POp<9Xk_33|v0F^<~n zSlrtAh=^`q7oo&9H6_fo1+<11ZADr3gIRmX!eu=p?9nPx*keev2^-A5J~!MUFA%tu z5KI*Eo52Br%h7ozStx@Fm$xs=DFFqOaZ?8K=$DmjmcapD z;ZK6h`+3@`X;MkRbVnmD51X}4nQUHsbr|;VTW&(>I7>peu~OWt3EB^1z@NJ?(Y8X7 z7Q;`aABkK!4OP=p6g|kDFjMAR^X>Ff-md_RX#fhG3_Nf=-o3%2j2q%7XG4Mu1V&Rs zmkF;A1Z>!TsYycCMb1+)zJGo-${uf$LdN?AwMX`F^tN0Rb#sj3aPDs3rYf=RX z?pcD7@3Zo$z+}K;p~SS;0lY)CSWL1``@*0u_+A{r5rp(ywZQ|dEG!!!4aCqI=dlRRB!ZL9l!o8P+Qh6wE$e*Er3fh%;CgQxawIXb79Xc2`A7$pnZ-g#j;gXm5FYS9I&f7`_%REQht0H-mHsLLbO=mn=Ku zv*2%uO@9dJvRj~OzJe~E&>NWR@HW`UGEk~u*FP_)+kXzg88ZAg9piwn5U4&N9y#8M zT~Qh*dNCQybZmQCGb|n<%%j#Pn)&ALTJ^fkk9@t*zbXm5Ss;r-usL-cR$>`{O`}h0$G@euQIbDu6cwNYe z_QSS64G8~UJFMN=7kq+H+|JR9O&d`&G4+DRx7Q|n9KA41E8{Qrl1_wKOe_%=unL-^ zKR)R&CW1!zRoJF4xS>G~X-rf+ocsJ+swVAUYSYNy!Da+0Wvtve=lWB^I6h;>V5p(c zs%-$(at~8~vO4JnB_XM<2h_Kvv2s56hFMUgYBrjPO6-RtA+g-eREkz%q@bp5gZ5wt z()eIY_iq9i;WPV;js5<|3!q5p^rYpAvfsx!=j^q3=JWW^f53@257z|rF1LA)#h%pU znLWTP1f-VIiz{@68bQ z#zKN|&+FkkOxLp&o{=Z5Z9}U14x5m8`DeLvFEp~npiK|zLuE|2`O{(dz^kX`lJ<4L zc}uWf&HqW;AMd{a?b`USJ*_!XS70FT)lwyr)?lL5m%h_le`W`r2-N?%zPC3N6bz)Y z^nd{ZSXo1ZSf-OYFcq&yHGeB?H6_@jdl2xv;-VX` z|NJ9)#n5QxT3g=lbYWe(GJEI|Pz5kW$o*HJgY_0+mYrPvkcQOHQ}7ERQ5$@KSyzYM z4DA*!=-AExsBjpshz3#%tqqs^1)~43g%UU{c&nY01TDj(qIwH?FJFEHYYgJHXS+EP z-)3g-{6>qd)ysr@xs7Gn^(g(fZ@w>rC@=^ea$bgBc@eaZVcwUFbdY2NFy_TtZ#6w2 zjRTA|fLv%*hyVJ;ReoK>X&j*GqJ^#~07ozjBrY8VVPRpNIdjHBJON2zffl8>{5(9X zF<;};)0u2-ZU4cG|GOTD52(UjHVB08(NcTr$7e~*AOj%UEwtSqK^J)Or&>qys=+y9 znX@X2BWkb2xaw$qP`%#-AO{*~nt^27zO@|j{{$Pqe%T^LtBdYqsN^ir`#;TVyCfnqqz)a!owPyr~NV8jmf0r;n;1!1uPV>q2bP@ z2W*EkgFlYa+A+GU;ekmj+Q0obKR<42RWyiI7o_%&%lt`Aph+a=zq@gBIf_B@7W}}? zg6L+O{%{SXZCpg{0CMT{e@khyz}NsgvzjRgGvNh_K~;KA zcJ|g0W{QG=!STwQ`6qLmg%~z=Ln69AMW^b|ag#C@a5Ej00WWo}uO;_+=o{qfeqp|H zE28j~s_TSjiPf8cO7NM zrLj0m|1*sirKi}YyF6jsQ-VN^(beh<$N*5fopo#RHe8_*>Y(__p3v7d$QwK5Zt2I5yz3N*_?Nd@c+x6)v?p-F?7 zVu%k6Dhvs8|A;@LxF*RUk8CpIZ{-QWLj)~2arvH&DGgPbeDL+g^JCZ=QgZt#4JbU( zu~Y5VN!u5z+Tb%WkAF}AFJq>e7CsN2jJ9M7Y{751`TkrN&2zOsy{Z1CqNbJ|c(e~4 zwuVZXN0R?k4T*rJAU1Ec-vt&gC%Mop{{=-^#+=DL=u77}c7io0V-byv?IZK|EM ziRHlK0d?8UK`nKGz(emBA6OK#(jI&LZn}8)AGA@|t*O_NrUlS83k zGh_e=K|yQyO7ozL-ox1z0Mn+_^Kt{UCcnCkM*} zJb`N}&6Qce+rXfg;9R$N)|?nh03LJSex1>WL0q0-y1BbRWF#PvpfNrRr5Nmj zh`fo}Z)5a}fX#^h?I&vc{RxLww##}g#ueIBJ*kq4z+Al7YrF7x)o$Ou%*y)pND3;q zj*xFeeGdV>!Ov5`SzuL)y8!`)s1GmdLAn+I>;V=lWp8iq&vh8+OVNPagRk}vx&4oP zQ?PD#LF(^3F<<)~6tK*p2qZ_Gg&MNXhk@ z*fyr16hxGLe~Sb%m6-dl^U!~vKRI4M`TX0%ke3?TbQ#cQ1Ryj0m=w7I#9xJ}0@&IQ z{u;zf2{UlB&EO+0KXh*Y;mXgPc}S>#`loHz?jMXnYJfa$J6R)+LgvtY={w>|Y>#CV z0!m;IunxYMDm0@Sw2X)>WT+p)wgHm}Eh6qmhqOW>+M2%&tvakkQD;W4NgwFf$&9-X z+<|Lw=@-0CNrBcYH^}KIm;q3D+0MI2Z9)d z=;?saklrX3l4-mI@cXL0oNq1N4JCbV^&1Qn_4+G_akkIcEju)T z)j|4r(Yf#rAYP>meu=K;$3+lswQ~nMPEl}yCx>WR*Z^N^d9Pj#MR-pOg?cXPkMJ7yMz8oIL*8m1NWL|30Hwc0UxB1C z_wKKV-QYC*Kt+E0Gp8PHDP8mCFhe2ZX37~Gb1A9fKoPE z=6KO-{2?B=A?M(G$9>SL4ii;F(u#%NiEh>@dh$q<4&FaonqP{7eG(5X+)WV8^i!p4 z+Ov~37T#;r5&UVN5~_tb@aDAV2?oRuz?dw-H^etVb0{@^=~VD#=>API~}G?RIRhz)7@`esUNF32y+p$p26gj0#YZL9D6TG|wAkW3(Uy z6?DkTqnIv2-Y@MPIF{BS-MCBj*?*RWLY z+_41N4!Xy`-#LOrW&uBBW@YIqUG-^_W*w)h7C&8-*Vr4tXOLDnSZs_-crINEmI>HL zI_X61Eb0mqs87aE7S+d2=jJtF9Ez)|)4k_1)VA zliZ&_PVsDkMHXtj%7aUE!d)wS?4pQW%VKwxETR*!i0Uq#afO-@6B8pi0Y=V;Y9L~S zJ!9!*UG>U-OvH<0Sx}BZua-PrTP&cD;JVFJ$S8yEIi zth!7)V1r+SA1FM24>o2s2G{mJ?=Sp&apSMfDL_*MlnRXsT2M_-D*pb^FW4)M=hkZi zzwK8@Sz<2JsQ=WI1nH}+tWN=zY=WkG9*8ZlL~n39{Tj&W6+uUHbKY;}K#@9?A&tlu zdw?}ZM@QhmQI4$n0eC6=*|V9qVT`k%tJgZd3xJ{y8}rzt=?MX%^a0t0LYr{Khvmii zjmCdDtH!JpV5{=aw^S=rgyhbMK8ne#cPL$nnC zU5gQ}MFn{k$=Mx#C-9}*aAnX%thy2#0l9mFULoc>j}PF!pEU}0!GEKKf>~eQlO`m? zzjXUAM7#_WPQj*DvNH**2ci?0FRw%XMj?e9DE}g`m3j-VX@CV_EV40D{Odm>l8Vi5 zN~BO;cp5l!3ojmJkTObrgP75+^B0H4byqAV0t{hLDb{95GoWiqV?I*twGO=&sH*^y z)+_Zf64HMOhPuEoRI1Qg-}dil~4uw3wf?=a+6%rUe(E^%{5gfaSt zgG%%T!tg99BX}PJ1_@fS-8gh66=8>ce)dm1gYugM2NnR(brc4IC&hjF2akTjW9omN zBEAPxfX*Rc6u>-;oQQZq#pCKcR%A(pND<&Gc=Gfq;(I^d>kYE1M1imJKXO~?S);u9 z$NcLzN6P8abaCsnoz+3LCCy%U?o5?~jCLDZXnYTt7#Y)6RT>a9=~}u1VaEGF2awrf zZ+iMCOtitQF-=)-SZtW-9QAp(g40~Vh)BSwf!#Q3a866N!m5z*PZ0ojH4U3tk6~aBX~pottSj0lql@X~;6hJnj0f5b-|YxO1^W4YJ*r_FED_v_9?3BgF{K%E7zn>pa_go$$C z3O;elTQ8YS`=&4J?#blS=O72ISg7ZHh%%th0Bw43vdzuK`Y>Gp|Bsy=YX=DK#;x-mTGl|m-QQ8-;S&&3UV(RPkk}o+%oQ}@!a`Ze`;8W8nx)CWn^0HG8!mL0(@fpMVzyO$OI70)raE;u) zJsNQM17z)gdjk*3fGj})mWb119%J$1$K&FYrz{`}Bz^?tqF{MSoTt*jM^aV3cW|Hu zw)3^BWnTDLz3pa%7Ro8{+^0~vXEtNcU@H?!ap6t_>?D>0h*P;Z~j zDoIv$8RP4TQWZxLWk+*buNJQmv>%Xlj7kE*9+huYE3uAq`*`rs7Hz$9LFuRSY2zQS z@BwOPbu0nMnn6h*bS3|+jtyUDZcpI6_H^sz<-f%F$!5r;D?PGaAa|*PID2>mO{9`K z0L#cB33zL{lMilW+l}s?(~5M#8^|Yw{PFG_DagTbrbkd_1^*CJy*<2BMuFfa;=F`l zq>c9K@(n@avjHn30%<|VK8WxH``X@Gy)h~FbTKdJ!j;Ib+$scu%y}}e$ z&P3?Vyfqv7bQ|e(>j}?Xd1uMj$e5r$ z9YU_IoQsr%BTYE4PyVzjC3ND7ZDYSM$ecVX+`Xgp6R+N%ki_@j+GT($ur(j6P7WNG zDdJM;bbmrK$X$fbd7jw(TU9gg8-nu(PBDN>qdf-B$fSb^j@Uzy zb3ZU!URXEZcG|Erl34ye-SV0lP_dT>KQseNLe4{pPI_NKAo%g|F(AO>ccbzOyzgw0 zF5l6g07kJ)}6O}bh#Dgb*5syl!b3}H&64tZ;?lc&fevMZO z0b_&$3uP#2!90P_Sncigr@1->MxPgA4SqOWf<|q*kIOXmm@SNQOC*W|rJf=t4@ce* zQu5;=lufYS#zOGVYXbf?ef@SPwP!Bt%BmG)<#u%bIh@#73C|TJ$cDuM_TV8T@N%VXITk?N>A^o}A|Xy_mKV=s^5WC)jFwOw4snP0i4oThZT3XL?9y zKKwnkWAQfI6E#3z`6Zlpi3B|YTFyV{u7{B_sew(v4-OG@EX2h)4ttzBhFo!t^ibK^ z+27wQl7fL<89Jh1UVu1MQdM>5^V?Zz{~G7o9od4M_hXLX^XZZf9(@QIQoF#Y8ziJ( z!v#zzz#=r}?4Se#PM!mLTlcF1q8%F>drRW8#qg3sqA%Zy{<6WjbG+t~(J;+*v#)VY z;A8-p3ceMtJtmcz2K#^sIN3oX)d&5M`O+Cxy3Fi`DD#X7S;tfNl#Wz3N>xn_k&M^S zAeBN6(|o5u0LP$uVXPzN&PC+=(D34i3f8=<{I<2pZ|RL~#Xg5u zYecXkDudg`aVp?A+wb85bs#@jNZSR-J{o-K*wb*j1&;rWSt+!u+#kE1ROPaBCZFz9 zPJwX?w$hAUCj9t6oqR2x^m}>6)(z;|%lsG49us0UmQg$dGik=UGpQZcg-WGv-t*c+ zOBc4b&-~jgB$UCkmHz&zqw=V-KpeXj2vjVoT#E2-tV6YUhanlataA`|RX6~AQuH!j z>3U)6bwr}?QpFlR<9XQ;k`zfh}M5R{xETgXXupJ=OHiR=Mh8sjb- zAtPUaKH`L%Ek!PAJ5K74W3t}6p+|ap7y00?nq;jIcK632#Fc-VjULX)y8c_-uM!D| zSA=q#G9(8L4KMpP-fk~n>wY3YV%i-vGfr#YkNiFaey?^@IV320yH`j3$@_==)~JZj zcwB9e=6}p00`FC@Eubc3AA1(a_Y1x%AzeIXn5;b)UyG9y>ySeMA#@iGXrr*7Uylvk znCI)WWx2pb3U;;-9h3bVI#LGE2EPOx$@Ox*oXsEpKgL9cZSw7QJ-d&V)YzIfz5VuJ z6t2O9a?vWAe;@l!%sjiVQ97~rm#eKQcX>Kb-o-tk)*r~zFxb?-E$r?FcX}M}9eccb zgPq6L;5Xdd#WvmNyt4%ZHNb^|j`kl3{rXFJ&kolXeaVwB!uHI`GA`baF*R7F7*-f) zQS(B`uf;w$Dyc2$d-at3Jz^nwt)u^b{S;21a};)nJXuS>et1jw8u=R%sruDEPiwKH zUsSYwml5q2T)a0w-%pZo!6?!W^915t^FOyhC;y^RO_7{31inN47kQzQn$TIogfmm9 zkpm&s)wAHE-NEg-2HB<)mTz7-fk>hbOWiT}Do#T^JAL?5^|g@OsU{fCF@l?L>K*n+ zJ5OWQ+ljoM73P-*;yfYM@y$J~WwCIJ30hd}F=Nl;p6nDf273U&beTIzohl1TZ6&(^ zzFh!pECo;EpPp+@^$ElJ{@r8;<$ZECafmu{wJaI*n3_;%ROoBQTq>bY6xPKM)C{BQ zjwD&wM%1o?2%`qN6_5Pal!s@Sj8K%xH-6Jv^DlXsPal}N{LWo{0=w@m41a`b2;Vno zp^I0u3@b&NKt)k5MFL6JYkhKk%zHPpw9ikFLnzhvJb7c-<1AUosC+-{I#$_t7(L!* zB2@#N0_+265V#50WZ*RBcF{r5-a(Ct81OBviHDAfffH$|FU~PZn zu|S%i24G-tBB>d~zoL>8C>J(x<`0x55|C%#!f8(T zweckQXTeo7+Cid*!T;#e!>9p$ZwK?7r~EwKo`T@@%P?s5o*cgibsy+y9|*J#3_8`G zj%&3TBO~9XEbi6xQ&Sf^)R8Aoo*3NJjOLc`BMdwyKe925CZsm>4+X9MTeic0m1qjA(Mfp);@i=bQXaJh`8gxE1`zTF+CN%RRUs%0VrWro_dc9iMKdXugV^@@;7l2x3t`KjECp8fl)c(NF+tv{ zScxeqT|2Ajljzn{iCP6f3A=Hf)QwPLQ^3-|L}*=}*MpL{1sl#4&hL;hF_Fr_v2*ab zY4^y_zj^%`ch$&)F^4`(I?n{{inbmNMPm*&adwxcG)f`{By`_$+;cd~n)l(-1rD2t zNAt3RR30xGpE8^F?A&y2*t(m<6y%QH!!GRLnKq_*u6W}MVNaJrSh-|sHy9y)WQ|qQ z2c|yUT*iIv^1RV0_A*P8p4V3RKvFl~uQ%#CuAxY6gu~ho%AErG_u-V7fbDk{I33#! zB8wh8GXMpI)ePj5q_uR)sxUdie7ykEqmz`qFur@47w*|$dmZXC4KnYwj}A=FEk_U# z7A{=WJ($`~ugh)FS($ zWp0U~-(cf(+QHG$ATUyj!xf+OO7MLInnkqAe+ikD&V;&Os9ri4^uHEJ)KfB=ERk2z?&8#m?3P__;sK=4> z0|%NZxnA{=US;XFis^4X>-xds9=_raL2K)7@^|TJ2pgk&e24X>y+5be$-f1mK@d%S zrW;RO<+phF$-*J)oT2RX(Pg2vAb0D<0*|YEhvo@ULKb(dcn@EBRW{_DpAIQ0-k767 z$BY;D+Bz@QIud?fy%q5ZvD?;#N-uTNbr5r~^?Y{+;md37&)zE5<&6#OoQmzu zhrC0(yT?B3(FH;;ZdYA@#BbA?evrK$dpv^a4Oz$LWAF?F0YNC;g?G`_)RZ3gdr76q%P(k0did6+DFx3zrj?W%-!{&j81nI-rcRvT+r}(yP40o5xV?Z` z-6+_`osMSUuMS!~V0 zs&2ya;l}yVyBH_5QBJR~+@Wux`^7(o9v;#(tj@G^CD?3i1_pzx28+aZ;%x7V64*`6 zNXdB2f5-UmQG57)G0wHIwMPBvG8oy{&YWMs@B}t)u*rWFTGJOkob4=K3h9)4VJ#SB zNTxDKv`d#uT^P=)=+!iMkm3_Ppzqg`Hdw5tn3`dB{`^Y*SS@)V-FlaLYRMGGz984{ zYu%wI(%zfP+0-88HLl0&dm~P}j&)k&lq6nd)Fi{Js2HC8V{M|0@CyoGxtiuwgvt(C zncN>dbf@L9HXM07GktADu&ol!V_Pq@sPFqQtX+37RSf@PeZ}ry-XTqDJtm2lT@UW+ zW0+ryEJfp^7#ci&~coab82O!*pF4UT6TL_EO7d8E@tR*c9>S*Qw1?hAMVLs zv5Ri2eAA29ULcZgu_8m<7+PKix<3#5d3vdXzjDbXRjH9tlP^WYdYV2;kWu88`(WqNHZm@Cg{=L`?=!9> z?i3%CWQv1Q2gFJ|D7fl*K09Fvltl^GF}HdmSaGJ$T>a=_3D@43;$;?y_PpQOxW zxr}_>uD>-gf5DB}NO@`sm=&2ixgHba(If}<$p|V|q7mfAuU2KThDN`|t$*iTkrbrz z9>#=(l8PUsCOv{cK>;C=ey8!OUC=_P+~Zr`71fXP!r zP@!c+Ev{0C&2*xzXXR>}r*Tc~I0>4GIx4gZV=(@Pn*QNtiFe?m9!o>AXjM3EDWXA9 zq$Cp`PnZ*#noiY4gLbtIq!X5~m5r$CNt)%&6^pZv`%*zvu+T++nOXe5#=bD;SGS9! zD??ZoL^^R*GPZ5Z#XBr^rnM~veL1>sg~K{L&KL7I$W*dYp=J4Wh+eb&+yQu+FrAwb;jOI~=#i@KoO(J%oo385 zf&E5?;Pu$NodH>N$~&k6XQr#EP1;EL-!ix#@ZEC?*p%_pX>NnFk-SW~0eHf+sOZ02 zf*;@r4-g?{?7p>hwXjzcvQ2DW@;PKZgZ7)tZ?98Wr@NB8HFL8LjcipHEB1h?m!mto zNSAxDGCWK?0lWC1dEd6j?c`lB(b*%g;PyzE3rD6aVdbnllbqpUb~Sp7s+gPV+s8%k zC0XPoZ1fwVF?ZoS7o!}~z2gN9qtp7Wzd4+s`dIQX)Ds+Sm!JwL@Um7_BF3Uh@QVM} zQ6yk||NC)efJQ5z1p-!SRd@&zunjP08*L&FH5vHdPcH)Q(|MEubeB2aM-|H(+OPj@ z*@=4bVh&DxWPw1JuU{UCIE%n&b1LAuY9ilMEtO5}0wUHD9A}8fFu|F9@K{+xCm)5J zMgU91HW;$qh7G+IG?kI}A4p~6@>e70mB1sCU}^|*jvhEIpfOMyrwpW=1>>ViWa9a? ztxc=SeTjHkIf1tuJfVnJ1wKv95DSu*vhwo-AJDx%qe06dYqvET!^WT9?~GGh8el&8 zXJY-Yvo@~M??FgWQrrHCT?o&60FBsMH{sPQIneU(iHNG_1pt&^77*yKT^TKA<9i3E zy2Iz$3!qOA=Lz7*f2`>QoL97gr&mFDH$e(cx(CCbnsDIt%rW-Jt%!>!K_kk{+#Jg~ zF#-PPxjsoE_Q?w|lCGbC@QOT@2(0$d;ZENB^*$<*-xRx7YCGcPE;8Ldt>zCu=kA&`wT!B$-qOdJFoGEvqevATH8>~#C$YUdrqetJqDZyhhK3z6}iE5dvRQ2t} zYcB&E4)X8@;1l=+1Ya1C_Eym)cw|g&E;}4`0dq(0s+hL7Dp3Xr2!LlboKKIj{DKA{ zCyMnzIgTBx0aS45H71RHXT`LUst;~liBfyx4aL3cXl;IvHb-fUCkmf$4a>Cu4y*HiZjzT+|Z%elJ? zgF6-Rm%~~3pOJA_5Y`HhWq?PzT?OS2EZ9XJ%i3dKT*u))f$~fJ9n>m&Kc5JP1N zxm+_?T3@{$bTBsMQfUxDC&JmgWtf+rG6kChI)X$nt7q`=kS98P0f%Z0Y4RN%iSqe* z3lP;6px*E^Rv>Q(6`>9fTG2D1gGlU+RIU+7y|K-ULx2wnfaQvH2OzX~}$0?!ly zsb-s9uczYs9`x-W!)+s`HE1!nApQ;HSP7i=*%q_iFeIq!R(-%nXxGaH&{1s4fdi~8%F2JBvR9*oO<-N8`oE+W=! z^z^h1zLsKx<|#hv2Z$&ycAc61&g}i4zBf?Vqh7pba^$FiUgRl@-ky2&ehX}tI2LUO ziVbCo70?QF;NxD2cc~jT{=2C#Psi2g|T0fd;(1Mw+Mm~Y;c%M5-BEjW<+6* zkxuDIUS}|xvfS6P@^NFX=S%bFhcAAo-}>iLN@aCvh%hDuOfXX`_VIa`=$zu9>P-mj zFOK&+S?rzICH6|_x_1z<^4PE~bAlw()z>Eqxd%bpP!Y9kPjug{|SQ z8hsbDw21|LL+r4aq8jA;Uvd?R1>8d}$t3NDZx(w~uo*$G@SyPmfZXUZYRY_S1j|Ns zxp`0TaVyC?tOYsi&&Z8*>g??$n0B?sjQQRy@V)1=7bS_))*sMKQA}pV;fBEWHi9Ma z0D;?4bsd{A1w)+Gyrjv|@{t>(^jlDBjJh6ev{6#6{wIq%%N7u68n1tB?S@~&Uf z=s??XdAPh84c;~uaVu(AUL`Z@D#U$^i&yIq{alLm;OOYN^ugO?c{F6Fjqi!FhKw7y z*r(3SS!0a}nK`V^dU2pT1-hc_#M7mCa8Y!rfWAj zy_YHLZ+KFGo)>QALhFl_Bk-)cza{Lqf0${BkJ5gl*=1clo^Ijht0R9jB`rA%lEhu0 zLQHEIBZ*`zD%v35Q`gkRgGBa$?N)63C=Oj;Lj|+1M11m$%=NbVq9*}L9h=v=$Dv#>( ztIgZ4ch^~AYlZ_;7PsNUIKd@nX8!m@$z}F086%f`umNUQl3(1KdRyKf$S9;Wz@c;P@PtoF4B-qrwWlixq!#p-&`U@{uTfKkQ zngA+##qkhr0~5`R*b)-t_v(?QWsPmk{zZs)~EA+%m}<(gcWvR?~POvKH&I2|hCV$)g^&*{kdJ+0!pQe$ZDNl;~0g3!Ux~wMw+VgBAc& zxr9C5Z(3-^)bW8!$Az)${rDj3!1-$yimeq;u6Z92&BH%zom zO&+(Yk~qnt7t5lSj6ERaZzBmw$2tvvN)q^8JOXR%3In%1n41H1r>Uaw=Mo`jt-%*HY9DiFq$R zT#Z=DSVRxHaSEKF6WW~h^`?b^}K(!D#aypu7m32_&JK>6~Wpu z%G01jmLh`x#h%t2e41ynu@d@badx}htwPOjjNh5+j3&f<}L1yYl~(t!9iWI-rf84g0%sk{2Q+BN_qeQ literal 0 HcmV?d00001 diff --git a/log/best_performance/epoch_loss_class_train.png b/log/best_performance/epoch_loss_class_train.png new file mode 100644 index 0000000000000000000000000000000000000000..51e019edc4c558c878d4b2abbe1f87c89aab39e5 GIT binary patch literal 23976 zcmeFZbySvL*DZX}-I4-Q(%s#HARs7GN(l(kN=hprASIxvbcle8gaXnfp@4wW9STSz zQc`DK_`L5q zK`^HAap4n*zVRRM9~sXJCZ1Q_Zh87x-?c+@tv%hH-8`Ld-(>N&yX$e=%~ecLOi+}c z#lh3l-9uJL$mPEe5Olk1FC_Gcf(R}`;C|8613}2F(f?r-DCOTq5OZoRbrnP3wAG1w zHw=Fs<8Ac3uzqgMBsqdj{fWd-U8J>ce6`xM&dx}%a@F23+17i2=KXcg8q=k6@lJlR z_c(l==ljkKi5uRaC?_XZQ4h0DzyD(EAT*5ugBmBpdVVXYZl2oXtH;-Xhq>v+J@EfH zOq1=%DJUosCiEjomDS--Y%~rF`a!xi#zpw2t-T3C3I86UrwKv7_I~*PfBpZNVZrOe zND)o?92QM0)ZE=olE@Jirv?vKshn(ghMJmHSeTqJJa8wx6sg6c6|zuwcQ09_gYP}h zii(L*oH@fG`;eR;Sr17~Jsa{`NqKE;?ZI~lhQ6ZSU!M4+q>`EKs}w|ya_%m}LlYC} z=)4bjO>44J_V$+4zI5q=Caa@pT6#JKIXR-Kso7`nmXb)Bjg9Y?=!?h6$+&q}q8b~O zrDbKyj6)Z3ZA~9GSvCIbe3p=${7}+mba>UaN2(d~ab(9MdScw?I-J;ecr7C%rk;#B zoK2rU6YxGI{PpXXFdcEPgLn?RDO2(%b#-0DUpR8nz4T0QJEgFG;+M|{qa?1&ZO=t))Z0r;YjNS|$hG0a{jJRjpD9)u znJSB&O5e?!9c4F!4sw=GSBgPyVL?HNc3$;uxgw)6{M#$E5g}2L1x+w2m6Lu%+1c6d zckim&+H%zI&!0ttNzGlyKT;$lB*1en_T}Qmol|lxjo;hzg?UmgGRnR7#>nKOdv>c| zokje}dwX_iX*$Pp2A9zfO<$h!35+tzdgAx>^|^a{YZP15Jx)pSyFS%!D{9^J00*Ct z@PS#8#xp_l61z^CEx*vFkYC^P*;!Z+hFIJpvvSoNH`4AM!MpV3Tp%hKaKy&Nz1bE+ zU-_r1@;6g))!mYd&=IiavMho$e25fapE%p~urKP1!udc?bgyLqu ze2J|*aHnedUCs3DEG8Kl*`KwEvoQElUAv#TGN%WOEnE9?HDQ1`w3@jx12B=t90Z-u zgxcYCXJ%&~1fLvC2{Xj{u9UWuQ&M7badCZc8X|>|t{7g}R%^tPwjd=c$3m#!R9EP_wb9^5Dy0d!_zX0 zFUoZ@%(U!HCFxKCB6?1eV8v^yy)oa!q&g$S$D8jV^Zmg2JPdd zq|jW=_{nbzlV5ylU)9z!PE>of5DC>HK0ZFq5KB~?I?4_X+jSSQQmjsmB;0{vr?p7dTEWx0@q7i#w3ws z1>fS@--NTXvm0=_YN27XttcxuH@;=yHg!Ni06H4-`#+NK@$nn6$mI2x7XtqN%8!hV zZM~0$R}y>@q>~{}5AWSJ-4yEfYxiO>>vtDEe*PyI9pT~O^L8^rF3H#HH*y0Mi%qOV z*~+Nx?CgFQUd{heR#xV?Y*xQvQ&jx<+_@5mvE$b6@`^yKeZ}Ys1EPw-;dk#IMMvAk zYOt@5*;k#kGyCf>D+Sfo+O(6>ht#@G$Zl2h8Nch$)6@I?%rc0lvy5NqPp{fvZg913E*HS**nw64KI$A?qF=!GgTB@HM?#2@e|^LGF5ZEdL4glw}gMs5P<4 zaJt8JIV0xB4_ow7Vz;^)2aB$EudX$m1U-5BlvP~(%xPd9?OZoB8k`jK?=A{zcZ$(q zw+KG+@?0CIhqdGG<<$x&LDIuOB4c7cHQvYSE4<1z9(;1Nx6wj&XM5=ya;@Rya}t-$ zVsAE`%v}x|A&Z!mm6hv%$InK^#l5R=-v9kGGA>Tz{jKilrKLw?N{78^rINu4amoXw zZ$E@R^ieD^RYQ~a!eZ)6(v>6Gi%f^ zCxNd2n)UkV3(+*$kXGIA+-Z8>qev8*y;S?skT>D+Y+K8P;%5W@$_Rm`cY+m5t;K;J?$?!VVJ^I|*+B$n=Y#A8-Ld^c=Omlc- zRMdT_?8gUFC$n$`sqvoe1}?2c*8}NMxgBIs?D8`wg%IioN#qGd9^Re>+uK7Hj;ny}#@d?KQn872IJ z<5IkmjC(T1N6SBWp)z-Tt*WZB9(-&1T*iY_hm-2>pI=z%dJ!-MO#=he0@o|B5q^^r z1fq6hYpVqk`}EwLP5r@E)k20Eyi2ri&B_TsfBw7&A7dcv^#_8On3%or zIrkaDfTQi9josadhzLBz{U1~Sr~WMVVSWAjb#LW;&onH*V}<@#@1{if`H43-H-{_T z&dDn%Y<`tnXW`;%zwwe;+8r6IbZdXisy4m69F2uXI{o`+CoCtdZ{NO^+-xO`VUQra zfB*hPUENj)7~wm!pHh!q&B`6aKYS49;NoJ3Jn!o2s{dv{qJo5osAXy3%_d|$Rt^sA z%)lM`prdW$u&}W4KOO}^CCFR*c_%~$TSq-?$iknbzP>&(I=W@O;e-^xfF~>!NOclq z6LtIZslM|s?A>=(Ow~0sZniumarb=tH~99@JM&FI2{S*RS++S9*F6jgLC3r?08bd< zzg|HMXSc#Sx&vtwt%~qZm&fa3uo4zMW#xCrJvzF(FZlXOOF9pSKoLcd8|LPlnJ0&M zurR`}6sp54mEQi&2}A1Z>pR^oHKp$C%nu0(0ygYi&>jX%t<=$YbwK+w3qJx_)O50* z+<;0qe;t}RIw=Ijr{-`Xk$^q-8`Yi!`Ij@>@-u@T zLCs+SaP>X^a{I~A)X9YS+;FK)Y0#11c;M~?H#c`9tY7}P@0PkRL~&rw>7-m40TAtN z&WU&}89Gna?!U@Z+g3}o(|aFS1IhM-nE)XX(ahQ!DPmG~#6PJe9g0_jH`iU7?pB7! zmUEU*NJ!Jnj6qdZHLK4-y}i&_&S&itpwiek#>LGoEoz>OWg=34Yp#CZADb+Jkk@^5 zJABipJS1YP+g^qXfCS}zsr&t6Mehl_$NTkp_TQXmEAs|aPjj{ZaNBFOy8dB>|L&TZ z*vIN>N$rQ%tw|dAUMsC&aImql;u6y#ea5T+LJ6p;C3;4u7Z!%N>2q>(p@58YXL@nH zo$=@UT9yK)Y>wKEGY%WL}T|_qYE)qXIPO`OmnaP9X$7ByUQVQ z^K<={iH|TAC)Ii}D?H z7)%$oLZTM+%>DmdyM4QGnnPW;vmB1_>&j4=lL#Sc{C*+)X5jfuJRNl(9Y8THEiH$Q z{`U3@jy*Z&bCf=p#eDmA1(qm$fu)`Id;aVQD-E30IYRMb-!6bzm|}n^I5-#)roFH3 z>A}p*Jl`N^$VAkLJN@g57ZZy=5>=phF)<9HcJJ@sL0J+!iX*MOE2RtToF<4QQmo{P zUG>EY4r1ncs-4jvSzxhY+US#2VNYNo_zeq@Prv3qO{D^j#aqHmoG#`@+`cYd0a;9# z3WFl&9Dm1=GP}j`DzWj22?&O`=g;XNR;mVfV7kLNv<567{9HfvL${!Qb6B0sN?)IR zGGI%0E7w!;=+D&{W(5Y#I7SutwrDKx)hjd4BT1*hw-&$CItAKNEWc%YU9p~@uGlWS z41=BX_a+t2bVmZaz+}Vm-(`~uXZoQ1pU=+Hgusg4*w`r8mA}{O7f`k?uNrvcwAV=( zD2-yBZV+XNMZu}Iwl+Bxl}Aejnr#6SLeVL;X^%@?yIwq`mjB(f)^T>bU0irS4fi4X zwvW5dgtW(>i+J;fuisJPmClO?ZVg8?Gcz-gaNz7m)O_UU&c$yp59ONIc+*L{vv{wL zdbRXGY6&%ZBhm06G}Os?VIM9SZbvPg;+rG^BW6mQ`U1DjN|*(z7Y+_iWPE(P=D9$$ zD49h!F2^#;P$qI%L*6N`=v&_IUEZGGh4tLj)HLAAUM!r_d&^=#mpj6e!Fi1W;+$<$ znk0r~81p5Yzz#H?VWv#07gdiTb|U9M80d|Rdk90((pD0hK6BY z%Y)VcyoSr{h!Cr?P5#vx4t<6%Jk6>B(TT!Pm*dX?<5J++~> zmM0?|%!a^7JW*w3?Han~NsN9#^zSNVVk;QVR+e0Lv+72y8k!QCLo^vVb10Js@Fx3w zT6Ngv-Tb-!3+ap8r|$z#_tD7*(}VYPxM|?3!@54z=ZZ^TxRERU~D>=+&xAqMQ&&iMKb8p{r{zv+LaG6l!3BZ_EyW zUCISiZt}r_0?7j4dQn@udCaXL7P8o&fBjb#W+Wt+mT)O>3LR1}J1T{cU^{V!Qs_9J zJbA*x$@!{H@?duzg}2M>yY{~? zMvP&3IoU3HX0+i^R8$l60am5XfPbNXd#IiF+A9(F|cC zX@EfPqHT#F({y(ie)Q;32@Legcoo(43TF)c0gq=-pN2z6f`y=^Bk$5PHc3euwDsIs zpNfCVXAEpry+cV%pFo}apQo(7vI7=YD32|Ho63y{nO8%CWhW zAt*H|Dk|Ews-~)10;sVoO@?Z{ZkHJo3#-)jb0Y2w%y|Y)JRC+Q=UboBVXRVwwd%ch zu?Ak~z0CCxbG#{rZr2B1iE)s^mR!FaLlE_~Ud{C!|BZjO-`X4#P7vcWPX~ zL{}s?tYZ31d{sCvrxbCv4Pe9n&oWH#+MkPJX@5^>a9f-}554eCqz{Hvnf^KB4}ZKh z3gU@8e?Hx~d*=+(X&Q{_Wp2Ezbadvdl&g3{rh>nuyu5rbbWShx^NDl@-OmsbMDTrM?s!Lv4BzM1ci?g}889}B3 zcbSeip{>88zq`JC#7D!Dbr09C(&E6~(1c}uM zEfXC}=dQ}?56N(I*4!Pc%vQJjS7b)$vlT)@@yOcd;+5inOG}gSz}=nn7ec0gd`@y0 zc>T1)fkb=x<%8T%OpT>wv`Yop_ zf=Zi2gCnDg|Fm|FP-NpwZ}ljVrR(7#(m{vP8UAZBKqO%{oE*#Dw6VFM;p7e&#-#nV z^>lZ;46zQ3UaY1A#&e*>f!9JH@5vwUR5r}SN=-H%tT)61uMr(H$gOdBhiQ7}$&E$6 z*i(`}p^n^-wj?)~9f)`g1f?oIMNwfO78Vxs87kLls8^!e|Au8*j@{DNPCTuNwp?K} zA%TDYs6hARx4RY#Lq7|d6B|JXUPVO(MvX9BwfRyo5iL_FjPgoMFY7cc%jTvgjH!Sv z5?Emm=etv}5J|Vm7{FjrF&gmbw(q4$x|j-z$2u6~-i&!hFT@qo1Q8bVL z-iXhPiHqd+MM8Is5?8<-Gs${30(YEkINDR|>qK_TYV|(xnWx5O2*8pYT&@p@yw4h= zMhxBkw{Q7Tj;f4*uIq&u2pVm`L-bX9IVEf|bZ^8RoA|R}UB5g(J!K=SS~yL@ipHUP zm<5(Jw3Z)SCm4W!l@8bw2DafrObjt}yH-Gt!)w0I$e!GAH=5Co!od6f`*g9horT-# z=;(a9?pB8*RIC)CP*+#CIiKprCL&^BRo0ewg;$tkoD3;>(Sr*hAP1N817i4QfSZ^1 zbw%yRk93f2F)5?OYHbk2sxyj_PPbMl0?ov6y_g$r!3SqE0)A^TLPuZ5ua<0g_`6f+ zxu_kXMl8d5QM24-rg=P&ib%R;qP&h3N0Sw8ENkkVUn zE<~$qIC;jA66!3DP%6HQ<9;)3GXr-km@#RU4}m(ZrmanAYpMY~H7s6r&Fb^p3i0}- z^^~N9z^EVT+=qnHKP*j_McsvkkDw&cg()ET`^ue+>0);VsODc+JZeE(4oSU5@{L#% zmG+myoJld6dh;v-#p<4CAOn`E7+puS6OZ&B%AZ6e?KSj}T9177$jJo&lR#27irJeH z8|6rf?Co5n5&ZQ9Cm+q4*V$hp{l7e=0tUpTvJ`%eHdFbO|5gn*XnS3^@SjP5BmQh` zgyiSj7ZnE`{T;X+eCN)c#mPFk@$qq#1rf1nW08>9emI7fkwlJb{^^Q=fA#9k>h6_U zawF4Qmfs8vEm%(x*$5?TW%}B1X|BcYv**tvZ-0N44efBj+qa*90}{6FAWxI^A`}!9 zlp2?mkQkVy*|^o6GT>A(=5^WMGHkzk_J*s-AXC>Ts^BwWFcRuM+7~e(Ri5}wOu<6n zpDYIA2f8Y0&&6I{f2zR1Kmk+A!8W9p91()M4+AwJ2MeFpdO`yI&_vsF}h zBJM+;F9^aHKx({}9|GCi0(Bg%m1ut_H<{1h39S8vOvS)nfI4PlIM~=$t&gbD?&a)1=Cy;47Zly)dL+v1a!-izCsj2u;o^19qMz06ZLT|6MZXb+1v;fVa0 zTN+mEx!^+z=vkUy>86s?&^&}j>AO-cqv9y*>c6w1qOY%?Ucm$`FVG~nQ9$M~R_mX9TzG(HYfhdfCF3C9@dz>wV_o+78)uI&zgo2G;HL{t!XBMge z5ny$OW&**Vyu1D9`(+zw^MHWSOpjJV7DwK50f~5h*{|G;1Q2Mnu?>$H-R9r6t0sRv zR`((HBVEHcFgH@rq=vFkfci)*EiEB!;K3}R3^YuySBs0CY1^M)-%(QM5GXq!Wvhya z(f(RlxzWXCte3LqMq5P9_Y&U&)0gwZyYf-++@AJf<2V7juC9VW@r$8^A2f}v&}E~z zk(f!g|B(A5${*Xb#WcZf&6k5z3ncCO?kk)HMrbyfEN{U5a=aiAqQ^B zLi1_*{uV1d0z3{3TO*`Y5$GJ{4z7@>CvqJ6>u56W9TCReqOls76qlcT=79yM%V2__ zj=&a(xW%!`;)3`F8#X~UbD()n%z_yio?@PTwX1SnsZwMA2D;3fQkAhSv3!g6I=j1D zfebFS?_maV=2498cAFQOMngrnzJ+!3haR$Apsi$2F`bAikus3B0O_)TFJGK$P=dIP zfNC?Bpvf$9d0rHYkeyTHXPYH4y+r#v| z$VZPTnF_U1f-cj2;f?VPUU2jMb&%w+y8OtZC{Vo}29=GE{R7ekvmQ#O7zwBx<>Siu zb4yCdeWn7XKru3&dw;*sJCS2%KYLZ*$XdN{GLo5fGKU?_I@FYYv>U)slRzLqU&yAU zbPiBfYZMh%*7skiZj#WK!mpLM@q4P7Vs}?hpZE1V&RgJ(e1zfb7X)>UOLnG(Ju2EVbRs?DH*ug2wC=o3fjVE+?%Jr zQR~}R%Ma^G%>FAXgyC?$XHrnB+iYT=n0|TXW>1y;;n7i@bf*18{4hHphll4sAiz5F zKtF`a4Coq$3JjS@#lV2#c(>ssD(2SJhw6KC%RD0Zy^b4WKPks|k1f(_O}J%g;N}E3 zskabL4vuNaV~_3VVP5FO9Ud=UtbAE$uJ-s1U%O2v=Zvort9;rulLOz&VR&Z2u5jDG z{rz93NK}R;LY_W-3RZ=SdU~j=D*`=}fx*hd;u_pcr9LbwBW88SI8q?w{>=%}OsI^9 zhxfsGgi<5dz!1CU$2)_W_w8%Vzzuo2%=`oL1T23T8fh&9C_Lu7RDy!)JLtUI&2!7R zkH6mFZ7z(KMQc}Nfe1l5?k&*Ar2*SPGuBsTcj@EDk0x)G0`^L>Y;?`_(dv%;ycHe7QD@~F7}(F&*M&?(g5rIqneOZ-s@G- z4pVOTd>31=^m#d)(?`gPQmY=t!Q$w*a{hT3$OS8&j)4K?`37C7P-#( zgQ1b+tmO0H<6rriXg9tNs&<9jRNUjogh+4A#_IB*S==I13s!4x z*qIP07#Ilf&L9<=zP(y#2t5xqjo=M#mKTAkG6e#XOg$>ktjHGmbbAY){O3ic6pCqi0$IihF66p(ooL^7U2eh|2Z&FrsKxc%@XfKZ4VsLQ<o*cziXOUO~F9w6`$~&g1P~&Y#b{fTPE)n z0s}>Eb65zW@*oK~X27^(QXjCN9>A395$rk8>vovUZ%C&rbz2vPOgFsO z35jQtdnjmLjUaky($=N#1TSY_gpT0D?eD$d956I8x;a^Uk51GUUoTx2hnDskL#5WG zOD&NU>>Hn`E?ae{NDSBb2))uve*{p&dS$q@@+jT;v%+fN%-m9B?M(&Bt3o#VA2SzG zG-ybM`u6%*m&SmyngL)7kNIbLsJKm1sCEMr6~aK>-zwv|`U;JH6n6>)#{I7h#q#-@ z@_T`EJkAPU1SbrG%q?;_k?4eYwy(?_28yXhgt&1?N$plt{ujy|7`+7&Ej)c>Ygx8h z1RIcW)YR0KMbdna?N+Z5t$R8noba<$5Y8q8rTgy<167#=2X@ssX{3v+)-CZ7XmZ=i! z>9ifo{;u^|w*Z@g>4&O`gkNB1LNzwjnB+RT^S60^)PB%t<5O#^%A0`{%ciCsF(r$8 z;S|v0jqlt+TZm=ra<#_F@5_JwCqBe`@Zf>#|H*5)tus`DEJ_LzTz7Go&cSo$yl)Q8 z6K?CDk}{c^XJRn>?m$*rBsF!xFrh8uw*lGMq#;Nl=y=~Al|?;ML8k$6-h_Xw+KUgA zvFs~_j^i3dMlBN)6SHkxqf+SXaHiSz~WCvTM{Uo zxQMO2J*fXJwSu7arb@ar0}9@p^j}BqgqM=}XBb1uKk|Jd5~xi}NU(h5!t+eH`jW>f zv9c&hz=IeUnO!SBsCoKeGwHHoVa40vlVj+j-Ro|ZUIX63(3T~6O*-OSoP{mw9qbQ- zGq`k9IA(2rk;X1vxrT6NRK+xZBXmf2sy9l6A%2_Mo2m?U#XB?2IN)Rjd_H2I3C#%~ z_N$!H58EDU{=rg)bapF8Me|jZFpwR82?UUHev7QIg0PeaK6RnKjHXaL5uddiASx@{ ziob4s(8pAIRCP?(#1#q+r-Uh_4Jt0Z>3212m`K3(fHCTCL;VB*!Uui3nWKD+BWcZ~ zJg853jWqyhB{?~mz*-`xOA6J|k>2`1|7gvY1h3hvsFIh$33&!!IZ)bhkX6{h3<8S* z3xQ%^I+%awv2TN=VB1wEQtRd~p{Gr6p9{Gea=f5)(u6w0FtBi~fQJJ212*D3T+$>j zt~1y^(uZ+fZ{7+3fi3lui2y7J=Q3W^SNE%K!Sr=>c80?G&d<-kw-3WkU~g}K|MQ~8 zVL8M2z!xICZ@)`7hcC#F6qo-?A9$+%NavR?SJ^_4gZseEBFwLrO$c zlpHCUS-jAsg;w5o!25s|sqIvU0#OYz3Ley%NNA>?pIg@$yse&(a8pz2+5Xx%_<8em zT@;YQIT2*%G@xi;ASnBDczF0YBO?y@S{o?hY`J~9zE#pWm!Fd%(-FC~HmPtQ9ttxg zIXM|MvLYp3%SMMs>%kGK*~&nkC1hj>iZV<*P>Q(!Uiv&S>O*DwfDi+j0l+Y#t4rq) z7Bre#5f&*aTJW*L4Fu>YXCmGS$ZuHf(g>%V%9Rq-s_U3&%ex6(qczah9~UzGfds_L zdMzc)%d2fGUX8QHKs|VQRX`RO^3Fg42N82kCoDf;h1uBIv-J(VE;7selEWPwrt0fD z9HjET|AMk@U}Dn2d*3v;g+z{bQJs> zCZQ%KRenmUc~YqsV~D@dF)6Xa0Kwm^!Z`}&q7iU!aor9B0LcyqL%jIFl&}|;%&mi0 zW&Q<=00hg|R4FpBskWxN)km>w#!YX`wDK~K-aLNP%Klqxi&2+NG&yJ>MKVxD5E?Mst>%L3#6>vh8wE@yAvyvd?P zi7q^tQRmSQTCi4gRhSM-?p|gHQVb-}NFvTtr91@#j;zpvpqT}1ny1djGiO);@S|4J zgp?G=APr#&l8Nj0RF_{i4}2qhX}hd`O39xk_oljd;R4P>V)_SUihBs;Lu!hdxJ|Yz zmkz%E1d)vzL=|u?LyyzCRoTdD^;A0?GJ|rWw4no^U-FPN0B; zBUj=eNkI>;_;-;VvObHfECWzOG|N+fGgLk=djHeo(%UCi12;`F$^Wep+<$2&3(yJ@ z0_fl5?Od<>_b28i4;2=ELyLQwAZ_ztCCPcOkneByBe-50EuB-DpR!}#e7x~e6AdC1 z4R|uK4uPjvgHg_#7(A2%gHMx_XQ3_%ztF)<>`N#7fi_TqeAw8>P3u?@U~qk5=7d0 zWuvm-5(E~>2J$dwqCp@X^#~)-iD_SXnm!*}qb=+gj64bd1PA6GN=5v3+ftCdy#HKx3x;iD<-CL73_5T*#iTB%e|YIlv80iD>hx>x#tpq=lVE8g zOD=Qe%4Cbv&SXTPW)_4mKz?|G8{0CUn8GP-{<{oYYjJq|dgc z#|xRk^*;lF0-AxWB-s!j8;gUWPQc87-vq$sK#W*hxCzq8hZR8-BQVlNIF&Sy)P%?eC-wq{OcL;`!m-^5V+9Y5?YD;gAHH_M~NoT ziNNIn`YS0Oh=LLn^1(!F7&pLn|g{~1et$g9+on?cCzNjNNuryVyE}M&7T6>-+_v8nquLC9r#NBo}V)_@uZCg zh03bVN+4lUpjBW6yJyPiH^*a;@N6J-(Dq5tvL0kKi3b)%b$wi2c(}O)?;B)rlcG*Y z0scGkL&i-^q(n0v_>+_yf{uDC-7LYUI$PS#tOm}p4u}3mXg|;HdZ~Jlamn}R;=D`6 zL{TWi<6cZ*(3io+yEhYY4jAhYpb7&{j&{)&4~*bkq6`ubXePf7?;j#DYu{=|Ii_Ja zPUWPp$nQQ`{fd>Rdfg_$2eI79~pW+bq; z`2>owTWHr+?}ZbH;gz?Fg^os|yoqGwf*T(E z9>nV{Ta_WK(?Ee(?QP8H7aGF$hZd5tI0lK=30D%zeY5<9^%l;|%+=kFJX+#5JE+bn z%1m7}L9damb=1YnRO5GC!;@%22vqPzpQWa^oEK%|Ut^YAo5 zEr^4@sTqM#MT50C)rZPn)wv22;NNuECcv#E0Ecg$ORgK=~7$J!$03M z^{TD7V-7GH@)B7RH~O!gR3zBZpIg^bb2X@%wuK7}sySeMgzZ46j~-~QP}S@ZXzM_2 zB&Vf~pclKpRc3T<A}9en zD*#G5RW&sbyvb^}2d}$(dVWH!stgQf`GF+o6dA#<)bMi$YLX^+HcK4(*bwCiQl`gG zo`h9S1)YTp6n(c(+Fl}`jsuHg8`r^Co0&mFM)}Nmis;?xLX0Aj?X^LDxebcDK=TGe zjKdbKXh8Bv2-@6SynI;`dTP|q&Tm@wQbo!C|H8^ixaNX@`u!Kbh3yegr19GXO@l9v zgtN@W=wdkeCxd88Et#pAVqw(qL8JlGi^K^d@Y>)HV}%|nBO`+(2+#hMnmgoA-rwKP zE-RzFckdoPet}}djT<)@EdsX(aJw&%aCLmAf5Yk1PS#$wIU1k8q25rvEehpkgb*$( zvHAPA6zq9Dl(9X6ycKXI%8<& zVFG!)`DuxRI7o^DsuMxrc`kM+N@F%+I97M=h(LWrH$-JQraXh)7VrltV5m=*zLQgR z^n~Xs%4PtU`SRtGZ3%ciwlD65%W|{xDdE^ly9Qdt8ytteo@rl>v{yqlWatD!Ot{ zA+(smgVWH)pgiZus(GWj*EZltpn4=TJDQS{WawyS;tR2wdwDMM$udB0Nq5KJBP}5s zIR>#Q(ueBn&Ff#zZm)4%eF6&f-7}q z-vpabZKV<~VumY-Q&1!5;X} zFVe+oM;RpCr=zCFM>+)g)xPb&+1n0KtVru<3N+~Mb}2ryoRDGpXan(aFw{0kfYH+^ z`S&pW`at7Ib8~Q?p+u(B!^F^CCmaO${j%9lOqvBsy`DkhvfYW(dxDnM?qg??CMT)FQZq7J;W*8D6P zqPMU({mvje>jP%P!qe&9TFL-_j?UVkyTwM&m3T1`E78w;t9M0IQ>#?>+shxqiZ};! z;mG3xnv7wI16Lp+p#EZjrD1k4923P)pvI(B9hvgKT6lJIqoxLQS~mG!rt}miv_p~% za>aL<$MhkIcpMe!Obt8Q&;Iy18;P+RAd;0{d~alc66Qmf#2J2HTagYt_FA3(gQq%NXg(9`&cFq4Jw{ zcv!F}N%MEC7V;mz_@&QN^sUb9r^0|4hAtPM_`_;_a*eK3Y+JNm=gTMK2vLe3zrspx zy^D05iK53E!`Ie^pBM;*p97dMTwoy!_}x?E&V9r?sdce+(!==i3(JRHhoVmwV`)Ho z&GXOl*3T&~CThkSzZ2{!e9lty)M8HAC$K#DjVe=Vk_xtzfFC1fN|5%tjVWRBoR)^K^c9LCjd9BJT&AOs=vgKM+oFv-7%uV&7eTOp}NW z0XuUy!2+c&{}SyWlm?;h-wS#Rt7shYa0sF#cAA9GipI#mdV!=^3kwYcV zP&v8jmlcQ`_~=La<_{DQL{sPtemA)0g3AgqH;#6vG(n^lXCSEt!xl&gL!;hc%zMmH zIJyk-S?u%Y9XsZHNRk*)*gOkcjuI4;J@74^Y9ruJa|a*%pVhHb@3e+S8%gCQUEK@b z-WB@;;4L~^!Ie{8WrV!4T?Ur%%vsosgAD6q?7`N-Y3PGnP(Ll~d$)q01wcKM;E;Fy zm|k2g1U`grp}Z?27o(z9)E?(k)tWOJSUsdD!5%x4&=7zq_@i2Hv9B zG4v!L7#9bN&Vuh9HBX@~G}!i}4)7erG}u&Be6sc8l`NCLSfp!f7HJTeGEX;3fWz=X z?3s|E8z=o3(0&C{ssO(11smJe0myQ;w{FdX)BasK5(B#w-xSxI1>cWa`J|>`hSmV} zc62<4VQWi#JlPX%yYk6;Fb_J`J&^kL-u`l$f`7%}HfmI`w$9#v`r^gCu@}HZM5>oY zVfOa*p}QHtIR2pZH^7cRO9RWvyB;tV@FxA$7XK0dE~r!)ipKQIXQ7r7EVJ07mV$gS+92m|DOzB(T5D`lFQ+HHhc5EIP!7@BpFQo zMrg>@+!Di?hTknfn;8^+kBp2g4wtgSF3B)(Az~sxNeRqji_;N5mvEv1Hyt>-!y+T` zQDYu3QsmpzkQ&RG{rAb{XA=`A|=gB{n5fqIE$B?RV>a{{qz~ig?gjEU?cHbPsOZ{0&d3)8&NJk_o!<1{dqto{6hLH zHHCrAO)?e=0)tz^H^=b~?&N(6na|MkEwgk&Wy}@Wzz}s2pG))nEfN{3r252g^+9mR_Sqd2hLbb0bFMr>(i8F~p9ge5B z=Wm%#7b&HM!%%@WEA#~1834ghTlIbL5-h)?c#;@daNCA#nrXU$C6XZ{`&(EyS6L}& zEU;ljd%caIoK4t*)d;K3U*WqV7|7>>Z>}6~FCA-CwyF(>cSXo6c zMhEYZHg$9;BiDRrE>R9&x@ep4dwdIKb0pGA$UZ=rmErtf?G>445p@f{OBS4&gL?Ry zK&~GjDsLW_Kyd@CP9WASxy>=pe(uRofV$r<{*CghX}fmV>lC0j47hGbJ5(=p7D zA?R-hWW0aY2eN`kYS-ZvaMDPpUXOuoJbS)w3;{QoM5M>hsj(gQvVS7_8h1FwPB9`# ziZO-Z%}OPqORYgwM)U=eV?@~1XC(qZDG57|S6B~S&D=Z*b`a#yPw`x#%mpxLO)tN? zUtGg5DH!Ljf0sEh=hdSHHz-31OjmfQ9cETzfK#L!>^le7BR$NB7vOqdx~#(tvRlfS z>&5!Wo*=EtVT(k>jaK48U0TVHCT?th87$y`X8=}b>I zh8zJV#8jU2jmY2gvM6Qd4aYRIpUl=#kBx_E?jw=^2BG+C@? z2;o>oP2v|N5yFfZro4Xljz@xS6FP{)pepRIOG+otY8=Ms-|nBs%J|9)oURH`_t<{Y ztjcdxLIw!}NB7(xI6=Kc_nR##7e4Oo(v7|4>In6`+@x~j{zW8&zq8jkML(xF>|HJ8 z2#y_~-H2sN!p1jw)bl460s_i+Z{2UL*Q*DI?bcb8Y^#*t`}n%A>sl-sB|!JH68j4lVFY!QTz)@)kgvb>t$N_9BH~LBO$vicEr(@Su(|p|UbRus@)#kj;UU>UJ)Yd}a!mS{B z&>K89@eE01Xn*lxAl1LMJ9|6~XJZd!Q>Kg=bZb*zbm-k-crMpC))944JV~p9?zns4 zX>PxjdY+cXi>An&--Ks=@qwqEz$X^#;h~694u2=z*0G}aZtrx)PS7Ib%L~i5-qkw2 zSiw96^K6da-Zkv3kD9khCHWQN!`>EmajIcYc{~KgxNN;l6!+KY&aByI$jen0Fh*CV zhil}80gC+^N9u0CH#55!qD;mG?)RJ*sntOZbQBWY=#PeZ7eE$}RQgW4OlH~G*QJ{F z(IoOpChm5Z*Zu4XB(3%iZeQez{q0$5zX z-=@>$G)%Dt2wZTD(Bar&th8$I@Xxc~gzU-AN<_I@4>`2-d67I-jGyO}DmOjEDse(c zG=<8XHD z1(riXI!f)_+kFU6ILRqcOb5HL?#R~Qn_u@$8M`MVH#?%@0<|jJ&@w7f=}}G2Z*53( zX5)rA#o{-&qSDNi#2P+X9akB;ed=3Y3p(br*UYOz#0^t6TAN0rdxJfRL}|UAa?-D_ z9~KQ(?0ORlS4zCF?b@5YSmc5S@~~!7PWSq{Wjz^nTS{?SHTZNak1S*M1OGIjJEsUV zBVIS}^qd%Nf2a=(ep$5|_7u;(>LM8o!xh$R)Xk*YFi2s0AFY}em5L_<_7A8)NNCaL z$*paUEF)z9LVPXUJ1c@}P=2yy-23Zan~pA_&?~I4S1rr%;{39rLYkrPQN7-L#huhE z8z2yJAcR^EI57TdQ$6bB%QoWK7_mt{P`GD+=Mzf^`%X;$SQ!>IsFmrV3lATf5@Aaz z2p)`~^v?3Hyntg?1yP-MB{1peq_v7Nea@kkAxdQ6&C(abWZyUiC}Z??Co{>*-y<9#L{xu3|1b(&`Gv!*s{4DHx1uN3*T+jdm5qLfBD<3U{MFeBP{ z6zt+tul(w`(vlm@W^3T=G7J$uaI#FNTP5)cgw;QY>e>rbxekIqn1~c`S^U&xqzi{0 zw~Dgz+?KgOfWq@V%ZT%Gczac96l%nXyWfgWL~&7Zu5L>4uv|3sJL(H&;3;;q z=`A_2;pS;G#6-i`AM+W2}XwWf}`}_LTwX@qt zSow1}Ufxds5R32q)(8YR z9q2@7r{J0;C1#swpI+MjmNtfaoht4Q))-;E#%nq`j`FUOEGV;`8~X zC}~pb6k33;RyT%JCQT=mW<(Hh8@sy8Rb54uM{q_+KF3>(?^83)n>~CIMXWx6p`GEH z<5yKJEsv#>x-@~kNnT`|p0+cNu;i=(kh1OOmu54<9>MnmLgLU@dAyH#%-zlHwDEhy zB)2+uAuF`u&#|-dPsOImDatYvTkT&>*WQn~oY#V!)WUaxW0PcG(1D@IQ`=2F2yBIL z^>F*k8+MFRx22kCkXunWHH9}xX?a@TwGObSU2WL9;L?b=h)=D0S9Dp7HXx{(nHuQ% z9*Wf%h1y!i>UEn;fxq+}tvgh7DBS?0>##AV&f{SrH+c_@ix{HVd)iC6`Z z__T`mTpk)1@keWS?jn}f(*x!E%-hUKh#rzm6jKKO#*RQH$lWr1;fx(R>`>$)WFys-41T1nXP&say5KQ0Q zsc|eBUteD$?E@c%fZ5@|g>hEG^GL2C5gmypj?w)2MAJ2LA~USQO~0_UTTR{VLO#?J z706aeA>b0B6;Sb*C60dE)}{fbFe-MlID9l+2ZZSIMSgAj;f!Fm0NUc9!|}4MjqRTT z@s&-0zki}oC=?0H;{@=iB@U61aZ6pPJ-hgwF}EwD5n%;1TdI1{Npf|p5>^-Wh{Ynz z-%2p2$zxQE)XEpt-rZ2g#gihy;C<319qaIzR-8pK&R8dxFmjvsJ*?CdU@}jIXLO(kOWGl(jK64o<25? zj=Eb5>;oYX%VgJ=U@R1aA>!cB-grAS)St!|R`)t`TV;z^r7tSTd8Oy+0Oo=x>Bx$+ zXZAy-rE3B&RUT8!J#3^RlQZ=fC65uy-ERXM9?a5GVd&tJR)&gdfqe|bYT{ml3}gvg z!PDN)ShA0qEan>ALi?K_V{k{A0P-Pi03tLiNh?3?_X(^t(KulWp+p3VDQ)4d+3b76 zDHcOBc3%f&ZT-LNk@m9HI{lggeKZuR;f>qz_L(B}usq|eFflEyxkLHg9jwaRd%IjD zUpslVPrLE)EujN+Vram0C#x-qzMv(BKJYOkKaDu6sk>SaD6O`&jl>dZ273;$7h+C8 zKa9H|ZJhmGLkw~?`vs{k{^5SRd00R5%@tS?`_$We^+rJRmG)sZ2~ zL?dYX_U&QlDWNxnK>s)+Bco%XuUL#2ZNN{!t0oE-?C1zX)QyrDKY5lGDh-riMBOEU zBw+h4h$s+G7-lKeg%QA{I{80Weg2IfKwY4Jc2#;T>@`C4ftOlZM&AF5sn^@zzfnaM zg9am$kv3zVXk8E;qAgVz!Z_|R_Ot-sO zTZy_^Qa);{X#)MYWv*n#zJ2jRjd*~1K>CRu zn+!Mv1Oz14U9xi>oSZO9XwLI(Ze%Gu#GAOQs;UM8-D)?)xRe=L+VBwryaFBABs!fg z0@`*WTm4kHY{r8+NU$Dm1K295oWbK`R_57XPVWEIsjX@f;^5$6>>oTh14DXS<%-cO zC9h)m&*Ir$xyvTYSoP~2WCfX8=uvI@)Wd+|;c<@T+sv}#)ulAx7a5+w6y)(4mUu|f z@N)R>Yvotvl4Q6>-k^UL^_@}V^!1PYHS(7)X9j6njRO~GOn^ckl;RbiM`Ot~OE~St zBiZ%XrRHwyRm|noG%l|jiD z(z7SaI)NZ_UChctJT)>Gkm_@Q2;soHsIK~b>{vs*d6P^+SlGEK3Jl;RL+?q%4A;rK z-G)C8;45Uxmq3F`APg>`GCy1eD07+%g52rd0QK7T$T3Q>Nt1PkCFv^TKYr|6^TEn| zhnc=&e9SQe&xA|sbld^eMb+(<9SOOmJJo~72j;cc)mage8kRZ$V|ZkIe{$Pf{b=-LysFbSd?CtC1Uf&|S9Fl0G}{ZUa-fx*BqU7!p;aYEiIVS9I*18xZSCxh_36-&T)PaV zr;QeZcK6`y;^&aCeQk3gxKC;OM5xP?A3&1AzDdB11%*2P%4=*JdQ; z9>8IgHytU=K~ai|iVBQ@NP1woRix8n5I84#hjK&`2tQ1fmW?rqCL~v)mi1l8c>Vf% z+*#i~{&j4&iwqtzsX_HD?ky?+C~Dt(0j literal 0 HcmV?d00001 diff --git a/log/best_performance/epoch_loss_class_train.txt b/log/best_performance/epoch_loss_class_train.txt new file mode 100644 index 0000000..515fdfc --- /dev/null +++ b/log/best_performance/epoch_loss_class_train.txt @@ -0,0 +1,90 @@ +3.7712276314242685 +3.0636168965986927 +2.801343619823456 +2.6331870392384205 +2.5135307890929712 +2.4291697614202716 +2.3555361779060604 +2.2984420201291202 +2.2522135552288742 +2.197252894587989 +2.097349743214381 +1.9542330193516024 +1.7952129174223848 +1.6156651931379018 +1.4707137380617539 +1.3638808261168214 +1.2811493240666694 +1.2145853494477985 +1.1464360410734533 +1.0886454173869087 +1.0415393155084143 +1.0067537130545055 +0.9712598563652918 +0.9459512437851009 +0.9096426171339108 +0.883714666565097 +0.8630674364239435 +0.8406433062055244 +0.820840925855913 +0.8079694209025758 +0.7841788579339185 +0.776262023339371 +0.75239310020289 +0.7407916599995307 +0.7314519268123639 +0.7205715374548911 +0.7119348809259342 +0.6838741163310618 +0.675083714229081 +0.6665253263740542 +0.40875943099878626 +0.33014846760375033 +0.301763350591267 +0.27762719021237026 +0.26088881696264316 +0.2430464035457363 +0.2282850744710909 +0.2212755358291129 +0.203683311874431 +0.1944100352668947 +0.18915298141420142 +0.18149443655029596 +0.17873722235375974 +0.171926496180016 +0.164108280483452 +0.1580632201443505 +0.1529976990462558 +0.14594485707745689 +0.1456607071607638 +0.14255504573110114 +0.11923092203425435 +0.10815973842863215 +0.1090422415932918 +0.10648260271793217 +0.09845620263738979 +0.100937618059459 +0.10059720761977203 +0.09929326052758904 +0.09756028030852115 +0.09850751669961134 +0.09668084698679134 +0.0974692665948768 +0.09672294503566119 +0.09527862614642296 +0.09309141301707675 +0.09378876050791575 +0.094307657253778 +0.09193214094837932 +0.09031117155016503 +0.08870918162732286 +0.08976529900305628 +0.08807121031712889 +0.08787393360298594 +0.09003039282064657 +0.08867947479501018 +0.08774756830658956 +0.08815590344746861 +0.08792258952912853 +0.09093503984912055 +0.08763042784054735 diff --git a/log/data_tree.log b/log/data_tree.log index 3bfbe88..f98e62c 100644 --- a/log/data_tree.log +++ b/log/data_tree.log @@ -1,16 +1,16 @@ -data -|-- NTU-RGB+D -| `-- samples_with_missing_skeletons.txt -`-- nturgb_d - |-- xsub - | |-- train_data_joint_pad.npy - | |-- train_label.pkl - | |-- val_data_joint_pad.npy - | `-- val_label.pkl - `-- xview - |-- train_data_joint_pad.npy - |-- train_label.pkl - |-- val_data_joint_pad.npy - `-- val_label.pkl - -4 directories, 9 files +data +|-- NTU-RGB+D +| `-- samples_with_missing_skeletons.txt +`-- nturgb_d + |-- xsub + | |-- train_data_joint_pad.npy + | |-- train_label.pkl + | |-- val_data_joint_pad.npy + | `-- val_label.pkl + `-- xview + |-- train_data_joint_pad.npy + |-- train_label.pkl + |-- val_data_joint_pad.npy + `-- val_label.pkl + +4 directories, 9 files diff --git a/log/train_aim.log b/log/train_aim.log index 1f84ef7..30a1033 100644 --- a/log/train_aim.log +++ b/log/train_aim.log @@ -1,21 +1,21 @@ -$python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml --device 0 1 2 - -/root/AS-GCN/processor/io.py:34: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details. - default_arg = yaml.load(f) -/root/AS-GCN/net/utils/adj_learn.py:18: UserWarning: This overload of nonzero is deprecated: - nonzero() -Consider using one of the following signatures instead: - nonzero(*, bool as_tuple) (Triggered internally at /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.) - offdiag_indices = (ones - eye).nonzero().t() -/root/anaconda3/envs/stgcn/lib/python3.6/site-packages/torch/nn/modules/container.py:435: UserWarning: Setting attributes on ParameterList is not supported. - warnings.warn("Setting attributes on ParameterList is not supported.") -/root/AS-GCN/net/utils/adj_learn.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument. - soft_max_1d = F.softmax(trans_input) - -[05.08.21|19:25:22] Parameters: -{'work_dir': './work_dir/recognition/ntu-xsub/AS_GCN', 'config': 'config/as_gcn/ntu-xsub/train_aim.yaml', 'phase': 'train', 'save_result': False, 'start_epoch': 0, 'num_epoch': 10, 'use_gpu': True, 'device': [0, 1, 2], 'log_interval': 100, 'save_interval': 1, 'eval_interval': 5, 'save_log': True, 'print_log': True, 'pavi_log': False, 'feeder': 'feeder.feeder.Feeder', 'num_worker': 4, 'train_feeder_args': {'data_path': './data/nturgb_d/xsub/train_data_joint_pad.npy', 'label_path': './data/nturgb_d/xsub/train_label.pkl', 'random_move': True, 'repeat_pad': True, 'down_sample': True, 'debug': False}, 'test_feeder_args': {'data_path': './data/nturgb_d/xsub/val_data_joint_pad.npy', 'label_path': './data/nturgb_d/xsub/val_label.pkl', 'random_move': False, 'repeat_pad': True, 'down_sample': True}, 'batch_size': 32, 'test_batch_size': 32, 'debug': False, 'model1': 'net.as_gcn.Model', 'model2': 'net.utils.adj_learn.AdjacencyLearn', 'model1_args': {'in_channels': 3, 'num_class': 60, 'dropout': 0.5, 'edge_importance_weighting': True, 'graph_args': {'layout': 'ntu-rgb+d', 'strategy': 'spatial', 'max_hop': 4}}, 'model2_args': {'n_in_enc': 150, 'n_hid_enc': 128, 'edge_types': 3, 'n_in_dec': 3, 'n_hid_dec': 128, 'node_num': 25}, 'weights1': None, 'weights2': None, 'ignore_weights': [], 'show_topk': [1, 5], 'base_lr1': 0.1, 'base_lr2': 0.0005, 'step': [50, 70, 90], 'optimizer': 'SGD', 'nesterov': True, 'weight_decay': 0.0001, 'max_hop_dir': 'max_hop_4', 'lamda_act': 0.5, 'lamda_act_dir': 'lamda_05'} - -[05.08.21|19:25:22] Training epoch: 0 -[05.08.21|19:25:29] Iter 0 Done. | loss2: 876.8732 | loss_nll: 832.9621 | loss_kl: 43.9111 | lr: 0.000500 -[05.08.21|19:26:56] Iter 100 Done. | loss2: 118.9051 | loss_nll: 110.3876 | loss_kl: 8.5176 | lr: 0.000500 -[05.08.21|19:28:14] Iter 200 Done. | loss2: 76.1775 | loss_nll: 71.7404 | loss_kl: 4.4371 | lr: 0.000500 +$python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml --device 0 1 2 + +/root/AS-GCN/processor/io.py:34: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details. + default_arg = yaml.load(f) +/root/AS-GCN/net/utils/adj_learn.py:18: UserWarning: This overload of nonzero is deprecated: + nonzero() +Consider using one of the following signatures instead: + nonzero(*, bool as_tuple) (Triggered internally at /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.) + offdiag_indices = (ones - eye).nonzero().t() +/root/anaconda3/envs/stgcn/lib/python3.6/site-packages/torch/nn/modules/container.py:435: UserWarning: Setting attributes on ParameterList is not supported. + warnings.warn("Setting attributes on ParameterList is not supported.") +/root/AS-GCN/net/utils/adj_learn.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument. + soft_max_1d = F.softmax(trans_input) + +[05.08.21|19:25:22] Parameters: +{'work_dir': './work_dir/recognition/ntu-xsub/AS_GCN', 'config': 'config/as_gcn/ntu-xsub/train_aim.yaml', 'phase': 'train', 'save_result': False, 'start_epoch': 0, 'num_epoch': 10, 'use_gpu': True, 'device': [0, 1, 2], 'log_interval': 100, 'save_interval': 1, 'eval_interval': 5, 'save_log': True, 'print_log': True, 'pavi_log': False, 'feeder': 'feeder.feeder.Feeder', 'num_worker': 4, 'train_feeder_args': {'data_path': './data/nturgb_d/xsub/train_data_joint_pad.npy', 'label_path': './data/nturgb_d/xsub/train_label.pkl', 'random_move': True, 'repeat_pad': True, 'down_sample': True, 'debug': False}, 'test_feeder_args': {'data_path': './data/nturgb_d/xsub/val_data_joint_pad.npy', 'label_path': './data/nturgb_d/xsub/val_label.pkl', 'random_move': False, 'repeat_pad': True, 'down_sample': True}, 'batch_size': 32, 'test_batch_size': 32, 'debug': False, 'model1': 'net.as_gcn.Model', 'model2': 'net.utils.adj_learn.AdjacencyLearn', 'model1_args': {'in_channels': 3, 'num_class': 60, 'dropout': 0.5, 'edge_importance_weighting': True, 'graph_args': {'layout': 'ntu-rgb+d', 'strategy': 'spatial', 'max_hop': 4}}, 'model2_args': {'n_in_enc': 150, 'n_hid_enc': 128, 'edge_types': 3, 'n_in_dec': 3, 'n_hid_dec': 128, 'node_num': 25}, 'weights1': None, 'weights2': None, 'ignore_weights': [], 'show_topk': [1, 5], 'base_lr1': 0.1, 'base_lr2': 0.0005, 'step': [50, 70, 90], 'optimizer': 'SGD', 'nesterov': True, 'weight_decay': 0.0001, 'max_hop_dir': 'max_hop_4', 'lamda_act': 0.5, 'lamda_act_dir': 'lamda_05'} + +[05.08.21|19:25:22] Training epoch: 0 +[05.08.21|19:25:29] Iter 0 Done. | loss2: 876.8732 | loss_nll: 832.9621 | loss_kl: 43.9111 | lr: 0.000500 +[05.08.21|19:26:56] Iter 100 Done. | loss2: 118.9051 | loss_nll: 110.3876 | loss_kl: 8.5176 | lr: 0.000500 +[05.08.21|19:28:14] Iter 200 Done. | loss2: 76.1775 | loss_nll: 71.7404 | loss_kl: 4.4371 | lr: 0.000500 diff --git a/main.py b/main.py index bd402cb..c81bb7b 100644 --- a/main.py +++ b/main.py @@ -1,24 +1,24 @@ -import argparse -import sys -import torchlight -from torchlight.io import import_class - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser(description='Processor collection') - - processors = dict() - processors['recognition'] = import_class('processor.recognition.REC_Processor') - #processors['demo'] = import_class('processor.demo.Demo') - - subparsers = parser.add_subparsers(dest='processor') - for k, p in processors.items(): - subparsers.add_parser(k, parents=[p.get_parser()]) - - arg = parser.parse_args() - - # start - Processor = processors[arg.processor] - p = Processor(sys.argv[2:]) - p.start() +import argparse +import sys +import torchlight +from torchlight.io import import_class + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='Processor collection') + + processors = dict() + processors['recognition'] = import_class('processor.recognition.REC_Processor') + #processors['demo'] = import_class('processor.demo.Demo') + + subparsers = parser.add_subparsers(dest='processor') + for k, p in processors.items(): + subparsers.add_parser(k, parents=[p.get_parser()]) + + arg = parser.parse_args() + + # start + Processor = processors[arg.processor] + p = Processor(sys.argv[2:]) + p.start() diff --git a/net/__init__.py b/net/__init__.py index 8b13789..d3f5a12 100644 --- a/net/__init__.py +++ b/net/__init__.py @@ -1 +1 @@ - + diff --git a/net/as_gcn.py b/net/as_gcn.py index 7468be4..f49b08d 100644 --- a/net/as_gcn.py +++ b/net/as_gcn.py @@ -1,308 +1,307 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import Variable - -from net.utils.graph import Graph - - -class Model(nn.Module): - - def __init__(self, in_channels, num_class, graph_args, - edge_importance_weighting, **kwargs): - super().__init__() - - self.graph = Graph(**graph_args) - A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False) - self.register_buffer('A', A) - self.edge_type = 2 - - temporal_kernel_size = 9 - spatial_kernel_size = A.size(0) + self.edge_type - st_kernel_size = (temporal_kernel_size, spatial_kernel_size) - - self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) - - self.class_layer_0 = StgcnBlock(in_channels, 64, st_kernel_size, self.edge_type, stride=1, residual=False, **kwargs) - self.class_layer_1 = StgcnBlock(64, 64, st_kernel_size, self.edge_type, stride=1, **kwargs) - self.class_layer_2 = StgcnBlock(64, 64, st_kernel_size, self.edge_type, stride=1, **kwargs) - self.class_layer_3 = StgcnBlock(64, 128, st_kernel_size, self.edge_type, stride=2, **kwargs) - self.class_layer_4 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=1, **kwargs) - self.class_layer_5 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=1, **kwargs) - self.class_layer_6 = StgcnBlock(128, 256, st_kernel_size, self.edge_type, stride=2, **kwargs) - self.class_layer_7 = StgcnBlock(256, 256, st_kernel_size, self.edge_type, stride=1, **kwargs) - self.class_layer_8 = StgcnBlock(256, 256, st_kernel_size, self.edge_type, stride=1, **kwargs) - - self.recon_layer_0 = StgcnBlock(256, 128, st_kernel_size, self.edge_type, stride=1, **kwargs) - self.recon_layer_1 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=2, **kwargs) - self.recon_layer_2 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=2, **kwargs) - self.recon_layer_3 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=2, **kwargs) - self.recon_layer_4 = StgcnBlock(128, 128, (3, spatial_kernel_size), self.edge_type, stride=2, **kwargs) - self.recon_layer_5 = StgcnBlock(128, 128, (5, spatial_kernel_size), self.edge_type, stride=1, padding=False, residual=False, **kwargs) - self.recon_layer_6 = StgcnReconBlock(128+3, 30, (1, spatial_kernel_size), self.edge_type, stride=1, padding=False, residual=False, activation=None, **kwargs) - - - if edge_importance_weighting: - self.edge_importance = nn.ParameterList([nn.Parameter(torch.ones(self.A.size())) for i in range(9)]) - self.edge_importance_recon = nn.ParameterList([nn.Parameter(torch.ones(self.A.size())) for i in range(9)]) - else: - self.edge_importance = [1] * (len(self.st_gcn_networks)+len(self.st_gcn_recon)) - self.fcn = nn.Conv2d(256, num_class, kernel_size=1) - - def forward(self, x, x_target, x_last, A_act, lamda_act): - - N, C, T, V, M = x.size() - x_recon = x[:,:,:,:,0] # [2N, 3, 300, 25] - x = x.permute(0, 4, 3, 1, 2).contiguous() # [N, 2, 25, 3, 300] - x = x.view(N * M, V * C, T) # [2N, 75, 300] - - x_last = x_last.permute(0,4,1,2,3).contiguous().view(-1,3,1,25) - - x_bn = self.data_bn(x) - x_bn = x_bn.view(N, M, V, C, T) - x_bn = x_bn.permute(0, 1, 3, 4, 2).contiguous() - x_bn = x_bn.view(N * M, C, T, V) - - h0, _ = self.class_layer_0(x_bn, self.A * self.edge_importance[0], A_act, lamda_act) # [N, 64, 300, 25] - h1, _ = self.class_layer_1(h0, self.A * self.edge_importance[1], A_act, lamda_act) # [N, 64, 300, 25] - h1, _ = self.class_layer_1(h0, self.A * self.edge_importance[1], A_act, lamda_act) # [N, 64, 300, 25] - h2, _ = self.class_layer_2(h1, self.A * self.edge_importance[2], A_act, lamda_act) # [N, 64, 300, 25] - h3, _ = self.class_layer_3(h2, self.A * self.edge_importance[3], A_act, lamda_act) # [N, 128, 150, 25] - h4, _ = self.class_layer_4(h3, self.A * self.edge_importance[4], A_act, lamda_act) # [N, 128, 150, 25] - h5, _ = self.class_layer_5(h4, self.A * self.edge_importance[5], A_act, lamda_act) # [N, 128, 150, 25] - h6, _ = self.class_layer_6(h5, self.A * self.edge_importance[6], A_act, lamda_act) # [N, 256, 75, 25] - h7, _ = self.class_layer_7(h6, self.A * self.edge_importance[7], A_act, lamda_act) # [N, 256, 75, 25] - h8, _ = self.class_layer_8(h7, self.A * self.edge_importance[8], A_act, lamda_act) # [N, 256, 75, 25] - - x_class = F.avg_pool2d(h8, h8.size()[2:]) - x_class = x_class.view(N, M, -1, 1, 1).mean(dim=1) - x_class = self.fcn(x_class) - x_class = x_class.view(x_class.size(0), -1) - - r0, _ = self.recon_layer_0(h8, self.A*self.edge_importance_recon[0], A_act, lamda_act) # [N, 128, 75, 25] - r1, _ = self.recon_layer_1(r0, self.A*self.edge_importance_recon[1], A_act, lamda_act) # [N, 128, 38, 25] - r2, _ = self.recon_layer_2(r1, self.A*self.edge_importance_recon[2], A_act, lamda_act) # [N, 128, 19, 25] - r3, _ = self.recon_layer_3(r2, self.A*self.edge_importance_recon[3], A_act, lamda_act) # [N, 128, 10, 25] - r4, _ = self.recon_layer_4(r3, self.A*self.edge_importance_recon[4], A_act, lamda_act) # [N, 128, 5, 25] - r5, _ = self.recon_layer_5(r4, self.A*self.edge_importance_recon[5], A_act, lamda_act) # [N, 128, 1, 25] - r6, _ = self.recon_layer_6(torch.cat((r5, x_last),1), self.A*self.edge_importance_recon[6], A_act, lamda_act) # [N, 64, 1, 25] - pred = x_last.squeeze().repeat(1,10,1) + r6.squeeze() # [N, 3, 25] - - pred = pred.contiguous().view(-1, 3, 10, 25) - x_target = x_target.permute(0,4,1,2,3).contiguous().view(-1,3,10,25) - - return x_class, pred[::2], x_target[::2] - - def extract_feature(self, x): - - N, C, T, V, M = x.size() - x = x.permute(0, 4, 3, 1, 2).contiguous() - x = x.view(N * M, V * C, T) - x = self.data_bn(x) - x = x.view(N, M, V, C, T) - x = x.permute(0, 1, 3, 4, 2).contiguous() - x = x.view(N * M, C, T, V) - - for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): - x, _ = gcn(x, self.A * importance) - - _, c, t, v = x.size() - feature = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1) - - x = self.fcn(x) - output = x.view(N, M, -1, t, v).permute(0, 2, 3, 4, 1) - - return output, feature - - -class StgcnBlock(nn.Module): - - def __init__(self, - in_channels, - out_channels, - kernel_size, - edge_type=2, - t_kernel_size=1, - stride=1, - padding=True, - dropout=0, - residual=True): - super().__init__() - - assert len(kernel_size) == 2 - assert kernel_size[0] % 2 == 1 - if padding == True: - padding = ((kernel_size[0] - 1) // 2, 0) - else: - padding = (0,0) - - self.gcn = SpatialGcn(in_channels=in_channels, - out_channels=out_channels, - k_num=kernel_size[1], - edge_type=edge_type, - t_kernel_size=t_kernel_size) - self.tcn = nn.Sequential(nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - nn.Conv2d(out_channels, - out_channels, - (kernel_size[0], 1), - (stride, 1), - padding), - nn.BatchNorm2d(out_channels), - nn.Dropout(dropout, inplace=True)) - - if not residual: - self.residual = lambda x: 0 - elif (in_channels == out_channels) and (stride == 1): - self.residual = lambda x: x - else: - self.residual = nn.Sequential(nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=(stride, 1)), - nn.BatchNorm2d(out_channels)) - self.relu = nn.ReLU(inplace=True) - - def forward(self, x, A, B, lamda_act): - - res = self.residual(x) - x, A = self.gcn(x, A, B, lamda_act) - x = self.tcn(x) + res - - return self.relu(x), A - - -class StgcnReconBlock(nn.Module): - - def __init__(self, - in_channels, - out_channels, - kernel_size, - edge_type=2, - t_kernel_size=1, - stride=1, - padding=True, - dropout=0, - residual=True, - activation='relu'): - super().__init__() - - assert len(kernel_size) == 2 - assert kernel_size[0] % 2 == 1 - - if padding == True: - padding = ((kernel_size[0] - 1) // 2, 0) - else: - padding = (0,0) - - self.gcn_recon = SpatialGcnRecon(in_channels=in_channels, - out_channels=out_channels, - k_num=kernel_size[1], - edge_type=edge_type, - t_kernel_size=t_kernel_size) - self.tcn_recon = nn.Sequential(nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - nn.ConvTranspose2d(in_channels=out_channels, - out_channels=out_channels, - kernel_size=(kernel_size[0], 1), - stride=(stride, 1), - padding=padding, - output_padding=(stride-1,0)), - nn.BatchNorm2d(out_channels), - nn.Dropout(dropout, inplace=True)) - - if not residual: - self.residual = lambda x: 0 - elif (in_channels == out_channels) and (stride == 1): - self.residual = lambda x: x - else: - self.residual = nn.Sequential(nn.ConvTranspose2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - stride=(stride, 1), - output_padding=(stride-1,0)), - nn.BatchNorm2d(out_channels)) - self.relu = nn.ReLU(inplace=True) - self.activation = activation - - def forward(self, x, A, B, lamda_act): - - res = self.residual(x) - x, A = self.gcn_recon(x, A, B, lamda_act) - x = self.tcn_recon(x) + res - if self.activation == 'relu': - x = self.relu(x) - else: - x = x - - return x, A - - -class SpatialGcn(nn.Module): - - def __init__(self, - in_channels, - out_channels, - k_num, - edge_type=2, - t_kernel_size=1, - t_stride=1, - t_padding=0, - t_dilation=1, - bias=True): - super().__init__() - - self.k_num = k_num - self.edge_type = edge_type - self.conv = nn.Conv2d(in_channels=in_channels, - out_channels=out_channels*k_num, - kernel_size=(t_kernel_size, 1), - padding=(t_padding, 0), - stride=(t_stride, 1), - dilation=(t_dilation, 1), - bias=bias) - - def forward(self, x, A, B, lamda_act): - - x = self.conv(x) - n, kc, t, v = x.size() - x = x.view(n, self.k_num, kc//self.k_num, t, v) - x1 = x[:,:self.k_num-self.edge_type,:,:,:] - x2 = x[:,-self.edge_type:,:,:,:] - x1 = torch.einsum('nkctv,kvw->nctw', (x1, A)) - x2 = torch.einsum('nkctv,nkvw->nctw', (x2, B)) - x_sum = x1+x2*lamda_act - - return x_sum.contiguous(), A - - -class SpatialGcnRecon(nn.Module): - - def __init__(self, in_channels, out_channels, k_num, edge_type=3, - t_kernel_size=1, t_stride=1, t_padding=0, t_outpadding=0, t_dilation=1, - bias=True): - super().__init__() - - self.k_num = k_num - self.edge_type = edge_type - self.deconv = nn.ConvTranspose2d(in_channels=in_channels, - out_channels=out_channels*k_num, - kernel_size=(t_kernel_size, 1), - padding=(t_padding, 0), - output_padding=(t_outpadding, 0), - stride=(t_stride, 1), - dilation=(t_dilation, 1), - bias=bias) - - def forward(self, x, A, B, lamda_act): - - x = self.deconv(x) - n, kc, t, v = x.size() - x = x.view(n, self.k_num, kc//self.k_num, t, v) - x1 = x[:,:self.k_num-self.edge_type,:,:,:] - x2 = x[:,-self.edge_type:,:,:,:] - x1 = torch.einsum('nkctv,kvw->nctw', (x1, A)) - x2 = torch.einsum('nkctv,nkvw->nctw', (x2, B)) - x_sum = x1+x2*lamda_act - - return x_sum.contiguous(), A +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + +from net.utils.graph import Graph + + +class Model(nn.Module): + + def __init__(self, in_channels, num_class, graph_args, + edge_importance_weighting, **kwargs): + super().__init__() + + self.graph = Graph(**graph_args) + A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False) + self.register_buffer('A', A) + self.edge_type = 2 + + temporal_kernel_size = 9 + spatial_kernel_size = A.size(0) + self.edge_type + st_kernel_size = (temporal_kernel_size, spatial_kernel_size) + + self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) + + self.class_layer_0 = StgcnBlock(in_channels, 64, st_kernel_size, self.edge_type, stride=1, residual=False, **kwargs) + self.class_layer_1 = StgcnBlock(64, 64, st_kernel_size, self.edge_type, stride=1, **kwargs) + self.class_layer_2 = StgcnBlock(64, 64, st_kernel_size, self.edge_type, stride=1, **kwargs) + self.class_layer_3 = StgcnBlock(64, 128, st_kernel_size, self.edge_type, stride=2, **kwargs) + self.class_layer_4 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=1, **kwargs) + self.class_layer_5 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=1, **kwargs) + self.class_layer_6 = StgcnBlock(128, 256, st_kernel_size, self.edge_type, stride=2, **kwargs) + self.class_layer_7 = StgcnBlock(256, 256, st_kernel_size, self.edge_type, stride=1, **kwargs) + self.class_layer_8 = StgcnBlock(256, 256, st_kernel_size, self.edge_type, stride=1, **kwargs) + + self.recon_layer_0 = StgcnBlock(256, 128, st_kernel_size, self.edge_type, stride=1, **kwargs) + self.recon_layer_1 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=2, **kwargs) + self.recon_layer_2 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=2, **kwargs) + self.recon_layer_3 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=2, **kwargs) + self.recon_layer_4 = StgcnBlock(128, 128, (3, spatial_kernel_size), self.edge_type, stride=2, **kwargs) + self.recon_layer_5 = StgcnBlock(128, 128, (5, spatial_kernel_size), self.edge_type, stride=1, padding=False, residual=False, **kwargs) + self.recon_layer_6 = StgcnReconBlock(128+3, 30, (1, spatial_kernel_size), self.edge_type, stride=1, padding=False, residual=False, activation=None, **kwargs) + + + if edge_importance_weighting: + self.edge_importance = nn.ParameterList([nn.Parameter(torch.ones(self.A.size())) for i in range(9)]) + self.edge_importance_recon = nn.ParameterList([nn.Parameter(torch.ones(self.A.size())) for i in range(9)]) + else: + self.edge_importance = [1] * (len(self.st_gcn_networks)+len(self.st_gcn_recon)) + self.fcn = nn.Conv2d(256, num_class, kernel_size=1) + + def forward(self, x, x_target, x_last, A_act, lamda_act): + N, C, T, V, M = x.size() + x_recon = x[:,:,:,:,0] # [2N, 3, 300, 25] wsx: x_recon(4,3,290,25) select the first person data? + x = x.permute(0, 4, 3, 1, 2).contiguous() # [N, 2, 25, 3, 300] wsx: x(4,2,25,3,290) + x = x.view(N * M, V * C, T) # [2N, 75, 300]m wsx: x(8,75,290) + + x_last = x_last.permute(0,4,1,2,3).contiguous().view(-1,3,1,25) #(2N,3,1,25) + + x_bn = self.data_bn(x) + x_bn = x_bn.view(N, M, V, C, T) + x_bn = x_bn.permute(0, 1, 3, 4, 2).contiguous() + x_bn = x_bn.view(N * M, C, T, V) #2N,3,290,25 + + h0, _ = self.class_layer_0(x_bn, self.A * self.edge_importance[0], A_act, lamda_act) # [N, 64, 300, 25] + h1, _ = self.class_layer_1(h0, self.A * self.edge_importance[1], A_act, lamda_act) # [N, 64, 300, 25] + h1, _ = self.class_layer_1(h0, self.A * self.edge_importance[1], A_act, lamda_act) # [N, 64, 300, 25] + h2, _ = self.class_layer_2(h1, self.A * self.edge_importance[2], A_act, lamda_act) # [N, 64, 300, 25] + h3, _ = self.class_layer_3(h2, self.A * self.edge_importance[3], A_act, lamda_act) # [N, 128, 150, 25] + h4, _ = self.class_layer_4(h3, self.A * self.edge_importance[4], A_act, lamda_act) # [N, 128, 150, 25] + h5, _ = self.class_layer_5(h4, self.A * self.edge_importance[5], A_act, lamda_act) # [N, 128, 150, 25] + h6, _ = self.class_layer_6(h5, self.A * self.edge_importance[6], A_act, lamda_act) # [N, 256, 75, 25] + h7, _ = self.class_layer_7(h6, self.A * self.edge_importance[7], A_act, lamda_act) # [N, 256, 75, 25] + h8, _ = self.class_layer_8(h7, self.A * self.edge_importance[8], A_act, lamda_act) # [N, 256, 75, 25] + + x_class = F.avg_pool2d(h8, h8.size()[2:]) #(8,256,1,1) + x_class = x_class.view(N, M, -1, 1, 1).mean(dim=1) #(4,256,1,1) + x_class = self.fcn(x_class) #(4,60,1,1) Conv2d(256, 60, kernel_size=(1, 1), stride=(1, 1)) + x_class = x_class.view(x_class.size(0), -1) #(4,60) + + r0, _ = self.recon_layer_0(h8, self.A*self.edge_importance_recon[0], A_act, lamda_act) # [N, 128, 75, 25] + r1, _ = self.recon_layer_1(r0, self.A*self.edge_importance_recon[1], A_act, lamda_act) # [N, 128, 38, 25] + r2, _ = self.recon_layer_2(r1, self.A*self.edge_importance_recon[2], A_act, lamda_act) # [N, 128, 19, 25] + r3, _ = self.recon_layer_3(r2, self.A*self.edge_importance_recon[3], A_act, lamda_act) # [N, 128, 10, 25] + r4, _ = self.recon_layer_4(r3, self.A*self.edge_importance_recon[4], A_act, lamda_act) # [N, 128, 5, 25] + r5, _ = self.recon_layer_5(r4, self.A*self.edge_importance_recon[5], A_act, lamda_act) # [N, 128, 1, 25] + r6, _ = self.recon_layer_6(torch.cat((r5, x_last),1), self.A*self.edge_importance_recon[6], A_act, lamda_act) # [N, 64, 1, 25] wsx:(8,30,1,25) + pred = x_last.squeeze().repeat(1,10,1) + r6.squeeze() # [N, 3, 25] wsx:(8,30,25) + + pred = pred.contiguous().view(-1, 3, 10, 25) + x_target = x_target.permute(0,4,1,2,3).contiguous().view(-1,3,10,25) + + return x_class, pred[::2], x_target[::2] + + def extract_feature(self, x): + + N, C, T, V, M = x.size() + x = x.permute(0, 4, 3, 1, 2).contiguous() + x = x.view(N * M, V * C, T) + x = self.data_bn(x) + x = x.view(N, M, V, C, T) + x = x.permute(0, 1, 3, 4, 2).contiguous() + x = x.view(N * M, C, T, V) + + for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): + x, _ = gcn(x, self.A * importance) + + _, c, t, v = x.size() + feature = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1) + + x = self.fcn(x) + output = x.view(N, M, -1, t, v).permute(0, 2, 3, 4, 1) + + return output, feature + + +class StgcnBlock(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + edge_type=2, + t_kernel_size=1, + stride=1, + padding=True, + dropout=0, + residual=True): + super().__init__() + + assert len(kernel_size) == 2 + assert kernel_size[0] % 2 == 1 + if padding == True: + padding = ((kernel_size[0] - 1) // 2, 0) + else: + padding = (0,0) + + self.gcn = SpatialGcn(in_channels=in_channels, + out_channels=out_channels, + k_num=kernel_size[1], + edge_type=edge_type, + t_kernel_size=t_kernel_size) + self.tcn = nn.Sequential(nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, + out_channels, + (kernel_size[0], 1), + (stride, 1), + padding), + nn.BatchNorm2d(out_channels), + nn.Dropout(dropout, inplace=True)) + + if not residual: + self.residual = lambda x: 0 + elif (in_channels == out_channels) and (stride == 1): + self.residual = lambda x: x + else: + self.residual = nn.Sequential(nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=(stride, 1)), + nn.BatchNorm2d(out_channels)) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x, A, B, lamda_act): + + res = self.residual(x) + x, A = self.gcn(x, A, B, lamda_act) + x = self.tcn(x) + res + + return self.relu(x), A + + +class StgcnReconBlock(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + edge_type=2, + t_kernel_size=1, + stride=1, + padding=True, + dropout=0, + residual=True, + activation='relu'): + super().__init__() + + assert len(kernel_size) == 2 + assert kernel_size[0] % 2 == 1 + + if padding == True: + padding = ((kernel_size[0] - 1) // 2, 0) + else: + padding = (0,0) + + self.gcn_recon = SpatialGcnRecon(in_channels=in_channels, + out_channels=out_channels, + k_num=kernel_size[1], + edge_type=edge_type, + t_kernel_size=t_kernel_size) + self.tcn_recon = nn.Sequential(nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels=out_channels, + out_channels=out_channels, + kernel_size=(kernel_size[0], 1), + stride=(stride, 1), + padding=padding, + output_padding=(stride-1,0)), + nn.BatchNorm2d(out_channels), + nn.Dropout(dropout, inplace=True)) + + if not residual: + self.residual = lambda x: 0 + elif (in_channels == out_channels) and (stride == 1): + self.residual = lambda x: x + else: + self.residual = nn.Sequential(nn.ConvTranspose2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=(stride, 1), + output_padding=(stride-1,0)), + nn.BatchNorm2d(out_channels)) + self.relu = nn.ReLU(inplace=True) + self.activation = activation + + def forward(self, x, A, B, lamda_act): + + res = self.residual(x) + x, A = self.gcn_recon(x, A, B, lamda_act) + x = self.tcn_recon(x) + res + if self.activation == 'relu': + x = self.relu(x) + else: + x = x + + return x, A + + +class SpatialGcn(nn.Module): + + def __init__(self, + in_channels, + out_channels, + k_num, + edge_type=2, + t_kernel_size=1, + t_stride=1, + t_padding=0, + t_dilation=1, + bias=True): + super().__init__() + + self.k_num = k_num + self.edge_type = edge_type + self.conv = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels*k_num, + kernel_size=(t_kernel_size, 1), + padding=(t_padding, 0), + stride=(t_stride, 1), + dilation=(t_dilation, 1), + bias=bias) + + def forward(self, x, A, B, lamda_act): + + x = self.conv(x) + n, kc, t, v = x.size() + x = x.view(n, self.k_num, kc//self.k_num, t, v) + x1 = x[:,:self.k_num-self.edge_type,:,:,:] + x2 = x[:,-self.edge_type:,:,:,:] + x1 = torch.einsum('nkctv,kvw->nctw', (x1, A)) + x2 = torch.einsum('nkctv,nkvw->nctw', (x2, B)) + x_sum = x1+x2*lamda_act + + return x_sum.contiguous(), A + + +class SpatialGcnRecon(nn.Module): + + def __init__(self, in_channels, out_channels, k_num, edge_type=3, + t_kernel_size=1, t_stride=1, t_padding=0, t_outpadding=0, t_dilation=1, + bias=True): + super().__init__() + + self.k_num = k_num + self.edge_type = edge_type + self.deconv = nn.ConvTranspose2d(in_channels=in_channels, + out_channels=out_channels*k_num, + kernel_size=(t_kernel_size, 1), + padding=(t_padding, 0), + output_padding=(t_outpadding, 0), + stride=(t_stride, 1), + dilation=(t_dilation, 1), + bias=bias) + + def forward(self, x, A, B, lamda_act): + + x = self.deconv(x) + n, kc, t, v = x.size() + x = x.view(n, self.k_num, kc//self.k_num, t, v) + x1 = x[:,:self.k_num-self.edge_type,:,:,:] + x2 = x[:,-self.edge_type:,:,:,:] + x1 = torch.einsum('nkctv,kvw->nctw', (x1, A)) + x2 = torch.einsum('nkctv,nkvw->nctw', (x2, B)) + x_sum = x1+x2*lamda_act + + return x_sum.contiguous(), A diff --git a/net/model_poseformer.py b/net/model_poseformer.py new file mode 100644 index 0000000..d702be4 --- /dev/null +++ b/net/model_poseformer.py @@ -0,0 +1,223 @@ +## Our PoseFormer model was revised from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + +import math +import logging +from functools import partial +from collections import OrderedDict +from einops import rearrange, repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import load_pretrained +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model + + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class PoseTransformer(nn.Module): + def __init__(self, num_frame=9, num_joints=25, in_chans=3, embed_dim_ratio: object = 32, depth=4, + num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=None, + num_class=60 + ): + """ ##########hybrid_backbone=None, representation_size=None, + Args: + num_frame (int, tuple): input frame number + num_joints (int, tuple): joints number + in_chans (int): number of input channels, 2D joints have 2 channels: (x,y) + embed_dim_ratio (int): embedding dimension ratio + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + num_class (int): the pose action class amount 30 + """ + super().__init__() + + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + embed_dim = embed_dim_ratio * num_joints #### temporal embed_dim is num_joints * spatial embedding dim ratio + out_dim = num_joints * 3 #### output dimension is num_joints * 3 + + ### spatial patch embedding + self.Spatial_patch_to_embedding = nn.Linear(3, 32) + self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio)) + + self.Temporal_pos_embed = nn.Parameter(torch.zeros(1, num_frame, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.Spatial_blocks = nn.ModuleList([ + Block( + dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.Spatial_norm = norm_layer(embed_dim_ratio) + self.Temporal_norm = norm_layer(embed_dim) + + ####### A easy way to implement weighted mean + self.weighted_mean = torch.nn.Conv1d(in_channels=num_frame, out_channels=1, kernel_size=1) + + self.head = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim , out_dim), + ) + + # wsx aciton_class_head + self.action_class_head = nn.Conv2d(290, num_class, kernel_size=1) + + # self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) + self.data_bn = nn.BatchNorm1d(3 * 25) + + + + + + def Spatial_forward_features(self, x): + b, _, f, p = x.shape ##### b is batch size, f is number of frames, p is number of joints + x = rearrange(x, 'b c f p -> (b f) p c', ) + + x = self.Spatial_patch_to_embedding(x) + x += self.Spatial_pos_embed + x = self.pos_drop(x) + + for blk in self.Spatial_blocks: + x = blk(x) + + x = self.Spatial_norm(x) + x = rearrange(x, '(b f) w c -> b f (w c)', f=f) + return x + + def forward_features(self, x): + b = x.shape[0] + x += self.Temporal_pos_embed + x = self.pos_drop(x) + for blk in self.blocks: + x = blk(x) + + x = self.Temporal_norm(x) + ##### x size [b, f, emb_dim], then take weighted mean on frame dimension, we only predict 3D pose of the center frame + # x = self.weighted_mean(x) #wsx don't change all frame to one + # x = x.view(b, 1, -1) + return x + + def forward(self, x, x_target): + ''' + # x input shape [170, 81, 17, 2] + x = x.permute(0, 3, 1, 2) #[170, 2, 81, 17] + b, _, _, p = x.shape #[170, 2, 81, 17] b:batch_size p:joint_num + ### now x is [batch_size, 2 channels, receptive frames, joint_num], following image data + ''' + + N, C, T, V, M = x.size() + x = x.permute(0, 4, 3, 1, 2).contiguous() + x = x.view(N * M, V * C, T) + x = self.data_bn(x) + x = x.view(N, M, V, C, T) + x = x.permute(0, 1, 3, 4, 2).contiguous() + x = x.view(N * M, C, T, V) + + x = self.Spatial_forward_features(x) + x = self.forward_features(x) # (2n, 290,800) + + # action_class_head + BatchN, FrameN, FutureN = x.size() + x = x.view(BatchN, FrameN, FutureN, 1) + x_class = F.avg_pool2d(x, x.size()[2:]) + x_class = x_class.view(N, M, -1, 1, 1).mean(dim=1) + x_class = self.action_class_head(x_class) + x_class = x_class.view(x_class.size(0), -1) + + + #action_class = x.permute(0,2,1) #[170, 544, 1] + #action_class = self.action_class_head(action_class) + #action_class = torch.squeeze(action_class) + #x = self.head(x) + #x = x.view(b, 1, p, -1) + + x_target = x_target.permute(0, 4, 1, 2, 3).contiguous().view(-1, 3, 10, 25) + + return x_class, x_target[::2] # [170,1,17,3] + diff --git a/net/utils/__init__.py b/net/utils/__init__.py index 8b13789..d3f5a12 100644 --- a/net/utils/__init__.py +++ b/net/utils/__init__.py @@ -1 +1 @@ - + diff --git a/net/utils/adj_learn.py b/net/utils/adj_learn.py index ea8e503..4580873 100644 --- a/net/utils/adj_learn.py +++ b/net/utils/adj_learn.py @@ -1,283 +1,284 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -import numpy as np -from torch.autograd import Variable - - -def my_softmax(input, axis=1): - trans_input = input.transpose(axis, 0).contiguous() - soft_max_1d = F.softmax(trans_input) - return soft_max_1d.transpose(axis, 0) - - -def get_offdiag_indices(num_nodes): - ones = torch.ones(num_nodes, num_nodes) - eye = torch.eye(num_nodes, num_nodes) - offdiag_indices = (ones - eye).nonzero().t() - offdiag_indices_ = offdiag_indices[0] * num_nodes + offdiag_indices[1] - return offdiag_indices, offdiag_indices_ - - -def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10): - y_soft = gumbel_softmax_sample(logits, tau=tau, eps=eps) - if hard: - shape = logits.size() - _, k = y_soft.data.max(-1) - y_hard = torch.zeros(*shape) - if y_soft.is_cuda: - y_hard = y_hard.cuda() - y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0) - y = Variable(y_hard - y_soft.data) + y_soft - else: - y = y_soft - return y - - -def gumbel_softmax_sample(logits, tau=1, eps=1e-10): - gumbel_noise = sample_gumbel(logits.size(), eps=eps) - if logits.is_cuda: - gumbel_noise = gumbel_noise.cuda() - y = logits + Variable(gumbel_noise) - return my_softmax(y / tau, axis=-1) - - -def sample_gumbel(shape, eps=1e-10): - uniform = torch.rand(shape).float() - return - torch.log(eps - torch.log(uniform + eps)) - - -def encode_onehot(labels): - classes = set(labels) - classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)} - labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32) - return labels_onehot - - -class MLP(nn.Module): - - def __init__(self, n_in, n_hid, n_out, do_prob=0.): - super().__init__() - - self.fc1 = nn.Linear(n_in, n_hid) - self.fc2 = nn.Linear(n_hid, n_out) - self.bn = nn.BatchNorm1d(n_out) - self.dropout = nn.Dropout(p=do_prob) - - self.init_weights() - - def init_weights(self): - for m in self.modules(): - if isinstance(m, nn.Linear): - nn.init.xavier_normal_(m.weight.data) - m.bias.data.fill_(0.1) - elif isinstance(m, nn.BatchNorm1d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def batch_norm(self, inputs): - x = inputs.view(inputs.size(0) * inputs.size(1), -1) - x = self.bn(x) - return x.view(inputs.size(0), inputs.size(1), -1) - - def forward(self, inputs): - x = F.elu(self.fc1(inputs)) - x = self.dropout(x) - x = F.elu(self.fc2(x)) - return self.batch_norm(x) - - -class InteractionNet(nn.Module): - - def __init__(self, n_in, n_hid, n_out, do_prob=0., factor=True): - super().__init__() - - self.factor = factor - self.mlp1 = MLP(n_in, n_hid, n_hid, do_prob) - self.mlp2 = MLP(n_hid*2, n_hid, n_hid, do_prob) - self.mlp3 = MLP(n_hid, n_hid, n_hid, do_prob) - self.mlp4 = MLP(n_hid*3, n_hid, n_hid, do_prob) if self.factor else MLP(n_hid*2, n_hid, n_hid, do_prob) - self.fc_out = nn.Linear(n_hid, n_out) - - self.init_weights() - - def init_weights(self): - for m in self.modules(): - if isinstance(m, nn.Linear): - nn.init.xavier_normal_(m.weight.data) - m.bias.data.fill_(0.1) - - def node2edge(self, x, rel_rec, rel_send): - receivers = torch.matmul(rel_rec, x) - senders = torch.matmul(rel_send, x) - edges = torch.cat([receivers, senders], dim=2) - return edges - - def edge2node(self, x, rel_rec, rel_send): - incoming = torch.matmul(rel_rec.t(), x) - nodes = incoming / incoming.size(1) - return nodes - - def forward(self, inputs, rel_rec, rel_send): # input: [N, v, t, c] = [N, 25, 50, 3] - x = inputs.contiguous() - x = x.view(inputs.size(0), inputs.size(1), -1) # [N, 25, 50, 3] -> [N, 25, 50*3=150] - x = self.mlp1(x) # [N, 25, 150] -> [N, 25, n_hid=256] -> [N, 25, n_out=256] - x = self.node2edge(x, rel_rec, rel_send) # [N, 25, 256] -> [N, 600, 256]|[N, 600, 256]=[N, 600, 512] - x = self.mlp2(x) # [N, 600, 512] -> [N, 600, n_hid=256] -> [N, 600, n_out=256] - x_skip = x - if self.factor: - x = self.edge2node(x, rel_rec, rel_send) # [N, 600, 256] -> [N, 25, 256] - x = self.mlp3(x) # [N, 25, 256] -> [N, 25, n_hid=256] -> [N, 25, n_out=256] - x = self.node2edge(x, rel_rec, rel_send) # [N, 25, 256] -> [N, 600, 256]|[N, 600, 256]=[N, 600, 512] - x = torch.cat((x, x_skip), dim=2) # [N, 600, 512] -> [N, 600, 512]|[N, 600, 256]=[N, 600, 768] - x = self.mlp4(x) # [N, 600, 768] -> [N, 600, n_hid=256] -> [N, 600, n_out=256] - else: - x = self.mlp3(x) - x = torch.cat((x, x_skip), dim=2) - x = self.mlp4(x) - return self.fc_out(x) # [N, 600, 256] -> [N, 600, 3] - - -class InteractionDecoderRecurrent(nn.Module): - - def __init__(self, n_in_node, edge_types, n_hid, do_prob=0., skip_first=True): - super().__init__() - - self.msg_fc1 = nn.ModuleList([nn.Linear(2 * n_hid, n_hid) for _ in range(edge_types)]) - self.msg_fc2 = nn.ModuleList([nn.Linear(n_hid, n_hid) for _ in range(edge_types)]) - self.msg_out_shape = n_hid - self.skip_first_edge_type = skip_first - - self.hidden_r = nn.Linear(n_hid, n_hid, bias=False) - self.hidden_i = nn.Linear(n_hid, n_hid, bias=False) - self.hidden_n = nn.Linear(n_hid, n_hid, bias=False) - - self.input_r = nn.Linear(n_in_node, n_hid, bias=True) # 3 x 256 - self.input_i = nn.Linear(n_in_node, n_hid, bias=True) - self.input_n = nn.Linear(n_in_node, n_hid, bias=True) - - self.out_fc1 = nn.Linear(n_hid, n_hid) - self.out_fc2 = nn.Linear(n_hid, n_hid) - self.out_fc3 = nn.Linear(n_hid, n_in_node) - - self.dropout1 = nn.Dropout(p=do_prob) - self.dropout2 = nn.Dropout(p=do_prob) - self.dropout3 = nn.Dropout(p=do_prob) - - def single_step_forward(self, inputs, rel_rec, rel_send, rel_type, hidden): - receivers = torch.matmul(rel_rec, hidden) - senders = torch.matmul(rel_send, hidden) - pre_msg = torch.cat([receivers, senders], dim=-1) - all_msgs = torch.zeros(pre_msg.size(0), pre_msg.size(1), self.msg_out_shape) - gpu_id = rel_rec.get_device() - all_msgs = all_msgs.cuda(gpu_id) - if self.skip_first_edge_type: - start_idx = 1 - norm = float(len(self.msg_fc2)) - 1. - else: - start_idx = 0 - norm = float(len(self.msg_fc2)) - for k in range(start_idx, len(self.msg_fc2)): - msg = torch.tanh(self.msg_fc1[k](pre_msg)) - msg = self.dropout1(msg) - msg = torch.tanh(self.msg_fc2[k](msg)) - msg = msg * rel_type[:, :, k:k + 1] - all_msgs += msg / norm - agg_msgs = all_msgs.transpose(-2, -1).matmul(rel_rec).transpose(-2, -1) - agg_msgs = agg_msgs.contiguous()/inputs.size(2) - - r = torch.sigmoid(self.input_r(inputs) + self.hidden_r(agg_msgs)) - i = torch.sigmoid(self.input_i(inputs) + self.hidden_i(agg_msgs)) - n = torch.tanh(self.input_n(inputs) + r * self.hidden_n(agg_msgs)) - hidden = (1-i)*n + i*hidden - - pred = self.dropout2(F.relu(self.out_fc1(hidden))) - pred = self.dropout2(F.relu(self.out_fc2(pred))) - pred = self.out_fc3(pred) - pred = inputs + pred - - return pred, hidden - - def forward(self, data, rel_type, rel_rec, rel_send, pred_steps=1, - burn_in=False, burn_in_steps=1, dynamic_graph=False, - encoder=None, temp=None): - inputs = data.transpose(1, 2).contiguous() - time_steps = inputs.size(1) - hidden = torch.zeros(inputs.size(0), inputs.size(2), self.msg_out_shape) - gpu_id = rel_rec.get_device() - hidden = hidden.cuda(gpu_id) - pred_all = [] - for step in range(0, inputs.size(1) - 1): - if not step % pred_steps: - ins = inputs[:, step, :, :] - else: - ins = pred_all[step - 1] - pred, hidden = self.single_step_forward(ins, rel_rec, rel_send, rel_type, hidden) - pred_all.append(pred) - preds = torch.stack(pred_all, dim=1) - return preds.transpose(1, 2).contiguous() - - -class AdjacencyLearn(nn.Module): - - def __init__(self, n_in_enc, n_hid_enc, edge_types, n_in_dec, n_hid_dec, node_num=25): - super().__init__() - - self.encoder = InteractionNet(n_in=n_in_enc, # 150 - n_hid=n_hid_enc, # 256 - n_out=edge_types, # 3 - do_prob=0.5, - factor=True) - self.decoder = InteractionDecoderRecurrent(n_in_node=n_in_dec, # 256 - edge_types=edge_types, # 3 - n_hid=n_hid_dec, # 256 - do_prob=0.5, - skip_first=True) - self.offdiag_indices, _ = get_offdiag_indices(node_num) - - off_diag = np.ones([node_num, node_num])-np.eye(node_num, node_num) - self.rel_rec = torch.FloatTensor(np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)) - self.rel_send = torch.FloatTensor(np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)) - self.dcy = 0.1 - - self.init_weights() - - def init_weights(self): - for m in self.modules(): - if isinstance(m, nn.BatchNorm1d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def forward(self, inputs): # [N, 3, 50, 25, 2] - - N, C, T, V, M = inputs.size() - x = inputs.permute(0, 4, 3, 1, 2).contiguous() # [N, 2, 25, 3, 50] - x = x.contiguous().view(N*M, V, C, T).permute(0,1,3,2) # [2N, 25, 50, 3] - - gpu_id = x.get_device() - rel_rec = self.rel_rec.cuda(gpu_id) - rel_send = self.rel_send.cuda(gpu_id) - - self.logits = self.encoder(x, rel_rec, rel_send) - self.N, self.v, self.c = self.logits.size() - self.edges = gumbel_softmax(self.logits, tau=0.5, hard=True) - self.prob = my_softmax(self.logits, -1) - self.outputs = self.decoder(x, self.edges, rel_rec, rel_send, burn_in=False, burn_in_steps=40) - self.offdiag_indices = self.offdiag_indices.cuda(gpu_id) - - A_batch = [] - for i in range(self.N): - A_types = [] - for j in range(1, self.c): - A = torch.sparse.FloatTensor(self.offdiag_indices, self.edges[i,:,j], torch.Size([25, 25])).to_dense().cuda(gpu_id) - A = A + torch.eye(25, 25).cuda(gpu_id) - D = torch.sum(A, dim=0).squeeze().pow(-1)+1e-10 - D = torch.diag(D) - A_ = torch.matmul(A, D)*self.dcy - A_types.append(A_) - A_types = torch.stack(A_types) - A_batch.append(A_types) - self.A_batch = torch.stack(A_batch).cuda(gpu_id) # [N, 2, 25, 25] - - return self.A_batch, self.prob, self.outputs, x +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import numpy as np +from torch.autograd import Variable + + +def my_softmax(input, axis=1): + trans_input = input.transpose(axis, 0).contiguous() + soft_max_1d = F.softmax(trans_input) + return soft_max_1d.transpose(axis, 0) + + +def get_offdiag_indices(num_nodes): + ones = torch.ones(num_nodes, num_nodes) + eye = torch.eye(num_nodes, num_nodes) + offdiag_indices = (ones - eye).nonzero().t() + offdiag_indices_ = offdiag_indices[0] * num_nodes + offdiag_indices[1] + return offdiag_indices, offdiag_indices_ + + +def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10): + y_soft = gumbel_softmax_sample(logits, tau=tau, eps=eps) + if hard: + shape = logits.size() + _, k = y_soft.data.max(-1) + y_hard = torch.zeros(*shape) + if y_soft.is_cuda: + y_hard = y_hard.cuda() + y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0) + y = Variable(y_hard - y_soft.data) + y_soft + else: + y = y_soft + return y + + +def gumbel_softmax_sample(logits, tau=1, eps=1e-10): + gumbel_noise = sample_gumbel(logits.size(), eps=eps) + if logits.is_cuda: + gumbel_noise = gumbel_noise.cuda() + y = logits + Variable(gumbel_noise) + return my_softmax(y / tau, axis=-1) + + +def sample_gumbel(shape, eps=1e-10): + uniform = torch.rand(shape).float() + return - torch.log(eps - torch.log(uniform + eps)) + + +def encode_onehot(labels): + classes = set(labels) + classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)} + labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32) + return labels_onehot + + +class MLP(nn.Module): + + def __init__(self, n_in, n_hid, n_out, do_prob=0.): + super().__init__() + + self.fc1 = nn.Linear(n_in, n_hid) + self.fc2 = nn.Linear(n_hid, n_out) + self.bn = nn.BatchNorm1d(n_out) + self.dropout = nn.Dropout(p=do_prob) + + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight.data) + m.bias.data.fill_(0.1) + elif isinstance(m, nn.BatchNorm1d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def batch_norm(self, inputs): + x = inputs.view(inputs.size(0) * inputs.size(1), -1) + x = self.bn(x) + return x.view(inputs.size(0), inputs.size(1), -1) + + def forward(self, inputs): + x = F.elu(self.fc1(inputs)) + x = self.dropout(x) + x = F.elu(self.fc2(x)) + return self.batch_norm(x) + + +class InteractionNet(nn.Module): + + def __init__(self, n_in, n_hid, n_out, do_prob=0., factor=True): + super().__init__() + + self.factor = factor + self.mlp1 = MLP(n_in, n_hid, n_hid, do_prob) + self.mlp2 = MLP(n_hid*2, n_hid, n_hid, do_prob) + self.mlp3 = MLP(n_hid, n_hid, n_hid, do_prob) + self.mlp4 = MLP(n_hid*3, n_hid, n_hid, do_prob) if self.factor else MLP(n_hid*2, n_hid, n_hid, do_prob) + self.fc_out = nn.Linear(n_hid, n_out) + + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight.data) + m.bias.data.fill_(0.1) + + def node2edge(self, x, rel_rec, rel_send): + receivers = torch.matmul(rel_rec, x) + senders = torch.matmul(rel_send, x) + edges = torch.cat([receivers, senders], dim=2) + return edges + + def edge2node(self, x, rel_rec, rel_send): + incoming = torch.matmul(rel_rec.t(), x) + nodes = incoming / incoming.size(1) + return nodes + + def forward(self, inputs, rel_rec, rel_send): # input: [N, v, t, c] = [N, 25, 50, 3] + x = inputs.contiguous() + x = x.view(inputs.size(0), inputs.size(1), -1) # [N, 25, 50, 3] -> [N, 25, 50*3=150] + x = self.mlp1(x) # [N, 25, 150] -> [N, 25, n_hid=256] -> [N, 25, n_out=256] + x = self.node2edge(x, rel_rec, rel_send) # [N, 25, 256] -> [N, 600, 256]|[N, 600, 256]=[N, 600, 512] + x = self.mlp2(x) # [N, 600, 512] -> [N, 600, n_hid=256] -> [N, 600, n_out=256] + x_skip = x + if self.factor: + x = self.edge2node(x, rel_rec, rel_send) # [N, 600, 256] -> [N, 25, 256] + x = self.mlp3(x) # [N, 25, 256] -> [N, 25, n_hid=256] -> [N, 25, n_out=256] + x = self.node2edge(x, rel_rec, rel_send) # [N, 25, 256] -> [N, 600, 256]|[N, 600, 256]=[N, 600, 512] + x = torch.cat((x, x_skip), dim=2) # [N, 600, 512] -> [N, 600, 512]|[N, 600, 256]=[N, 600, 768] + x = self.mlp4(x) # [N, 600, 768] -> [N, 600, n_hid=256] -> [N, 600, n_out=256] + else: + x = self.mlp3(x) + x = torch.cat((x, x_skip), dim=2) + x = self.mlp4(x) + return self.fc_out(x) # [N, 600, 256] -> [N, 600, 3] + + +class InteractionDecoderRecurrent(nn.Module): + + def __init__(self, n_in_node, edge_types, n_hid, do_prob=0., skip_first=True): + super().__init__() + + self.msg_fc1 = nn.ModuleList([nn.Linear(2 * n_hid, n_hid) for _ in range(edge_types)]) + self.msg_fc2 = nn.ModuleList([nn.Linear(n_hid, n_hid) for _ in range(edge_types)]) + self.msg_out_shape = n_hid + self.skip_first_edge_type = skip_first + + self.hidden_r = nn.Linear(n_hid, n_hid, bias=False) + self.hidden_i = nn.Linear(n_hid, n_hid, bias=False) + self.hidden_n = nn.Linear(n_hid, n_hid, bias=False) + + self.input_r = nn.Linear(n_in_node, n_hid, bias=True) # 3 x 256 + self.input_i = nn.Linear(n_in_node, n_hid, bias=True) + self.input_n = nn.Linear(n_in_node, n_hid, bias=True) + + self.out_fc1 = nn.Linear(n_hid, n_hid) + self.out_fc2 = nn.Linear(n_hid, n_hid) + self.out_fc3 = nn.Linear(n_hid, n_in_node) + + self.dropout1 = nn.Dropout(p=do_prob) + self.dropout2 = nn.Dropout(p=do_prob) + self.dropout3 = nn.Dropout(p=do_prob) + + def single_step_forward(self, inputs, rel_rec, rel_send, rel_type, hidden): + receivers = torch.matmul(rel_rec, hidden) + senders = torch.matmul(rel_send, hidden) + pre_msg = torch.cat([receivers, senders], dim=-1) + all_msgs = torch.zeros(pre_msg.size(0), pre_msg.size(1), self.msg_out_shape) + gpu_id = rel_rec.get_device() + all_msgs = all_msgs.cuda(gpu_id) + if self.skip_first_edge_type: + start_idx = 1 + norm = float(len(self.msg_fc2)) - 1. + else: + start_idx = 0 + norm = float(len(self.msg_fc2)) + for k in range(start_idx, len(self.msg_fc2)): + msg = torch.tanh(self.msg_fc1[k](pre_msg)) + msg = self.dropout1(msg) + msg = torch.tanh(self.msg_fc2[k](msg)) + msg = msg * rel_type[:, :, k:k + 1] + all_msgs += msg / norm + agg_msgs = all_msgs.transpose(-2, -1).matmul(rel_rec).transpose(-2, -1) + agg_msgs = agg_msgs.contiguous()/inputs.size(2) + + r = torch.sigmoid(self.input_r(inputs) + self.hidden_r(agg_msgs)) + i = torch.sigmoid(self.input_i(inputs) + self.hidden_i(agg_msgs)) + n = torch.tanh(self.input_n(inputs) + r * self.hidden_n(agg_msgs)) + hidden = (1-i)*n + i*hidden + + pred = self.dropout2(F.relu(self.out_fc1(hidden))) + pred = self.dropout2(F.relu(self.out_fc2(pred))) + pred = self.out_fc3(pred) + pred = inputs + pred + + return pred, hidden + + def forward(self, data, rel_type, rel_rec, rel_send, pred_steps=1, + burn_in=False, burn_in_steps=1, dynamic_graph=False, + encoder=None, temp=None): + inputs = data.transpose(1, 2).contiguous() + time_steps = inputs.size(1) + hidden = torch.zeros(inputs.size(0), inputs.size(2), self.msg_out_shape) + gpu_id = rel_rec.get_device() + hidden = hidden.cuda(gpu_id) + pred_all = [] + for step in range(0, inputs.size(1) - 1): + if not step % pred_steps: + ins = inputs[:, step, :, :] + else: + ins = pred_all[step - 1] + pred, hidden = self.single_step_forward(ins, rel_rec, rel_send, rel_type, hidden) + pred_all.append(pred) + preds = torch.stack(pred_all, dim=1) + return preds.transpose(1, 2).contiguous() + + +class AdjacencyLearn(nn.Module): + + def __init__(self, n_in_enc, n_hid_enc, edge_types, n_in_dec, n_hid_dec, node_num=25): + super().__init__() + + self.encoder = InteractionNet(n_in=n_in_enc, # 150 + n_hid=n_hid_enc, # 256 + n_out=edge_types, # 3 + do_prob=0.5, + factor=True) + self.decoder = InteractionDecoderRecurrent(n_in_node=n_in_dec, # 256 + edge_types=edge_types, # 3 + n_hid=n_hid_dec, # 256 + do_prob=0.5, + skip_first=True) + self.offdiag_indices, _ = get_offdiag_indices(node_num) + + off_diag = np.ones([node_num, node_num])-np.eye(node_num, node_num) + self.rel_rec = torch.FloatTensor(np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)) + self.rel_send = torch.FloatTensor(np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)) + self.dcy = 0.1 + + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm1d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, inputs): # [N, 3, 50, 25, 2] + print("enter AdjacencyLearn") + + N, C, T, V, M = inputs.size() + x = inputs.permute(0, 4, 3, 1, 2).contiguous() # [N, 2, 25, 3, 50] + x = x.contiguous().view(N*M, V, C, T).permute(0,1,3,2) # [2N, 25, 50, 3] + + gpu_id = x.get_device() + rel_rec = self.rel_rec.cuda(gpu_id) + rel_send = self.rel_send.cuda(gpu_id) + + self.logits = self.encoder(x, rel_rec, rel_send) + self.N, self.v, self.c = self.logits.size() + self.edges = gumbel_softmax(self.logits, tau=0.5, hard=True) + self.prob = my_softmax(self.logits, -1) + self.outputs = self.decoder(x, self.edges, rel_rec, rel_send, burn_in=False, burn_in_steps=40) + self.offdiag_indices = self.offdiag_indices.cuda(gpu_id) + + A_batch = [] + for i in range(self.N): + A_types = [] + for j in range(1, self.c): + A = torch.sparse.FloatTensor(self.offdiag_indices, self.edges[i,:,j], torch.Size([25, 25])).to_dense().cuda(gpu_id) + A = A + torch.eye(25, 25).cuda(gpu_id) + D = torch.sum(A, dim=0).squeeze().pow(-1)+1e-10 + D = torch.diag(D) + A_ = torch.matmul(A, D)*self.dcy + A_types.append(A_) + A_types = torch.stack(A_types) + A_batch.append(A_types) + self.A_batch = torch.stack(A_batch).cuda(gpu_id) # [N, 2, 25, 25] + + return self.A_batch, self.prob, self.outputs, x diff --git a/net/utils/graph.py b/net/utils/graph.py index 52bb13e..708f0b8 100644 --- a/net/utils/graph.py +++ b/net/utils/graph.py @@ -1,129 +1,129 @@ -import numpy as np - -class Graph(): - - def __init__(self, - layout='openpose', - strategy='uniform', - max_hop=2, - dilation=1): - self.max_hop = max_hop - self.dilation = dilation - - self.get_edge(layout) - self.hop_dis = get_hop_distance(self.num_node, self.edge, max_hop=max_hop) - self.get_adjacency(strategy) - - def __str__(self): - return self.A - - def get_edge(self, layout): - if layout == 'openpose': - self.num_node = 18 - self_link = [(i, i) for i in range(self.num_node)] - neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12, 11), - (10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1), - (0, 1), (15, 0), (14, 0), (17, 15), (16, 14)] - self.edge = self_link + neighbor_link - self.center = 1 - elif layout == 'ntu-rgb+d': - self.num_node = 25 - self_link = [(i, i) for i in range(self.num_node)] - neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), - (6, 5), (7, 6), (8, 7), (9, 21), (10, 9), - (11, 10), (12, 11), (13, 1), (14, 13), (15, 14), - (16, 15), (17, 1), (18, 17), (19, 18), (20, 19), - (22, 23), (23, 8), (24, 25), (25, 12)] - neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] - self.edge = self_link + neighbor_link - self.center = 21 - 1 - elif layout == 'ntu_edge': - self.num_node = 24 - self_link = [(i, i) for i in range(self.num_node)] - neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6), - (8, 7), (9, 2), (10, 9), (11, 10), (12, 11), - (13, 1), (14, 13), (15, 14), (16, 15), (17, 1), - (18, 17), (19, 18), (20, 19), (21, 22), (22, 8), - (23, 24), (24, 12)] - neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] - self.edge = self_link + neighbor_link - self.center = 2 - else: - raise ValueError("Do Not Exist This Layout.") - - def get_adjacency(self, strategy): - valid_hop = range(0, self.max_hop + 1, self.dilation) - adjacency = np.zeros((self.num_node, self.num_node)) - for hop in valid_hop: - adjacency[self.hop_dis == hop] = 1 - normalize_adjacency = normalize_digraph(adjacency) - - if strategy == 'uniform': - A = np.zeros((1, self.num_node, self.num_node)) - A[0] = normalize_adjacency - self.A = A - elif strategy == 'distance': - A = np.zeros((len(valid_hop), self.num_node, self.num_node)) - for i, hop in enumerate(valid_hop): - A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == hop] - self.A = A - elif strategy == 'spatial': - A = [] - for hop in valid_hop: - a_root = np.zeros((self.num_node, self.num_node)) - a_close = np.zeros((self.num_node, self.num_node)) - a_further = np.zeros((self.num_node, self.num_node)) - for i in range(self.num_node): - for j in range(self.num_node): - if self.hop_dis[j, i] == hop: - if self.hop_dis[j, self.center] == self.hop_dis[i, self.center]: - a_root[j, i] = normalize_adjacency[j, i] - elif self.hop_dis[j, self.center] > self.hop_dis[i, self.center]: - a_close[j, i] = normalize_adjacency[j, i] - else: - a_further[j, i] = normalize_adjacency[j, i] - if hop == 0: - A.append(a_root) - else: - A.append(a_root + a_close) - A.append(a_further) - A = np.stack(A) - self.A = A - else: - raise ValueError("Do Not Exist This Strategy") - - -def get_hop_distance(num_node, edge, max_hop=1): - A = np.zeros((num_node, num_node)) - for i, j in edge: - A[j, i] = 1 - A[i, j] = 1 - - hop_dis = np.zeros((num_node, num_node)) + np.inf - transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)] - arrive_mat = (np.stack(transfer_mat) > 0) - for d in range(max_hop, -1, -1): - hop_dis[arrive_mat[d]] = d - return hop_dis - - -def normalize_digraph(A): - Dl = np.sum(A, 0) - num_node = A.shape[0] - Dn = np.zeros((num_node, num_node)) - for i in range(num_node): - if Dl[i] > 0: - Dn[i, i] = Dl[i]**(-1) - AD = np.dot(A, Dn) - return AD - - -def normalize_undigraph(A): - Dl = np.sum(A, 0) - num_node = A.shape[0] - Dn = np.zeros((num_node, num_node)) - for i in range(num_node): - if Dl[i] > 0: - Dn[i, i] = Dl[i]**(-0.5) - DAD = np.dot(np.dot(Dn, A), Dn) +import numpy as np + +class Graph(): + + def __init__(self, + layout='openpose', + strategy='uniform', + max_hop=2, + dilation=1): + self.max_hop = max_hop + self.dilation = dilation + + self.get_edge(layout) + self.hop_dis = get_hop_distance(self.num_node, self.edge, max_hop=max_hop) + self.get_adjacency(strategy) + + def __str__(self): + return self.A + + def get_edge(self, layout): + if layout == 'openpose': + self.num_node = 18 + self_link = [(i, i) for i in range(self.num_node)] + neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12, 11), + (10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1), + (0, 1), (15, 0), (14, 0), (17, 15), (16, 14)] + self.edge = self_link + neighbor_link + self.center = 1 + elif layout == 'ntu-rgb+d': + self.num_node = 25 + self_link = [(i, i) for i in range(self.num_node)] + neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), + (6, 5), (7, 6), (8, 7), (9, 21), (10, 9), + (11, 10), (12, 11), (13, 1), (14, 13), (15, 14), + (16, 15), (17, 1), (18, 17), (19, 18), (20, 19), + (22, 23), (23, 8), (24, 25), (25, 12)] + neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] + self.edge = self_link + neighbor_link + self.center = 21 - 1 + elif layout == 'ntu_edge': + self.num_node = 24 + self_link = [(i, i) for i in range(self.num_node)] + neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6), + (8, 7), (9, 2), (10, 9), (11, 10), (12, 11), + (13, 1), (14, 13), (15, 14), (16, 15), (17, 1), + (18, 17), (19, 18), (20, 19), (21, 22), (22, 8), + (23, 24), (24, 12)] + neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] + self.edge = self_link + neighbor_link + self.center = 2 + else: + raise ValueError("Do Not Exist This Layout.") + + def get_adjacency(self, strategy): + valid_hop = range(0, self.max_hop + 1, self.dilation) + adjacency = np.zeros((self.num_node, self.num_node)) + for hop in valid_hop: + adjacency[self.hop_dis == hop] = 1 + normalize_adjacency = normalize_digraph(adjacency) + + if strategy == 'uniform': + A = np.zeros((1, self.num_node, self.num_node)) + A[0] = normalize_adjacency + self.A = A + elif strategy == 'distance': + A = np.zeros((len(valid_hop), self.num_node, self.num_node)) + for i, hop in enumerate(valid_hop): + A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == hop] + self.A = A + elif strategy == 'spatial': + A = [] + for hop in valid_hop: + a_root = np.zeros((self.num_node, self.num_node)) + a_close = np.zeros((self.num_node, self.num_node)) + a_further = np.zeros((self.num_node, self.num_node)) + for i in range(self.num_node): + for j in range(self.num_node): + if self.hop_dis[j, i] == hop: + if self.hop_dis[j, self.center] == self.hop_dis[i, self.center]: + a_root[j, i] = normalize_adjacency[j, i] + elif self.hop_dis[j, self.center] > self.hop_dis[i, self.center]: + a_close[j, i] = normalize_adjacency[j, i] + else: + a_further[j, i] = normalize_adjacency[j, i] + if hop == 0: + A.append(a_root) + else: + A.append(a_root + a_close) + A.append(a_further) + A = np.stack(A) + self.A = A + else: + raise ValueError("Do Not Exist This Strategy") + + +def get_hop_distance(num_node, edge, max_hop=1): + A = np.zeros((num_node, num_node)) + for i, j in edge: + A[j, i] = 1 + A[i, j] = 1 + + hop_dis = np.zeros((num_node, num_node)) + np.inf + transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)] + arrive_mat = (np.stack(transfer_mat) > 0) + for d in range(max_hop, -1, -1): + hop_dis[arrive_mat[d]] = d + return hop_dis + + +def normalize_digraph(A): + Dl = np.sum(A, 0) + num_node = A.shape[0] + Dn = np.zeros((num_node, num_node)) + for i in range(num_node): + if Dl[i] > 0: + Dn[i, i] = Dl[i]**(-1) + AD = np.dot(A, Dn) + return AD + + +def normalize_undigraph(A): + Dl = np.sum(A, 0) + num_node = A.shape[0] + Dn = np.zeros((num_node, num_node)) + for i in range(num_node): + if Dl[i] > 0: + Dn[i, i] = Dl[i]**(-0.5) + DAD = np.dot(np.dot(Dn, A), Dn) return DAD \ No newline at end of file diff --git a/net/utils/utils_adj.py b/net/utils/utils_adj.py index c6a0218..034b64c 100644 --- a/net/utils/utils_adj.py +++ b/net/utils/utils_adj.py @@ -1,48 +1,48 @@ -import os -import numpy as np -import torch -import torch.utils.data -import torch.nn.functional as F -from torch.autograd import Variable - - -def my_softmax(input, axis=1): - trans_input = input.transpose(axis, 0).contiguous() - soft_max_1d = F.softmax(trans_input) - return soft_max_1d.transpose(axis, 0) - - -def get_offdiag_indices(num_nodes): - ones = torch.ones(num_nodes, num_nodes) - eye = torch.eye(num_nodes, num_nodes) - offdiag_indices = (ones - eye).nonzero().t() - offdiag_indices_ = offdiag_indices[0] * num_nodes + offdiag_indices[1] - return offdiag_indices, offdiag_indices_ - - -def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10): - y_soft = gumbel_softmax_sample(logits, tau=tau, eps=eps) - if hard: - shape = logits.size() - _, k = y_soft.data.max(-1) - y_hard = torch.zeros(*shape) - if y_soft.is_cuda: - y_hard = y_hard.cuda() - y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0) - y = Variable(y_hard - y_soft.data) + y_soft - else: - y = y_soft - return y - - -def gumbel_softmax_sample(logits, tau=1, eps=1e-10): - gumbel_noise = sample_gumbel(logits.size(), eps=eps) - if logits.is_cuda: - gumbel_noise = gumbel_noise.cuda() - y = logits + Variable(gumbel_noise) - return my_softmax(y / tau, axis=-1) - - -def sample_gumbel(shape, eps=1e-10): - uniform = torch.rand(shape).float() +import os +import numpy as np +import torch +import torch.utils.data +import torch.nn.functional as F +from torch.autograd import Variable + + +def my_softmax(input, axis=1): + trans_input = input.transpose(axis, 0).contiguous() + soft_max_1d = F.softmax(trans_input) + return soft_max_1d.transpose(axis, 0) + + +def get_offdiag_indices(num_nodes): + ones = torch.ones(num_nodes, num_nodes) + eye = torch.eye(num_nodes, num_nodes) + offdiag_indices = (ones - eye).nonzero().t() + offdiag_indices_ = offdiag_indices[0] * num_nodes + offdiag_indices[1] + return offdiag_indices, offdiag_indices_ + + +def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10): + y_soft = gumbel_softmax_sample(logits, tau=tau, eps=eps) + if hard: + shape = logits.size() + _, k = y_soft.data.max(-1) + y_hard = torch.zeros(*shape) + if y_soft.is_cuda: + y_hard = y_hard.cuda() + y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0) + y = Variable(y_hard - y_soft.data) + y_soft + else: + y = y_soft + return y + + +def gumbel_softmax_sample(logits, tau=1, eps=1e-10): + gumbel_noise = sample_gumbel(logits.size(), eps=eps) + if logits.is_cuda: + gumbel_noise = gumbel_noise.cuda() + y = logits + Variable(gumbel_noise) + return my_softmax(y / tau, axis=-1) + + +def sample_gumbel(shape, eps=1e-10): + uniform = torch.rand(shape).float() return - torch.log(eps - torch.log(U + eps)) \ No newline at end of file diff --git a/pip_req.txt b/pip_req.txt new file mode 100644 index 0000000..aae02c0 --- /dev/null +++ b/pip_req.txt @@ -0,0 +1,15 @@ +argparse==1.4.0 +cached-property==1.5.2 +dataclasses==0.8 +h5py==3.1.0 +imageio==2.9.0 +numpy==1.19.5 +opencv-python==4.5.1.48 +pyyaml==5.4.1 +scikit-video==1.1.11 +scipy==1.5.4 +torch==1.7.1 +torchvision==0.9.0 +tqdm==4.60.0 +typing-extensions==3.7.4.3 + diff --git a/processor/__init__.py b/processor/__init__.py index 8b13789..d3f5a12 100644 --- a/processor/__init__.py +++ b/processor/__init__.py @@ -1 +1 @@ - + diff --git a/processor/gpu.py b/processor/gpu.py index 306c391..e086d4c 100644 --- a/processor/gpu.py +++ b/processor/gpu.py @@ -1,35 +1,35 @@ -import os -import torch - - -def visible_gpu(gpus): - """ - set visible gpu. - - can be a single id, or a list - - return a list of new gpus ids - """ - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) - return list(range(len(gpus))) - - -def ngpu(gpus): - """ - count how many gpus used. - """ - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - return len(gpus) - - -def occupy_gpu(gpus=None): - """ - make program appear on nvidia-smi. - """ - if gpus is None: - torch.zeros(1).cuda() - else: - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - for g in gpus: - torch.zeros(1).cuda(g) +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/processor/io.py b/processor/io.py index b8d11ca..7e7e57c 100644 --- a/processor/io.py +++ b/processor/io.py @@ -1,116 +1,118 @@ -import sys -import os -import argparse -import yaml -import numpy as np - -import torch -import torch.nn as nn - -import torchlight -from torchlight.io import str2bool -from torchlight.io import DictAction -from torchlight.io import import_class - - -class IO(): - - def __init__(self, argv=None): - - self.load_arg(argv) - self.init_environment() - self.load_model() - self.load_weights() - self.gpu() - - def load_arg(self, argv=None): - parser = self.get_parser() - - # load arg form config file - p = parser.parse_args(argv) - if p.config is not None: - # load config file - with open(p.config, 'r') as f: - default_arg = yaml.load(f) - - # update parser from config file - key = vars(p).keys() - for k in default_arg.keys(): - if k not in key: - print('Unknown Arguments: {}'.format(k)) - assert k in key - - parser.set_defaults(**default_arg) - - self.arg = parser.parse_args(argv) - - def init_environment(self): - self.save_dir = os.path.join(self.arg.work_dir, - self.arg.max_hop_dir, - self.arg.lamda_act_dir) - self.io = torchlight.io.IO(self.save_dir, save_log=self.arg.save_log, print_log=self.arg.print_log) - self.io.save_arg(self.arg) - - # gpu - if self.arg.use_gpu: - gpus = torchlight.gpu.visible_gpu(self.arg.device) - #torchlight.occupy_gpu(gpus) - self.gpus = gpus - self.dev = "cuda:0" - else: - self.dev = "cpu" - - def load_model(self): - self.model1 = self.io.load_model(self.arg.model1, **(self.arg.model1_args)) - self.model2 = self.io.load_model(self.arg.model2, **(self.arg.model2_args)) - - def load_weights(self): - if self.arg.weights1: - self.model1 = self.io.load_weights(self.model1, self.arg.weights1, self.arg.ignore_weights) - self.model2 = self.io.load_weights(self.model2, self.arg.weights2, self.arg.ignore_weights) - - def gpu(self): - # move modules to gpu - self.model1 = self.model1.to(self.dev) - self.model2 = self.model2.to(self.dev) - for name, value in vars(self).items(): - cls_name = str(value.__class__) - if cls_name.find('torch.nn.modules') != -1: - setattr(self, name, value.to(self.dev)) - - # model parallel - if self.arg.use_gpu and len(self.gpus) > 1: - self.model1 = nn.DataParallel(self.model1, device_ids=self.gpus) - self.model2 = nn.DataParallel(self.model2, device_ids=self.gpus) - - def start(self): - self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg)))) - - @staticmethod - def get_parser(add_help=False): - - #region arguments yapf: disable - # parameter priority: command line > config > default - parser = argparse.ArgumentParser( add_help=add_help, description='IO Processor') - - parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results') - parser.add_argument('-c', '--config', default=None, help='path to the configuration file') - - # processor - parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not') - parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing') - - # visulize and debug - parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not') - parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not') - - # model - parser.add_argument('--model1', default=None, help='the model will be used') - parser.add_argument('--model2', default=None, help='the model will be used') - parser.add_argument('--model1_args', action=DictAction, default=dict(), help='the arguments of model') - parser.add_argument('--model2_args', action=DictAction, default=dict(), help='the arguments of model') - parser.add_argument('--weights', default=None, help='the weights for network initialization') - parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization') - #endregion yapf: enable - - return parser +import sys +import os +import argparse +import yaml +import numpy as np + +import torch +import torch.nn as nn + +import torchlight +from torchlight.io import str2bool +from torchlight.io import DictAction +from torchlight.io import import_class + + +class IO(): + + def __init__(self, argv=None): + + self.load_arg(argv) + self.init_environment() + self.load_model() + self.load_weights() + self.gpu() + + def load_arg(self, argv=None): + parser = self.get_parser() + + # load arg form config file + p = parser.parse_args(argv) + if p.config is not None: + # load config file + with open(p.config, 'r') as f: + default_arg = yaml.load(f) + + # update parser from config file + key = vars(p).keys() + for k in default_arg.keys(): + if k not in key: + print('Unknown Arguments: {}'.format(k)) + assert k in key + + parser.set_defaults(**default_arg) + + self.arg = parser.parse_args(argv) + + def init_environment(self): + self.save_dir = os.path.join(self.arg.work_dir, + self.arg.max_hop_dir, + self.arg.lamda_act_dir) + self.io = torchlight.io.IO(self.save_dir, save_log=self.arg.save_log, print_log=self.arg.print_log) + self.io.save_arg(self.arg) + + # gpu + if self.arg.use_gpu: + gpus = torchlight.gpu.visible_gpu(self.arg.device) + #torchlight.occupy_gpu(gpus) + self.gpus = gpus + self.dev = "cuda:0" + else: + self.dev = "cpu" + + def load_model(self): + self.model1 = self.io.load_model(self.arg.model1, **(self.arg.model1_args)) + self.model2 = self.io.load_model(self.arg.model2, **(self.arg.model2_args)) + + def load_weights(self): + if self.arg.weights1: + self.model1 = self.io.load_weights(self.model1, self.arg.weights1, self.arg.ignore_weights) + self.model2 = self.io.load_weights(self.model2, self.arg.weights2, self.arg.ignore_weights) + #self.model3 = self.io.load_weights(self.model3, self.arg.weights3, self.arg.ignore_weights) + + def gpu(self): + # move modules to gpu + self.model1 = self.model1.to(self.dev) + self.model2 = self.model2.to(self.dev) + self.model3 = self.model3.to(self.dev) + for name, value in vars(self).items(): + cls_name = str(value.__class__) + if cls_name.find('torch.nn.modules') != -1: + setattr(self, name, value.to(self.dev)) + + # model parallel + if self.arg.use_gpu and len(self.gpus) > 1: + self.model1 = nn.DataParallel(self.model1, device_ids=self.gpus) + self.model2 = nn.DataParallel(self.model2, device_ids=self.gpus) + + def start(self): + self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg)))) + + @staticmethod + def get_parser(add_help=False): + + #region arguments yapf: disable + # parameter priority: command line > config > default + parser = argparse.ArgumentParser( add_help=add_help, description='IO Processor') + + parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results') + parser.add_argument('-c', '--config', default=None, help='path to the configuration file') + + # processor + parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not') + parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing') + + # visulize and debug + parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not') + parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not') + + # model + parser.add_argument('--model1', default=None, help='the model will be used') + parser.add_argument('--model2', default=None, help='the model will be used') + parser.add_argument('--model1_args', action=DictAction, default=dict(), help='the arguments of model') + parser.add_argument('--model2_args', action=DictAction, default=dict(), help='the arguments of model') + parser.add_argument('--weights', default=None, help='the weights for network initialization') + parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization') + #endregion yapf: enable + + return parser diff --git a/processor/processor.py b/processor/processor.py index 690d846..045c764 100644 --- a/processor/processor.py +++ b/processor/processor.py @@ -1,186 +1,212 @@ -import sys -import argparse -import yaml -import numpy as np - -import torch -import torch.nn as nn -import torch.optim as optim - -import torchlight -from torchlight.io import str2bool -from torchlight.io import DictAction -from torchlight.io import import_class - -from .io import IO - - -class Processor(IO): - - def __init__(self, argv=None): - - self.load_arg(argv) - self.init_environment() - self.load_model() - self.load_weights() - self.gpu() - self.load_data() - - def init_environment(self): - - super().init_environment() - self.result = dict() - self.iter_info = dict() - self.epoch_info = dict() - self.meta_info = dict(epoch=0, iter=0) - - - def load_data(self): - Feeder = import_class(self.arg.feeder) - if 'debug' not in self.arg.train_feeder_args: - self.arg.train_feeder_args['debug'] = self.arg.debug - self.data_loader = dict() - if self.arg.phase == 'train': - self.data_loader['train'] = torch.utils.data.DataLoader(dataset=Feeder(**self.arg.train_feeder_args), - batch_size=self.arg.batch_size, - shuffle=True, - num_workers=self.arg.num_worker, - drop_last=True) - if self.arg.test_feeder_args: - self.data_loader['test'] = torch.utils.data.DataLoader(dataset=Feeder(**self.arg.test_feeder_args), - batch_size=self.arg.test_batch_size, - shuffle=False, - num_workers=self.arg.num_worker) - - def show_epoch_info(self): - for k, v in self.epoch_info.items(): - self.io.print_log('\t{}: {}'.format(k, v)) - if self.arg.pavi_log: - self.io.log('train', self.meta_info['iter'], self.epoch_info) - - def show_iter_info(self): - if self.meta_info['iter'] % self.arg.log_interval == 0: - info ='\tIter {} Done.'.format(self.meta_info['iter']) - for k, v in self.iter_info.items(): - if isinstance(v, float): - info = info + ' | {}: {:.4f}'.format(k, v) - else: - info = info + ' | {}: {}'.format(k, v) - - self.io.print_log(info) - - if self.arg.pavi_log: - self.io.log('train', self.meta_info['iter'], self.iter_info) - - def train(self): - for _ in range(100): - self.iter_info['loss'] = 0 - self.iter_info['loss_class'] = 0 - self.iter_info['loss_recon'] = 0 - self.show_iter_info() - self.meta_info['iter'] += 1 - self.epoch_info['mean_loss'] = 0 - self.epoch_info['mean_loss_class'] = 0 - self.epoch_info['mean_loss_recon'] = 0 - self.show_epoch_info() - - def test(self): - for _ in range(100): - self.iter_info['loss'] = 1 - self.iter_info['loss_class'] = 1 - self.iter_info['loss_recon'] = 1 - self.show_iter_info() - self.epoch_info['mean_loss'] = 1 - self.epoch_info['mean_loss_class'] = 1 - self.epoch_info['mean_loss_recon'] = 1 - self.show_epoch_info() - - def start(self): - self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg)))) - - if self.arg.phase == 'train': - for epoch in range(self.arg.start_epoch, self.arg.num_epoch): - self.meta_info['epoch'] = epoch - - if epoch < 10: - self.io.print_log('Training epoch: {}'.format(epoch)) - self.train(training_A=True) - self.io.print_log('Done.') - else: - self.io.print_log('Training epoch: {}'.format(epoch)) - self.train(training_A=False) - self.io.print_log('Done.') - - # save model - if ((epoch + 1) % self.arg.save_interval == 0) or (epoch + 1 == self.arg.num_epoch): - filename1 = 'epoch{}_model1.pt'.format(epoch) - self.io.save_model(self.model1, filename1) - filename2 = 'epoch{}_model2.pt'.format(epoch) - self.io.save_model(self.model2, filename2) - - # evaluation - if ((epoch + 1) % self.arg.eval_interval == 0) or (epoch + 1 == self.arg.num_epoch): - self.io.print_log('Eval epoch: {}'.format(epoch)) - if epoch <= 10: - self.test(testing_A=True) - else: - self.test(testing_A=False) - self.io.print_log('Done.') - - - elif self.arg.phase == 'test': - if self.arg.weights2 is None: - raise ValueError('Please appoint --weights.') - self.io.print_log('Model: {}.'.format(self.arg.model2)) - self.io.print_log('Weights: {}.'.format(self.arg.weights2)) - - self.io.print_log('Evaluation Start:') - self.test(testing_A=False, save_feature=True) - self.io.print_log('Done.\n') - - if self.arg.save_result: - result_dict = dict( - zip(self.data_loader['test'].dataset.sample_name, - self.result)) - self.io.save_pkl(result_dict, 'test_result.pkl') - - - @staticmethod - def get_parser(add_help=False): - - parser = argparse.ArgumentParser( add_help=add_help, description='Base Processor') - - parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results') - parser.add_argument('-c', '--config', default=None, help='path to the configuration file') - - parser.add_argument('--phase', default='train', help='must be train or test') - parser.add_argument('--save_result', type=str2bool, default=False, help='if ture, the output of the model will be stored') - parser.add_argument('--start_epoch', type=int, default=0, help='start training from which epoch') - parser.add_argument('--num_epoch', type=int, default=80, help='stop training in which epoch') - parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not') - parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing') - - parser.add_argument('--log_interval', type=int, default=100, help='the interval for printing messages (#iteration)') - parser.add_argument('--save_interval', type=int, default=1, help='the interval for storing models (#iteration)') - parser.add_argument('--eval_interval', type=int, default=5, help='the interval for evaluating models (#iteration)') - parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not') - parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not') - parser.add_argument('--pavi_log', type=str2bool, default=False, help='logging on pavi or not') - - parser.add_argument('--feeder', default='feeder.feeder', help='data loader will be used') - parser.add_argument('--num_worker', type=int, default=4, help='the number of worker per gpu for data loader') - parser.add_argument('--train_feeder_args', action=DictAction, default=dict(), help='the arguments of data loader for training') - parser.add_argument('--test_feeder_args', action=DictAction, default=dict(), help='the arguments of data loader for test') - parser.add_argument('--batch_size', type=int, default=256, help='training batch size') - parser.add_argument('--test_batch_size', type=int, default=256, help='test batch size') - parser.add_argument('--debug', action="store_true", help='less data, faster loading') - - parser.add_argument('--model1', default=None, help='the model will be used') - parser.add_argument('--model2', default=None, help='the model will be used') - parser.add_argument('--model1_args', action=DictAction, default=dict(), help='the arguments of model') - parser.add_argument('--model2_args', action=DictAction, default=dict(), help='the arguments of model') - parser.add_argument('--weights1', default=None, help='the weights for network initialization') - parser.add_argument('--weights2', default=None, help='the weights for network initialization') - parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization') - - return parser +import sys +import argparse +import yaml +import numpy as np + +import torch +import torch.nn as nn +import torch.optim as optim +import matplotlib.pyplot as plt +import pickle + +import torchlight +from torchlight.io import str2bool +from torchlight.io import DictAction +from torchlight.io import import_class + +from .io import IO + + +class Processor(IO): + + def __init__(self, argv=None): + + self.load_arg(argv) + self.init_environment() + self.load_model() + #self.load_weights() + self.gpu() + self.load_data() + + def init_environment(self): + + super().init_environment() + self.result = dict() + self.iter_info = dict() + self.epoch_info = dict() + self.meta_info = dict(epoch=0, iter=0) + self.epoch_loss_class_train = [] + self.epoch_loss_class_test = [] + + + def load_data(self): + Feeder = import_class(self.arg.feeder) + if 'debug' not in self.arg.train_feeder_args: + self.arg.train_feeder_args['debug'] = self.arg.debug + self.data_loader = dict() + if self.arg.phase == 'train': + self.data_loader['train'] = torch.utils.data.DataLoader(dataset=Feeder(**self.arg.train_feeder_args), + batch_size=self.arg.batch_size, + shuffle=True, + num_workers=self.arg.num_worker, + drop_last=True) + if self.arg.test_feeder_args: + self.data_loader['test'] = torch.utils.data.DataLoader(dataset=Feeder(**self.arg.test_feeder_args), + batch_size=self.arg.test_batch_size, + shuffle=False, + num_workers=self.arg.num_worker) + + def show_epoch_info(self): + for k, v in self.epoch_info.items(): + self.io.print_log('\t{}: {}'.format(k, v)) + if self.arg.pavi_log: + self.io.log('train', self.meta_info['iter'], self.epoch_info) + + + def show_epoch_curl(self, epoch): + plt.figure() + epoch_x = np.arange(10, len(loss_class_value)) + 1 + plt.plot(epoch_x, loss_class_value[3:], '--', color='C0') + plt.legend(['action_class_train']) + plt.ylabel('action CrossEntropyLoss') + plt.xlabel('Epoch') + plt.xlim((3, epoch)) + plt.savefig(os.path.join('loss_action_class_task.png')) + plt.close() + + def show_iter_info(self): + if self.meta_info['iter'] % self.arg.log_interval == 0: + info ='\tIter {} Done.'.format(self.meta_info['iter']) + for k, v in self.iter_info.items(): + if isinstance(v, float): + info = info + ' | {}: {:.4f}'.format(k, v) + else: + info = info + ' | {}: {}'.format(k, v) + + self.io.print_log(info) + + if self.arg.pavi_log: + self.io.log('train', self.meta_info['iter'], self.iter_info) + + def train(self): + for _ in range(300): + self.iter_info['loss'] = 0 + self.iter_info['loss_class'] = 0 + self.iter_info['loss_recon'] = 0 + self.show_iter_info() + self.meta_info['iter'] += 1 + self.epoch_info['mean_loss'] = 0 + self.epoch_info['mean_loss_class'] = 0 + self.epoch_info['mean_loss_recon'] = 0 + self.show_epoch_info() + + def test(self): + for _ in range(100): + self.iter_info['loss'] = 1 + self.iter_info['loss_class'] = 1 + self.iter_info['loss_recon'] = 1 + self.show_iter_info() + self.epoch_info['mean_loss'] = 1 + self.epoch_info['mean_loss_class'] = 1 + self.epoch_info['mean_loss_recon'] = 1 + self.show_epoch_info() + + def start(self): + self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg)))) + + if self.arg.phase == 'train': + for epoch in range(self.arg.start_epoch, self.arg.num_epoch): + self.meta_info['epoch'] = epoch + + if epoch < 10: + self.io.print_log('Training epoch: {}'.format(epoch)) + self.train(training_A=True) + self.io.print_log('Done.') + else: + self.io.print_log('Training epoch: {}'.format(epoch)) + self.train(training_A=False) + self.io.print_log('Done.') + + # save model + if ((epoch + 1) % self.arg.save_interval == 0) or (epoch + 1 == self.arg.num_epoch): + """ + filename1 = 'epoch{}_model1.pt'.format(epoch) + self.io.save_model(self.model1, filename1) + filename2 = 'epoch{}_model2.pt'.format(epoch) + self.io.save_model(self.model2, filename2) + """ + filename3 = 'epoch{}_model3.pt'.format(epoch) + self.io.save_model(self.model3, filename3) + + with open("epoch_loss_class_train.txt", "w") as outfile: + for item in self.epoch_loss_class_train: + outfile.write("{}: {}\n".format(self.epoch_info, item)) + + # evaluation + if ((epoch + 1) % self.arg.eval_interval == 0) or (epoch + 1 == self.arg.num_epoch): + self.io.print_log('Eval epoch: {}'.format(epoch)) + self.test(testing_A=False) + self.io.print_log('Done.') + + with open("epoch_loss_class_test.txt", "w") as outfile: + for item in self.epoch_loss_class_test: + outfile.write("{}: {}\n".format(self.epoch_info, item)) + + + + elif self.arg.phase == 'test': + if self.arg.weights2 is None: + raise ValueError('Please appoint --weights.') + self.io.print_log('Model: {}.'.format(self.arg.model2)) + self.io.print_log('Weights: {}.'.format(self.arg.weights2)) + + self.io.print_log('Evaluation Start:') + self.test(testing_A=False, save_feature=True) + self.io.print_log('Done.\n') + + if self.arg.save_result: + result_dict = dict( + zip(self.data_loader['test'].dataset.sample_name, + self.result)) + self.io.save_pkl(result_dict, 'test_result.pkl') + + + @staticmethod + def get_parser(add_help=False): + + parser = argparse.ArgumentParser( add_help=add_help, description='Base Processor') + + parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results') + parser.add_argument('-c', '--config', default=None, help='path to the configuration file') + + parser.add_argument('--phase', default='train', help='must be train or test') + parser.add_argument('--save_result', type=str2bool, default=False, help='if ture, the output of the model will be stored') + parser.add_argument('--start_epoch', type=int, default=0, help='start training from which epoch') + parser.add_argument('--num_epoch', type=int, default=80, help='stop training in which epoch') + parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not') + parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing') + + parser.add_argument('--log_interval', type=int, default=100, help='the interval for printing messages (#iteration)') + parser.add_argument('--save_interval', type=int, default=1, help='the interval for storing models (#iteration)') + parser.add_argument('--eval_interval', type=int, default=5, help='the interval for evaluating models (#iteration)') + parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not') + parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not') + parser.add_argument('--pavi_log', type=str2bool, default=False, help='logging on pavi or not') + + parser.add_argument('--feeder', default='feeder.feeder', help='data loader will be used') + parser.add_argument('--num_worker', type=int, default=4, help='the number of worker per gpu for data loader') + parser.add_argument('--train_feeder_args', action=DictAction, default=dict(), help='the arguments of data loader for training') + parser.add_argument('--test_feeder_args', action=DictAction, default=dict(), help='the arguments of data loader for test') + parser.add_argument('--batch_size', type=int, default=256, help='training batch size') + parser.add_argument('--test_batch_size', type=int, default=256, help='test batch size') + parser.add_argument('--debug', action="store_true", help='less data, faster loading') + + parser.add_argument('--model1', default=None, help='the model will be used') + parser.add_argument('--model2', default=None, help='the model will be used') + parser.add_argument('--model1_args', action=DictAction, default=dict(), help='the arguments of model') + parser.add_argument('--model2_args', action=DictAction, default=dict(), help='the arguments of model') + parser.add_argument('--weights1', default=None, help='the weights for network initialization') + parser.add_argument('--weights2', default=None, help='the weights for network initialization') + parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization') + + return parser diff --git a/processor/recognition.py b/processor/recognition.py index 06fb642..a17fe90 100644 --- a/processor/recognition.py +++ b/processor/recognition.py @@ -1,315 +1,376 @@ -import sys -import os -import argparse -import yaml -import numpy as np - -import matplotlib -matplotlib.use('Agg') -import matplotlib.pyplot as plt - -import torch -import torch.nn as nn -import torch.optim as optim - -import torchlight -from torchlight.io import str2bool -from torchlight.io import DictAction -from torchlight.io import import_class - -from .processor import Processor - - -def weights_init(m): - classname = m.__class__.__name__ - if classname.find('Conv1d') != -1: - m.weight.data.normal_(0.0, 0.02) - if m.bias is not None: - m.bias.data.fill_(0) - elif classname.find('Conv2d') != -1: - m.weight.data.normal_(0.0, 0.02) - if m.bias is not None: - m.bias.data.fill_(0) - elif classname.find('BatchNorm') != -1: - m.weight.data.normal_(1.0, 0.02) - m.bias.data.fill_(0) - - -class REC_Processor(Processor): - - def load_model(self): - self.model1 = self.io.load_model(self.arg.model1, **(self.arg.model1_args)) - self.model1.apply(weights_init) - self.model2 = self.io.load_model(self.arg.model2, **(self.arg.model2_args)) - - self.loss_class = nn.CrossEntropyLoss() - self.loss_pred = nn.MSELoss() - self.w_pred = 0.01 - - prior = np.array([0.95, 0.05/2, 0.05/2]) - self.log_prior = torch.FloatTensor(np.log(prior)) - self.log_prior = torch.unsqueeze(torch.unsqueeze(self.log_prior, 0), 0) - - self.load_optimizer() - - def load_optimizer(self): - if self.arg.optimizer == 'SGD': - self.optimizer1 = optim.SGD(params=self.model1.parameters(), - lr=self.arg.base_lr1, - momentum=0.9, - nesterov=self.arg.nesterov, - weight_decay=self.arg.weight_decay) - elif self.arg.optimizer == 'Adam': - self.optimizer1 = optim.Adam(params=self.model1.parameters(), - lr=self.arg.base_lr1, - weight_decay=self.arg.weight_decay) - else: - raise ValueError() - self.optimizer2 = optim.Adam(params=self.model2.parameters(), - lr=self.arg.base_lr2) - - def adjust_lr(self): - if self.arg.optimizer == 'SGD' and self.arg.step: - lr = self.arg.base_lr1 * (0.1**np.sum(self.meta_info['epoch']>= np.array(self.arg.step))) - for param_group in self.optimizer1.param_groups: - param_group['lr'] = lr - self.lr = lr - else: - self.lr = self.arg.base_lr1 - self.lr2 = self.arg.base_lr2 - - def nll_gaussian(self, preds, target, variance, add_const=False): - neg_log_p = ((preds-target)**2/(2*variance)) - if add_const: - const = 0.5*np.log(2*np.pi*variance) - neg_log_p += const - return neg_log_p.sum() / (target.size(0) * target.size(1)) - - def kl_categorical(self, preds, log_prior, num_node, eps=1e-16): - kl_div = preds*(torch.log(preds+eps)-log_prior) - return kl_div.sum()/(num_node*preds.size(0)) - - - def train(self, training_A=False): - - self.model1.train() - self.model2.train() - self.adjust_lr() - loader = self.data_loader['train'] - loss1_value = [] - loss_class_value = [] - loss_recon_value = [] - loss2_value = [] - loss_nll_value = [] - loss_kl_value = [] - - if training_A: - for param1 in self.model1.parameters(): - param1.requires_grad = False - for param2 in self.model2.parameters(): - param2.requires_grad = True - self.iter_info.clear() - self.epoch_info.clear() - - for data, data_downsample, target_data, data_last, label in loader: - data = data.float().to(self.dev) - data_downsample = data_downsample.float().to(self.dev) - label = label.long().to(self.dev) - - gpu_id = data.get_device() - self.log_prior = self.log_prior.cuda(gpu_id) - A_batch, prob, outputs, data_target = self.model2(data_downsample) - loss_nll = self.nll_gaussian(outputs, data_target[:,:,1:,:], variance=5e-4) - loss_kl = self.kl_categorical(prob, self.log_prior, num_node=25) - loss2 = loss_nll + loss_kl - - self.optimizer2.zero_grad() - loss2.backward() - self.optimizer2.step() - - self.iter_info['loss2'] = loss2.data.item() - self.iter_info['loss_nll'] = loss_nll.data.item() - self.iter_info['loss_kl'] = loss_kl.data.item() - self.iter_info['lr'] = '{:.6f}'.format(self.lr2) - - loss2_value.append(self.iter_info['loss2']) - loss_nll_value.append(self.iter_info['loss_nll']) - loss_kl_value.append(self.iter_info['loss_kl']) - self.show_iter_info() - self.meta_info['iter'] += 1 - self.epoch_info['mean_loss2'] = np.mean(loss2_value) - self.epoch_info['mean_loss_nll'] = np.mean(loss_nll_value) - self.epoch_info['mean_loss_kl'] = np.mean(loss_kl_value) - self.show_epoch_info() - self.io.print_timer() - - else: - for param1 in self.model1.parameters(): - param1.requires_grad = True - for param2 in self.model2.parameters(): - param2.requires_grad = True - self.iter_info.clear() - self.epoch_info.clear() - for data, data_downsample, target_data, data_last, label in loader: - data = data.float().to(self.dev) - data_downsample = data_downsample.float().to(self.dev) - target_data = target_data.float().to(self.dev) - data_last = data_last.float().to(self.dev) - label = label.long().to(self.dev) - - A_batch, prob, outputs, _ = self.model2(data_downsample) - x_class, pred, target = self.model1(data, target_data, data_last, A_batch, self.arg.lamda_act) - loss_class = self.loss_class(x_class, label) - loss_recon = self.loss_pred(pred, target) - loss1 = loss_class + self.w_pred*loss_recon - - self.optimizer1.zero_grad() - loss1.backward() - self.optimizer1.step() - - self.iter_info['loss1'] = loss1.data.item() - self.iter_info['loss_class'] = loss_class.data.item() - self.iter_info['loss_recon'] = loss_recon.data.item()*self.w_pred - self.iter_info['lr'] = '{:.6f}'.format(self.lr) - - loss1_value.append(self.iter_info['loss1']) - loss_class_value.append(self.iter_info['loss_class']) - loss_recon_value.append(self.iter_info['loss_recon']) - self.show_iter_info() - self.meta_info['iter'] += 1 - - self.epoch_info['mean_loss1']= np.mean(loss1_value) - self.epoch_info['mean_loss_class'] = np.mean(loss_class_value) - self.epoch_info['mean_loss_recon'] = np.mean(loss_recon_value) - self.show_epoch_info() - self.io.print_timer() - - - def test(self, evaluation=True, testing_A=False, save=False, save_feature=False): - - self.model1.eval() - self.model2.eval() - loader = self.data_loader['test'] - loss1_value = [] - loss_class_value = [] - loss_recon_value = [] - loss2_value = [] - loss_nll_value = [] - loss_kl_value = [] - result_frag = [] - label_frag = [] - - if testing_A: - A_all = [] - self.epoch_info.clear() - for data, data_downsample, target_data, data_last, label in loader: - data = data.float().to(self.dev) - data_downsample = data_downsample.float().to(self.dev) - label = label.long().to(self.dev) - - with torch.no_grad(): - A_batch, prob, outputs, data_bn = self.model2(data_downsample) - - if save: - n = A_batch.size(0) - a = A_batch[:int(n/2),:,:,:].cpu().numpy() - A_all.extend(a) - - if evaluation: - gpu_id = data.get_device() - self.log_prior = self.log_prior.cuda(gpu_id) - loss_nll = self.nll_gaussian(outputs, data_bn[:,:,1:,:], variance=5e-4) - loss_kl = self.kl_categorical(prob, self.log_prior, num_node=25) - loss2 = loss_nll + loss_kl - - loss2_value.append(loss2.item()) - loss_nll_value.append(loss_nll.item()) - loss_kl_value.append(loss_kl.item()) - - if save: - A_all = np.array(A_all) - np.save(os.path.join(self.arg.work_dir, 'test_adj.npy'), A_all) - - if evaluation: - self.epoch_info['mean_loss2'] = np.mean(loss2_value) - self.epoch_info['mean_loss_nll'] = np.mean(loss_nll_value) - self.epoch_info['mean_loss_kl'] = np.mean(loss_kl_value) - self.show_epoch_info() - - else: - recon_data = [] - feature_map = [] - self.epoch_info.clear() - for data, data_downsample, target_data, data_last, label in loader: - data = data.float().to(self.dev) - data_downsample = data_downsample.float().to(self.dev) - target_data = target_data.float().to(self.dev) - data_last = data_last.float().to(self.dev) - label = label.long().to(self.dev) - - with torch.no_grad(): - A_batch, prob, outputs, _ = self.model2(data_downsample) - x_class, pred, target = self.model1(data, target_data, data_last, A_batch, self.arg.lamda_act) - result_frag.append(x_class.data.cpu().numpy()) - - if save: - n = pred.size(0) - p = pred[::2,:,:,:].cpu().numpy() - recon_data.extend(p) - - if evaluation: - loss_class = self.loss_class(x_class, label) - loss_recon = self.loss_pred(pred, target) - loss1 = loss_class + self.w_pred*loss_recon - - loss1_value.append(loss1.item()) - loss_class_value.append(loss_class.item()) - loss_recon_value.append(loss_recon.item()) - label_frag.append(label.data.cpu().numpy()) - - if save: - recon_data = np.array(recon_data) - np.save(os.path.join(self.arg.work_dir, 'recon_data.npy'), recon_data) - - - self.result = np.concatenate(result_frag) - if evaluation: - self.label = np.concatenate(label_frag) - self.epoch_info['mean_loss1'] = np.mean(loss1_value) - self.epoch_info['mean_loss_class'] = np.mean(loss_class_value) - self.epoch_info['mean_loss_recon'] = np.mean(loss_recon_value) - self.show_epoch_info() - - for k in self.arg.show_topk: - hit_top_k = [] - rank = self.result.argsort() - for i,l in enumerate(self.label): - hit_top_k.append(l in rank[i, -k:]) - self.io.print_log('\n') - accuracy = sum(hit_top_k)*1.0/len(hit_top_k) - self.io.print_log('\tTop{}: {:.2f}%'.format(k, 100 * accuracy)) - - - - @staticmethod - def get_parser(add_help=False): - - parent_parser = Processor.get_parser(add_help=False) - parser = argparse.ArgumentParser( - add_help=add_help, - parents=[parent_parser], - description='Spatial Temporal Graph Convolution Network') - - parser.add_argument('--show_topk', type=int, default=[1, 5], nargs='+', help='which Top K accuracy will be shown') - parser.add_argument('--base_lr1', type=float, default=0.1, help='initial learning rate') - parser.add_argument('--base_lr2', type=float, default=0.0005, help='initial learning rate') - parser.add_argument('--step', type=int, default=[], nargs='+', help='the epoch where optimizer reduce the learning rate') - parser.add_argument('--optimizer', default='SGD', help='type of optimizer') - parser.add_argument('--nesterov', type=str2bool, default=True, help='use nesterov or not') - parser.add_argument('--weight_decay', type=float, default=0.0001, help='weight decay for optimizer') - - parser.add_argument('--max_hop_dir', type=str, default='max_hop_4') - parser.add_argument('--lamda_act', type=float, default=0.5) - parser.add_argument('--lamda_act_dir', type=str, default='lamda_05') - - return parser +import sys +import os +import argparse +import yaml +import numpy as np + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.optim as optim + +import torchlight +from torchlight.io import str2bool +from torchlight.io import DictAction +from torchlight.io import import_class + +from .processor import Processor + +from net.model_poseformer import * + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv1d') != -1: + m.weight.data.normal_(0.0, 0.02) + if m.bias is not None: + m.bias.data.fill_(0) + elif classname.find('Conv2d') != -1: + m.weight.data.normal_(0.0, 0.02) + if m.bias is not None: + m.bias.data.fill_(0) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +class REC_Processor(Processor): + + def load_model(self): + self.model1 = self.io.load_model(self.arg.model1, **(self.arg.model1_args)) + self.model1.apply(weights_init) + self.model2 = self.io.load_model(self.arg.model2, **(self.arg.model2_args)) + self.model3 = PoseTransformer(num_frame=290, num_joints=25, in_chans=2, embed_dim_ratio=32, depth=4, + num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0) + + self.loss_class = nn.CrossEntropyLoss() + self.loss_pred = nn.MSELoss() + self.w_pred = 0.01 + + prior = np.array([0.95, 0.05/2, 0.05/2]) + self.log_prior = torch.FloatTensor(np.log(prior)) + self.log_prior = torch.unsqueeze(torch.unsqueeze(self.log_prior, 0), 0) + + self.load_optimizer() + + def load_optimizer(self): + if self.arg.optimizer == 'SGD': + self.optimizer1 = optim.SGD(params=self.model3.parameters(), + lr=self.arg.base_lr1, + momentum=0.9, + nesterov=self.arg.nesterov, + weight_decay=self.arg.weight_decay) + elif self.arg.optimizer == 'Adam': + self.optimizer1 = optim.Adam(params=self.model3.parameters(), + lr=self.arg.base_lr1, + weight_decay=self.arg.weight_decay) + + def adjust_lr(self): + if self.arg.optimizer == 'SGD' and self.arg.step: + lr = self.arg.base_lr1 * (0.1**np.sum(self.meta_info['epoch']>= np.array(self.arg.step))) + for param_group in self.optimizer1.param_groups: + param_group['lr'] = lr + self.lr = lr + else: + self.lr = self.arg.base_lr1 + self.lr2 = self.arg.base_lr2 + + def nll_gaussian(self, preds, target, variance, add_const=False): + neg_log_p = ((preds-target)**2/(2*variance)) + if add_const: + const = 0.5*np.log(2*np.pi*variance) + neg_log_p += const + return neg_log_p.sum() / (target.size(0) * target.size(1)) + + def kl_categorical(self, preds, log_prior, num_node, eps=1e-16): + kl_div = preds*(torch.log(preds+eps)-log_prior) + return kl_div.sum()/(num_node*preds.size(0)) + + + def train(self, training_A=False): + self.model3.train() + self.adjust_lr() + loader = self.data_loader['train'] + loss_class_value = [] + loss2_value = [] + loss_nll_value = [] + loss_kl_value = [] + + if training_A: + for param1 in self.model1.parameters(): + param1.requires_grad = False + for param2 in self.model2.parameters(): + param2.requires_grad = True + self.iter_info.clear() + self.epoch_info.clear() + + for data, data_downsample, target_data, data_last, label in loader: + # data: (32,3,290,25,2) data_downsample:(32,3,50,25,2) target_data:(32,3,10,25,2) data_last:(32,3,1,25,2) label:(32) + data = data.float().to(self.dev) + data_downsample = data_downsample.float().to(self.dev) + label = label.long().to(self.dev) + + gpu_id = data.get_device() + self.log_prior = self.log_prior.cuda(gpu_id) + A_batch, prob, outputs, data_target = self.model2(data_downsample) + loss_nll = self.nll_gaussian(outputs, data_target[:,:,1:,:], variance=5e-4) + loss_kl = self.kl_categorical(prob, self.log_prior, num_node=25) + loss2 = loss_nll + loss_kl + + self.optimizer2.zero_grad() + loss2.backward() + self.optimizer2.step() + + self.iter_info['loss2'] = loss2.data.item() + self.iter_info['loss_nll'] = loss_nll.data.item() + self.iter_info['loss_kl'] = loss_kl.data.item() + self.iter_info['lr'] = '{:.6f}'.format(self.lr2) + + loss2_value.append(self.iter_info['loss2']) + loss_nll_value.append(self.iter_info['loss_nll']) + loss_kl_value.append(self.iter_info['loss_kl']) + self.show_iter_info() + self.meta_info['iter'] += 1 + break + self.epoch_info['mean_loss2'] = np.mean(loss2_value) + self.epoch_info['mean_loss_nll'] = np.mean(loss_nll_value) + self.epoch_info['mean_loss_kl'] = np.mean(loss_kl_value) + self.show_epoch_info() + self.io.print_timer() + + else: + ''' + for param1 in self.model1.parameters(): + param1.requires_grad = True + for param2 in self.model2.parameters(): + param2.requires_grad = True + ''' + for param3 in self.model3.parameters(): + param3.requires_grad = True + + self.iter_info.clear() + self.epoch_info.clear() + for data, data_downsample, target_data, data_last, label in loader: + data = data.float().to(self.dev) + #data_downsample = data_downsample.float().to(self.dev) + target_data = target_data.float().to(self.dev) + #data_last = data_last.float().to(self.dev) + label = label.long().to(self.dev) + + #A_batch, prob, outputs, _ = self.model2(data_downsample) + + # wsx model2 viz + ''' + import tensorwatch as tw + import torchvision.models + alexnet_model = torchvision.models.alexnet() + img = tw.draw_model(alexnet_model, [1, 3, 1024, 1024]) + img.save(r'img.jpg') + + from torchviz import make_dot + make_dot((A_batch), params=dict(list(self.model2.named_parameters()))).render("modle2", format="png") + + import hiddenlayer as h + vis_graph = h.build_graph(self.model2, torch.zeros([4,3,50,25,2])) + vis_graph.theme = h.graph.THEMES["blue"].copy() + vis_graph.save("./hl_model2.png") + + viz = self.model2.to(self.dev) + import torch + torch.save(viz, './model2.pth') + ''' + + # workingwsx + #x_class, pred, target = self.model1(data, target_data, data_last, A_batch, self.arg.lamda_act) + x_class, target = self.model3(data, target_data) + loss_class = self.loss_class(x_class, label) + #loss_recon = self.loss_pred(pred, target) + #loss1 = loss_class + + self.optimizer1.zero_grad() + loss_class.backward() + self.optimizer1.step() + + #self.iter_info['loss1'] = loss1.data.item() + self.iter_info['loss_class'] = loss_class.data.item() + #self.iter_info['loss_recon'] = loss_recon.data.item()*self.w_pred + self.iter_info['lr'] = '{:.6f}'.format(self.lr) + + #loss1_value.append(self.iter_info['loss1']) + loss_class_value.append(self.iter_info['loss_class']) + #loss_recon_value.append(self.iter_info['loss_recon']) + self.show_iter_info() + self.meta_info['iter'] += 1 + #break # breakwsx + + #self.epoch_info['mean_loss1']= np.mean(loss1_value) + self.epoch_info['mean_loss_class'] = np.mean(loss_class_value) + self.epoch_loss_class_train.append(np.mean(loss_class_value)) + #self.epoch_info['mean_loss_recon'] = np.mean(loss_recon_value) + self.show_epoch_info() + self.io.print_timer() + + # show train curve + plt.figure() + epoch_x = np.arange(0, len(self.epoch_loss_class_train)) + 1 + plt.plot(epoch_x, self.epoch_loss_class_train[0:], '--', color='C0') + plt.legend(['epoch_loss_class_train']) + plt.ylabel('action CrossEntropyLoss') + plt.xlabel('Epoch') + plt.xlim(0, self.meta_info['epoch']) + plt.savefig('epoch_loss_class_train.png') + plt.close() + + + + + + def test(self, evaluation=True, testing_A=False, save=False, save_feature=False): + + #self.model1.eval() + #self.model2.eval() + self.model3.eval() + loader = self.data_loader['test'] + loss1_value = [] + loss_class_value = [] + loss_recon_value = [] + loss2_value = [] + loss_nll_value = [] + loss_kl_value = [] + result_frag = [] + label_frag = [] + + if testing_A: + A_all = [] + self.epoch_info.clear() + for data, data_downsample, target_data, data_last, label in loader: + data = data.float().to(self.dev) + data_downsample = data_downsample.float().to(self.dev) + label = label.long().to(self.dev) + + with torch.no_grad(): + A_batch, prob, outputs, data_bn = self.model2(data_downsample) + + if save: + n = A_batch.size(0) + a = A_batch[:int(n/2),:,:,:].cpu().numpy() + A_all.extend(a) + + if evaluation: + gpu_id = data.get_device() + self.log_prior = self.log_prior.cuda(gpu_id) + loss_nll = self.nll_gaussian(outputs, data_bn[:,:,1:,:], variance=5e-4) + loss_kl = self.kl_categorical(prob, self.log_prior, num_node=25) + loss2 = loss_nll + loss_kl + + loss2_value.append(loss2.item()) + loss_nll_value.append(loss_nll.item()) + loss_kl_value.append(loss_kl.item()) + + break + + if save: + A_all = np.array(A_all) + np.save(os.path.join(self.arg.work_dir, 'test_adj.npy'), A_all) + + if evaluation: + self.epoch_info['mean_loss2'] = np.mean(loss2_value) + self.epoch_info['mean_loss_nll'] = np.mean(loss_nll_value) + self.epoch_info['mean_loss_kl'] = np.mean(loss_kl_value) + self.show_epoch_info() + + else: + recon_data = [] + feature_map = [] + self.epoch_info.clear() + for data, data_downsample, target_data, data_last, label in loader: + data = data.float().to(self.dev) + #data_downsample = data_downsample.float().to(self.dev) + target_data = target_data.float().to(self.dev) + #data_last = data_last.float().to(self.dev) + label = label.long().to(self.dev) + + with torch.no_grad(): + #A_batch, prob, outputs, _ = self.model2(data_downsample) + #x_class, pred, target = self.model1(data, target_data, data_last, A_batch, self.arg.lamda_act) + x_class, target = self.model3(data, target_data) + result_frag.append(x_class.data.cpu().numpy()) + + """ + if save: + n = pred.size(0) + p = pred[::2,:,:,:].cpu().numpy() + recon_data.extend(p) + """ + + if evaluation: + loss_class = self.loss_class(x_class, label) + #loss_recon = self.loss_pred(pred, target) + #loss1 = loss_class + self.w_pred*loss_recon + + #loss1_value.append(loss1.item()) + loss_class_value.append(loss_class.item()) + #loss_recon_value.append(loss_recon.item()) + label_frag.append(label.data.cpu().numpy()) + #break #breakwsx + """ + if save: + recon_data = np.array(recon_data) + np.save(os.path.join(self.arg.work_dir, 'recon_data.npy'), recon_data) + """ + + self.result = np.concatenate(result_frag) + + if evaluation: + self.label = np.concatenate(label_frag) + #self.epoch_info['mean_loss1'] = np.mean(loss1_value) + self.epoch_info['mean_loss_class'] = np.mean(loss_class_value) + self.epoch_loss_class_test.append(np.mean(loss_class_value)) + #self.epoch_info['mean_loss_recon'] = np.mean(loss_recon_value) + self.show_epoch_info() + + for k in self.arg.show_topk: + hit_top_k = [] + rank = self.result.argsort() + for i,l in enumerate(self.label): + hit_top_k.append(l in rank[i, -k:]) + self.io.print_log('\n') + accuracy = sum(hit_top_k)*1.0/len(hit_top_k) + self.io.print_log('\tTop{}: {:.2f}%'.format(k, 100 * accuracy)) + + # wsx test curve + plt.figure() + epoch_x = np.arange(0, len(self.epoch_loss_class_test)) + 1 + plt.plot(epoch_x, self.epoch_loss_class_test[0:], '--', color='C1') + plt.legend(['epoch_loss_class_eval']) + plt.ylabel('action CrossEntropyLoss') + plt.xlabel('Epoch') + plt.xlim(0, len(self.epoch_loss_class_test)) + plt.savefig('epoch_loss_class_eval.png') + plt.close() + + + @staticmethod + def get_parser(add_help=False): + + parent_parser = Processor.get_parser(add_help=False) + parser = argparse.ArgumentParser( + add_help=add_help, + parents=[parent_parser], + description='Spatial Temporal Graph Convolution Network') + + parser.add_argument('--show_topk', type=int, default=[1, 5], nargs='+', help='which Top K accuracy will be shown') + parser.add_argument('--base_lr1', type=float, default=0.1, help='initial learning rate') + parser.add_argument('--base_lr2', type=float, default=0.0005, help='initial learning rate') + parser.add_argument('--step', type=int, default=[], nargs='+', help='the epoch where optimizer reduce the learning rate') + parser.add_argument('--optimizer', default='SGD', help='type of optimizer') + parser.add_argument('--nesterov', type=str2bool, default=True, help='use nesterov or not') + parser.add_argument('--weight_decay', type=float, default=0.0001, help='weight decay for optimizer') + + parser.add_argument('--max_hop_dir', type=str, default='max_hop_4') + parser.add_argument('--lamda_act', type=float, default=0.5) + parser.add_argument('--lamda_act_dir', type=str, default='lamda_05') + + return parser diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000..4dd2d1d --- /dev/null +++ b/requirement.txt @@ -0,0 +1,327 @@ +# This file may be used to create an environment using: +# $ conda create --name --file +# platform: linux-64 +_ipyw_jlab_nb_ext_conf=0.1.0=py38_0 +_libgcc_mutex=0.1=main +alabaster=0.7.12=py_0 +anaconda=2020.11=py38_0 +anaconda-client=1.7.2=py38_0 +anaconda-navigator=1.10.0=py38_0 +anaconda-project=0.8.4=py_0 +argh=0.26.2=py38_0 +argon2-cffi=20.1.0=py38h7b6447c_1 +asn1crypto=1.4.0=py_0 +astroid=2.4.2=py38_0 +astropy=4.0.2=py38h7b6447c_0 +async_generator=1.10=py_0 +atomicwrites=1.4.0=py_0 +attrs=20.3.0=pyhd3eb1b0_0 +autopep8=1.5.4=py_0 +babel=2.8.1=pyhd3eb1b0_0 +backcall=0.2.0=py_0 +backports=1.0=py_2 +backports.functools_lru_cache=1.6.4=pyhd3eb1b0_0 +backports.shutil_get_terminal_size=1.0.0=py38_2 +backports.tempfile=1.0=pyhd3eb1b0_1 +backports.weakref=1.0.post1=py_1 +beautifulsoup4=4.9.3=pyhb0f4dca_0 +bitarray=1.6.1=py38h27cfd23_0 +bkcharts=0.2=py38_0 +blas=1.0=mkl +bleach=3.2.1=py_0 +blosc=1.20.1=hd408876_0 +bokeh=2.2.3=py38_0 +boto=2.49.0=py38_0 +bottleneck=1.3.2=py38heb32a55_1 +brotlipy=0.7.0=py38h7b6447c_1000 +bzip2=1.0.8=h7b6447c_0 +ca-certificates=2020.10.14=0 +cairo=1.14.12=h8948797_3 +certifi=2020.6.20=pyhd3eb1b0_3 +cffi=1.14.3=py38he30daa8_0 +chardet=3.0.4=py38_1003 +click=7.1.2=py_0 +cloudpickle=1.6.0=py_0 +clyent=1.2.2=py38_1 +colorama=0.4.4=py_0 +conda=4.10.1=py38h06a4308_1 +conda-build=3.20.5=py38_1 +conda-env=2.6.0=1 +conda-package-handling=1.7.3=py38h27cfd23_1 +conda-verify=3.4.2=py_1 +contextlib2=0.6.0.post1=py_0 +cryptography=3.1.1=py38h1ba5d50_0 +curl=7.71.1=hbc83047_1 +cycler=0.10.0=py38_0 +cython=0.29.21=py38he6710b0_0 +cytoolz=0.11.0=py38h7b6447c_0 +dask=2.30.0=py_0 +dask-core=2.30.0=py_0 +dbus=1.13.18=hb2f20db_0 +decorator=4.4.2=py_0 +defusedxml=0.6.0=py_0 +diff-match-patch=20200713=py_0 +distributed=2.30.1=py38h06a4308_0 +docutils=0.16=py38_1 +entrypoints=0.3=py38_0 +et_xmlfile=1.0.1=py_1001 +expat=2.2.10=he6710b0_2 +fastcache=1.1.0=py38h7b6447c_0 +filelock=3.0.12=py_0 +flake8=3.8.4=py_0 +flask=1.1.2=py_0 +fontconfig=2.13.0=h9420a91_0 +freetype=2.10.4=h5ab3b9f_0 +fribidi=1.0.10=h7b6447c_0 +fsspec=0.8.3=py_0 +future=0.18.2=py38_1 +get_terminal_size=1.0.0=haa9412d_0 +gevent=20.9.0=py38h7b6447c_0 +glib=2.66.1=h92f7085_0 +glob2=0.7=py_0 +gmp=6.1.2=h6c8ec71_1 +gmpy2=2.0.8=py38hd5f6e3b_3 +graphite2=1.3.14=h23475e2_0 +greenlet=0.4.17=py38h7b6447c_0 +gst-plugins-base=1.14.0=hbbd80ab_1 +gstreamer=1.14.0=hb31296c_0 +h5py=2.10.0=py38h7918eee_0 +harfbuzz=2.4.0=hca77d97_1 +hdf5=1.10.4=hb1b8bf9_0 +heapdict=1.0.1=py_0 +html5lib=1.1=py_0 +icu=58.2=he6710b0_3 +idna=2.10=py_0 +imageio=2.9.0=py_0 +imagesize=1.2.0=py_0 +importlib-metadata=2.0.0=py_1 +importlib_metadata=2.0.0=1 +iniconfig=1.1.1=py_0 +intel-openmp=2020.2=254 +intervaltree=3.1.0=py_0 +ipykernel=5.3.4=py38h5ca1d4c_0 +ipython=7.19.0=py38hb070fc8_0 +ipython_genutils=0.2.0=py38_0 +ipywidgets=7.5.1=py_1 +isort=5.6.4=py_0 +itsdangerous=1.1.0=py_0 +jbig=2.1=hdba287a_0 +jdcal=1.4.1=py_0 +jedi=0.17.1=py38_0 +jeepney=0.5.0=pyhd3eb1b0_0 +jinja2=2.11.2=py_0 +joblib=0.17.0=py_0 +jpeg=9b=h024ee3a_2 +json5=0.9.5=py_0 +jsonschema=3.2.0=py_2 +jupyter=1.0.0=py38_7 +jupyter_client=6.1.7=py_0 +jupyter_console=6.2.0=py_0 +jupyter_core=4.6.3=py38_0 +jupyterlab=2.2.6=py_0 +jupyterlab_pygments=0.1.2=py_0 +jupyterlab_server=1.2.0=py_0 +keyring=21.4.0=py38_1 +kiwisolver=1.3.0=py38h2531618_0 +krb5=1.18.2=h173b8e3_0 +lazy-object-proxy=1.4.3=py38h7b6447c_0 +lcms2=2.11=h396b838_0 +ld_impl_linux-64=2.33.1=h53a641e_7 +libarchive=3.4.2=h62408e4_0 +libcurl=7.71.1=h20c2e04_1 +libedit=3.1.20191231=h14c3975_1 +libffi=3.3=he6710b0_2 +libgcc-ng=9.1.0=hdf63c60_0 +libgfortran-ng=7.3.0=hdf63c60_0 +liblief=0.10.1=he6710b0_0 +libllvm10=10.0.1=hbcb73fb_5 +libpng=1.6.37=hbc83047_0 +libsodium=1.0.18=h7b6447c_0 +libspatialindex=1.9.3=he6710b0_0 +libssh2=1.9.0=h1ba5d50_1 +libstdcxx-ng=9.1.0=hdf63c60_0 +libtiff=4.1.0=h2733197_1 +libtool=2.4.6=h7b6447c_1005 +libuuid=1.0.3=h1bed415_2 +libxcb=1.14=h7b6447c_0 +libxml2=2.9.10=hb55368b_3 +libxslt=1.1.34=hc22bd24_0 +llvmlite=0.34.0=py38h269e1b5_4 +locket=0.2.0=py38_1 +lxml=4.6.1=py38hefd8a0e_0 +lz4-c=1.9.2=heb0550a_3 +lzo=2.10=h7b6447c_2 +markupsafe=1.1.1=py38h7b6447c_0 +matplotlib=3.3.2=0 +matplotlib-base=3.3.2=py38h817c723_0 +mccabe=0.6.1=py38_1 +mistune=0.8.4=py38h7b6447c_1000 +mkl=2020.2=256 +mkl-service=2.3.0=py38he904b0f_0 +mkl_fft=1.2.0=py38h23d657b_0 +mkl_random=1.1.1=py38h0573a6f_0 +mock=4.0.2=py_0 +more-itertools=8.6.0=pyhd3eb1b0_0 +mpc=1.1.0=h10f8cd9_1 +mpfr=4.0.2=hb69a4c5_1 +mpmath=1.1.0=py38_0 +msgpack-python=1.0.0=py38hfd86e86_1 +multipledispatch=0.6.0=py38_0 +navigator-updater=0.2.1=py38_0 +nbclient=0.5.1=py_0 +nbconvert=6.0.7=py38_0 +nbformat=5.0.8=py_0 +ncurses=6.2=he6710b0_1 +nest-asyncio=1.4.2=pyhd3eb1b0_0 +networkx=2.5=py_0 +nltk=3.5=py_0 +nose=1.3.7=py38_2 +notebook=6.1.4=py38_0 +numba=0.51.2=py38h0573a6f_1 +numexpr=2.7.1=py38h423224d_0 +numpy=1.19.2=py38h54aff64_0 +numpy-base=1.19.2=py38hfa32c7d_0 +numpydoc=1.1.0=pyhd3eb1b0_1 +olefile=0.46=py_0 +openpyxl=3.0.5=py_0 +openssl=1.1.1h=h7b6447c_0 +packaging=20.4=py_0 +pandas=1.1.3=py38he6710b0_0 +pandoc=2.11=hb0f4dca_0 +pandocfilters=1.4.3=py38h06a4308_1 +pango=1.45.3=hd140c19_0 +parso=0.7.0=py_0 +partd=1.1.0=py_0 +patchelf=0.12=he6710b0_0 +path=15.0.0=py38_0 +path.py=12.5.0=0 +pathlib2=2.3.5=py38_0 +pathtools=0.1.2=py_1 +patsy=0.5.1=py38_0 +pcre=8.44=he6710b0_0 +pep8=1.7.1=py38_0 +pexpect=4.8.0=py38_0 +pickleshare=0.7.5=py38_1000 +pillow=8.0.1=py38he98fc37_0 +pip=21.1=pypi_0 +pixman=0.40.0=h7b6447c_0 +pkginfo=1.6.1=py38h06a4308_0 +pluggy=0.13.1=py38_0 +ply=3.11=py38_0 +prometheus_client=0.8.0=py_0 +prompt-toolkit=3.0.8=py_0 +prompt_toolkit=3.0.8=0 +psutil=5.7.2=py38h7b6447c_0 +ptyprocess=0.6.0=py38_0 +py=1.9.0=py_0 +py-lief=0.10.1=py38h403a769_0 +pycodestyle=2.6.0=py_0 +pycosat=0.6.3=py38h7b6447c_1 +pycparser=2.20=py_2 +pycurl=7.43.0.6=py38h1ba5d50_0 +pydocstyle=5.1.1=py_0 +pyflakes=2.2.0=py_0 +pygments=2.7.2=pyhd3eb1b0_0 +pylint=2.6.0=py38_0 +pyodbc=4.0.30=py38he6710b0_0 +pyopenssl=19.1.0=py_1 +pyparsing=2.4.7=py_0 +pyqt=5.9.2=py38h05f1152_4 +pyrsistent=0.17.3=py38h7b6447c_0 +pysocks=1.7.1=py38_0 +pytables=3.6.1=py38h9fd0a39_0 +pytest=6.1.1=py38_0 +python=3.8.5=h7579374_1 +python-dateutil=2.8.1=py_0 +python-jsonrpc-server=0.4.0=py_0 +python-language-server=0.35.1=py_0 +python-libarchive-c=2.9=py_0 +pytz=2020.1=py_0 +pywavelets=1.1.1=py38h7b6447c_2 +pyxdg=0.27=pyhd3eb1b0_0 +pyyaml=5.3.1=py38h7b6447c_1 +pyzmq=19.0.2=py38he6710b0_1 +qdarkstyle=2.8.1=py_0 +qt=5.9.7=h5867ecd_1 +qtawesome=1.0.1=py_0 +qtconsole=4.7.7=py_0 +qtpy=1.9.0=py_0 +readline=8.0=h7b6447c_0 +regex=2020.10.15=py38h7b6447c_0 +requests=2.24.0=py_0 +ripgrep=12.1.1=0 +rope=0.18.0=py_0 +rtree=0.9.4=py38_1 +ruamel_yaml=0.15.87=py38h7b6447c_1 +scikit-image=0.17.2=py38hdf5156a_0 +scikit-learn=0.23.2=py38h0573a6f_0 +scipy=1.5.2=py38h0b6359f_0 +seaborn=0.11.0=py_0 +secretstorage=3.1.2=py38_0 +send2trash=1.5.0=py38_0 +setuptools=50.3.1=py38h06a4308_1 +simplegeneric=0.8.1=py38_2 +singledispatch=3.4.0.3=py_1001 +sip=4.19.13=py38he6710b0_0 +six=1.15.0=py38h06a4308_0 +snowballstemmer=2.0.0=py_0 +sortedcollections=1.2.1=py_0 +sortedcontainers=2.2.2=py_0 +soupsieve=2.0.1=py_0 +sphinx=3.2.1=py_0 +sphinxcontrib=1.0=py38_1 +sphinxcontrib-applehelp=1.0.2=py_0 +sphinxcontrib-devhelp=1.0.2=py_0 +sphinxcontrib-htmlhelp=1.0.3=py_0 +sphinxcontrib-jsmath=1.0.1=py_0 +sphinxcontrib-qthelp=1.0.3=py_0 +sphinxcontrib-serializinghtml=1.1.4=py_0 +sphinxcontrib-websupport=1.2.4=py_0 +spyder=4.1.5=py38_0 +spyder-kernels=1.9.4=py38_0 +sqlalchemy=1.3.20=py38h7b6447c_0 +sqlite=3.33.0=h62c20be_0 +statsmodels=0.12.0=py38h7b6447c_0 +sympy=1.6.2=py38h06a4308_1 +tbb=2020.3=hfd86e86_0 +tblib=1.7.0=py_0 +terminado=0.9.1=py38_0 +testpath=0.4.4=py_0 +threadpoolctl=2.1.0=pyh5ca1d4c_0 +tifffile=2020.10.1=py38hdd07704_2 +tk=8.6.10=hbc83047_0 +toml=0.10.1=py_0 +toolz=0.11.1=py_0 +torch=1.8.1=pypi_0 +torchvision=0.9.1=pypi_0 +tornado=6.0.4=py38h7b6447c_1 +tqdm=4.50.2=py_0 +traitlets=5.0.5=py_0 +typing_extensions=3.7.4.3=py_0 +ujson=4.0.1=py38he6710b0_0 +unicodecsv=0.14.1=py38_0 +unixodbc=2.3.9=h7b6447c_0 +urllib3=1.25.11=py_0 +watchdog=0.10.3=py38_0 +wcwidth=0.2.5=py_0 +webencodings=0.5.1=py38_1 +werkzeug=1.0.1=py_0 +wheel=0.35.1=py_0 +widgetsnbextension=3.5.1=py38_0 +wrapt=1.11.2=py38h7b6447c_0 +wurlitzer=2.0.1=py38_0 +xlrd=1.2.0=py_0 +xlsxwriter=1.3.7=py_0 +xlwt=1.3.0=py38_0 +xmltodict=0.12.0=py_0 +xz=5.2.5=h7b6447c_0 +yaml=0.2.5=h7b6447c_0 +yapf=0.30.0=py_0 +zeromq=4.3.3=he6710b0_3 +zict=2.0.0=py_0 +zipp=3.4.0=pyhd3eb1b0_0 +zlib=1.2.11=h7b6447c_3 +zope=1.0=py38_1 +zope.event=4.5.0=py38_0 +zope.interface=5.1.2=py38h7b6447c_0 +zstd=1.4.5=h9ceee32_0 diff --git a/torchlight/__init__.py b/torchlight/__init__.py index 07e70f1..5e0e7b9 100644 --- a/torchlight/__init__.py +++ b/torchlight/__init__.py @@ -1,8 +1,8 @@ -from .io import IO -from .io import str2bool -from .io import str2dict -from .io import DictAction -from .io import import_class -from .gpu import visible_gpu -from .gpu import occupy_gpu -from .gpu import ngpu +from .io import IO +from .io import str2bool +from .io import str2dict +from .io import DictAction +from .io import import_class +from .gpu import visible_gpu +from .gpu import occupy_gpu +from .gpu import ngpu diff --git a/torchlight/__pycache__/__init__.cpython-36.pyc b/torchlight/__pycache__/__init__.cpython-36.pyc index 9f75a6bb9658ea632019f7c35cb9c2b299d1129d..2765d62168f7da1034b01435ed8fb2548d0b46b3 100644 GIT binary patch delta 51 zcmeyt)XvOl%*)HA{dGgam5rQQj50?0Mfv$9`UREA8Hq)?@dZWsS*gh-`iaHq$$66< G7%c$7dJ)e6 delta 36 rcmZo?{=vj)%*)Hg6f`yA%tlTvMqXw8qWt_4eaB#3cW1xJ{)`p?wKfUm diff --git a/torchlight/__pycache__/io.cpython-36.pyc b/torchlight/__pycache__/io.cpython-36.pyc index 4b8a85cc437e91d0559b5d5b476091e9e1f1731d..70b62d77b24d4b1eb1611948c0d411b705c30098 100644 GIT binary patch delta 666 zcmYMy&uSA<6bA5|nA9{OW7;O6u_a(ZYp4*VU<=Y};~!c<1xr=Lb}~%TblS;GdS;R| zbx{yCYAT5L1$+c`d;oWXD|fGf857P8GL>Iwjumo;{A)o^TpzFVOHSuWVEtSM<`sC>?q&BDn2e^tt zlpCnG0p`LpfW?dSDz(Iq^fv8?a^`V#QN>bx&s-hYQUyX=|7E>7JC z%fZj=7ExE+8~vJUBRPJs))Uuqle8|Hx!t`Dq!!o&Y4B1Zle%u&hUb}8+t&5}Ln#nH z$EN;lx~h7w9xtMZkrrhng*^M}#n-_ycnA^-8EYHtSWeC3d3@#*rfb{kPqAFb<1n_Z lAdfoDdZQ!7{!w@G64C+C0y_Y==TU_mSm+<-u9z6R^A{B;rz!vd delta 633 zcmX}oL1@!p6bJCWn9QyrwOzBiPB%fAvj>@D;AG5coLfOem_yiNM~!J#W7A~5WUDR6 zpvahlr-vZ!;K8G(k%L#ogIB%g=Ed9YG7$Y<4<5cB@BQERy~Fo%_V`L>M$-!ho-uanr)7r=zY4G(uMG{t%$%6x zP39I;VeT_|<^i+8@E)u(Hcg*7O&<#fU{FQBc)7vhW5#8kFik;OL67SeV=%LfPw({& zK>DLUg@}UTJ!M(MCLIsohAu4?pTjnNEl#UDe7Vi6#D9xTfF3nRe&#|>CjQU%s9KtY zNA#l9-K}%#Fq@3dycQ%{t!8^x7}}2K)oM7z1EvhP&vk+62|kMx%7Qdxm3~ql$1Tz0 z9cG2uWmG|uZd>TPej~(d+>Wbu;CYT6xq%;|#Mf2@rs8k6wmK#qj0Tg7oDML^L`;_% P5hMdf|1tK&53l_Ln<|_C diff --git a/torchlight/build/lib/torchlight/__init__.py b/torchlight/build/lib/torchlight/__init__.py index 8b13789..d3f5a12 100644 --- a/torchlight/build/lib/torchlight/__init__.py +++ b/torchlight/build/lib/torchlight/__init__.py @@ -1 +1 @@ - + diff --git a/torchlight/build/lib/torchlight/gpu.py b/torchlight/build/lib/torchlight/gpu.py index 306c391..e086d4c 100644 --- a/torchlight/build/lib/torchlight/gpu.py +++ b/torchlight/build/lib/torchlight/gpu.py @@ -1,35 +1,35 @@ -import os -import torch - - -def visible_gpu(gpus): - """ - set visible gpu. - - can be a single id, or a list - - return a list of new gpus ids - """ - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) - return list(range(len(gpus))) - - -def ngpu(gpus): - """ - count how many gpus used. - """ - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - return len(gpus) - - -def occupy_gpu(gpus=None): - """ - make program appear on nvidia-smi. - """ - if gpus is None: - torch.zeros(1).cuda() - else: - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - for g in gpus: - torch.zeros(1).cuda(g) +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/torchlight/build/lib/torchlight/io.py b/torchlight/build/lib/torchlight/io.py index c753ca1..5b43720 100644 --- a/torchlight/build/lib/torchlight/io.py +++ b/torchlight/build/lib/torchlight/io.py @@ -1,203 +1,203 @@ -#!/usr/bin/env python -import argparse -import os -import sys -import traceback -import time -import warnings -import pickle -from collections import OrderedDict -import yaml -import numpy as np -# torch -import torch -import torch.nn as nn -import torch.optim as optim -from torch.autograd import Variable - -with warnings.catch_warnings(): - warnings.filterwarnings("ignore",category=FutureWarning) - import h5py - -class IO(): - def __init__(self, work_dir, save_log=True, print_log=True): - self.work_dir = work_dir - self.save_log = save_log - self.print_to_screen = print_log - self.cur_time = time.time() - self.split_timer = {} - self.pavi_logger = None - self.session_file = None - self.model_text = '' - - # PaviLogger is removed in this version - def log(self, *args, **kwargs): - pass - # try: - # if self.pavi_logger is None: - # from torchpack.runner.hooks import PaviLogger - # url = 'http://pavi.parrotsdnn.org/log' - # with open(self.session_file, 'r') as f: - # info = dict( - # session_file=self.session_file, - # session_text=f.read(), - # model_text=self.model_text) - # self.pavi_logger = PaviLogger(url) - # self.pavi_logger.connect(self.work_dir, info=info) - # self.pavi_logger.log(*args, **kwargs) - # except: #pylint: disable=W0702 - # pass - - def load_model(self, model, **model_args): - Model = import_class(model) - model = Model(**model_args) - self.model_text += '\n\n' + str(model) - return model - - def load_weights(self, model, weights_path, ignore_weights=None): - if ignore_weights is None: - ignore_weights = [] - if isinstance(ignore_weights, str): - ignore_weights = [ignore_weights] - - self.print_log('Load weights from {}.'.format(weights_path)) - weights = torch.load(weights_path) - weights = OrderedDict([[k.split('module.')[-1], - v.cpu()] for k, v in weights.items()]) - - # filter weights - for i in ignore_weights: - ignore_name = list() - for w in weights: - if w.find(i) == 0: - ignore_name.append(w) - for n in ignore_name: - weights.pop(n) - self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) - - for w in weights: - self.print_log('Load weights [{}].'.format(w)) - - try: - model.load_state_dict(weights) - except (KeyError, RuntimeError): - state = model.state_dict() - diff = list(set(state.keys()).difference(set(weights.keys()))) - for d in diff: - self.print_log('Can not find weights [{}].'.format(d)) - state.update(weights) - model.load_state_dict(state) - return model - - def save_pkl(self, result, filename): - with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: - pickle.dump(result, f) - - def save_h5(self, result, filename): - with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: - for k in result.keys(): - f[k] = result[k] - - def save_model(self, model, name): - model_path = '{}/{}'.format(self.work_dir, name) - state_dict = model.state_dict() - weights = OrderedDict([[''.join(k.split('module.')), - v.cpu()] for k, v in state_dict.items()]) - torch.save(weights, model_path) - self.print_log('The model has been saved as {}.'.format(model_path)) - - def save_arg(self, arg): - - self.session_file = '{}/config.yaml'.format(self.work_dir) - - # save arg - arg_dict = vars(arg) - if not os.path.exists(self.work_dir): - os.makedirs(self.work_dir) - with open(self.session_file, 'w') as f: - f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) - yaml.dump(arg_dict, f, default_flow_style=False, indent=4) - - def print_log(self, str, print_time=True): - if print_time: - # localtime = time.asctime(time.localtime(time.time())) - str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str - - if self.print_to_screen: - print(str) - if self.save_log: - with open('{}/log.txt'.format(self.work_dir), 'a') as f: - print(str, file=f) - - def init_timer(self, *name): - self.record_time() - self.split_timer = {k: 0.0000001 for k in name} - - def check_time(self, name): - self.split_timer[name] += self.split_time() - - def record_time(self): - self.cur_time = time.time() - return self.cur_time - - def split_time(self): - split_time = time.time() - self.cur_time - self.record_time() - return split_time - - def print_timer(self): - proportion = { - k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) - for k, v in self.split_timer.items() - } - self.print_log('Time consumption:') - for k in proportion: - self.print_log( - '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) - ) - - -def str2bool(v): - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - - -def str2dict(v): - return eval('dict({})'.format(v)) #pylint: disable=W0123 - - -def _import_class_0(name): - components = name.split('.') - mod = __import__(components[0]) - for comp in components[1:]: - mod = getattr(mod, comp) - return mod - - -def import_class(import_str): - mod_str, _sep, class_str = import_str.rpartition('.') - __import__(mod_str) - try: - return getattr(sys.modules[mod_str], class_str) - except AttributeError: - raise ImportError('Class %s cannot be found (%s)' % - (class_str, - traceback.format_exception(*sys.exc_info()))) - - -class DictAction(argparse.Action): - def __init__(self, option_strings, dest, nargs=None, **kwargs): - if nargs is not None: - raise ValueError("nargs not allowed") - super(DictAction, self).__init__(option_strings, dest, **kwargs) - - def __call__(self, parser, namespace, values, option_string=None): - input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 - output_dict = getattr(namespace, self.dest) - for k in input_dict: - output_dict[k] = input_dict[k] - setattr(namespace, self.dest, output_dict) +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/torchlight/gpu.py b/torchlight/gpu.py index 462faa9..76566a9 100644 --- a/torchlight/gpu.py +++ b/torchlight/gpu.py @@ -1,36 +1,36 @@ -import os -import torch - - -def visible_gpu(gpus): - """ - set visible gpu. - - can be a single id, or a list - - return a list of new gpus ids - """ - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) - #os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2" - return list(range(len(gpus))) - - -def ngpu(gpus): - """ - count how many gpus used. - """ - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - return len(gpus) - - -def occupy_gpu(gpus=None): - """ - make program appear on nvidia-smi. - """ - if gpus is None: - torch.zeros(1).cuda() - else: - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - for g in gpus: - torch.zeros(1).cuda(g) +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + #os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2" + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/torchlight/io.py b/torchlight/io.py index c753ca1..5b43720 100644 --- a/torchlight/io.py +++ b/torchlight/io.py @@ -1,203 +1,203 @@ -#!/usr/bin/env python -import argparse -import os -import sys -import traceback -import time -import warnings -import pickle -from collections import OrderedDict -import yaml -import numpy as np -# torch -import torch -import torch.nn as nn -import torch.optim as optim -from torch.autograd import Variable - -with warnings.catch_warnings(): - warnings.filterwarnings("ignore",category=FutureWarning) - import h5py - -class IO(): - def __init__(self, work_dir, save_log=True, print_log=True): - self.work_dir = work_dir - self.save_log = save_log - self.print_to_screen = print_log - self.cur_time = time.time() - self.split_timer = {} - self.pavi_logger = None - self.session_file = None - self.model_text = '' - - # PaviLogger is removed in this version - def log(self, *args, **kwargs): - pass - # try: - # if self.pavi_logger is None: - # from torchpack.runner.hooks import PaviLogger - # url = 'http://pavi.parrotsdnn.org/log' - # with open(self.session_file, 'r') as f: - # info = dict( - # session_file=self.session_file, - # session_text=f.read(), - # model_text=self.model_text) - # self.pavi_logger = PaviLogger(url) - # self.pavi_logger.connect(self.work_dir, info=info) - # self.pavi_logger.log(*args, **kwargs) - # except: #pylint: disable=W0702 - # pass - - def load_model(self, model, **model_args): - Model = import_class(model) - model = Model(**model_args) - self.model_text += '\n\n' + str(model) - return model - - def load_weights(self, model, weights_path, ignore_weights=None): - if ignore_weights is None: - ignore_weights = [] - if isinstance(ignore_weights, str): - ignore_weights = [ignore_weights] - - self.print_log('Load weights from {}.'.format(weights_path)) - weights = torch.load(weights_path) - weights = OrderedDict([[k.split('module.')[-1], - v.cpu()] for k, v in weights.items()]) - - # filter weights - for i in ignore_weights: - ignore_name = list() - for w in weights: - if w.find(i) == 0: - ignore_name.append(w) - for n in ignore_name: - weights.pop(n) - self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) - - for w in weights: - self.print_log('Load weights [{}].'.format(w)) - - try: - model.load_state_dict(weights) - except (KeyError, RuntimeError): - state = model.state_dict() - diff = list(set(state.keys()).difference(set(weights.keys()))) - for d in diff: - self.print_log('Can not find weights [{}].'.format(d)) - state.update(weights) - model.load_state_dict(state) - return model - - def save_pkl(self, result, filename): - with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: - pickle.dump(result, f) - - def save_h5(self, result, filename): - with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: - for k in result.keys(): - f[k] = result[k] - - def save_model(self, model, name): - model_path = '{}/{}'.format(self.work_dir, name) - state_dict = model.state_dict() - weights = OrderedDict([[''.join(k.split('module.')), - v.cpu()] for k, v in state_dict.items()]) - torch.save(weights, model_path) - self.print_log('The model has been saved as {}.'.format(model_path)) - - def save_arg(self, arg): - - self.session_file = '{}/config.yaml'.format(self.work_dir) - - # save arg - arg_dict = vars(arg) - if not os.path.exists(self.work_dir): - os.makedirs(self.work_dir) - with open(self.session_file, 'w') as f: - f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) - yaml.dump(arg_dict, f, default_flow_style=False, indent=4) - - def print_log(self, str, print_time=True): - if print_time: - # localtime = time.asctime(time.localtime(time.time())) - str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str - - if self.print_to_screen: - print(str) - if self.save_log: - with open('{}/log.txt'.format(self.work_dir), 'a') as f: - print(str, file=f) - - def init_timer(self, *name): - self.record_time() - self.split_timer = {k: 0.0000001 for k in name} - - def check_time(self, name): - self.split_timer[name] += self.split_time() - - def record_time(self): - self.cur_time = time.time() - return self.cur_time - - def split_time(self): - split_time = time.time() - self.cur_time - self.record_time() - return split_time - - def print_timer(self): - proportion = { - k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) - for k, v in self.split_timer.items() - } - self.print_log('Time consumption:') - for k in proportion: - self.print_log( - '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) - ) - - -def str2bool(v): - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - - -def str2dict(v): - return eval('dict({})'.format(v)) #pylint: disable=W0123 - - -def _import_class_0(name): - components = name.split('.') - mod = __import__(components[0]) - for comp in components[1:]: - mod = getattr(mod, comp) - return mod - - -def import_class(import_str): - mod_str, _sep, class_str = import_str.rpartition('.') - __import__(mod_str) - try: - return getattr(sys.modules[mod_str], class_str) - except AttributeError: - raise ImportError('Class %s cannot be found (%s)' % - (class_str, - traceback.format_exception(*sys.exc_info()))) - - -class DictAction(argparse.Action): - def __init__(self, option_strings, dest, nargs=None, **kwargs): - if nargs is not None: - raise ValueError("nargs not allowed") - super(DictAction, self).__init__(option_strings, dest, **kwargs) - - def __call__(self, parser, namespace, values, option_string=None): - input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 - output_dict = getattr(namespace, self.dest) - for k in input_dict: - output_dict[k] = input_dict[k] - setattr(namespace, self.dest, output_dict) +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/torchlight/setup.py b/torchlight/setup.py index 95bbce2..a8a1647 100644 --- a/torchlight/setup.py +++ b/torchlight/setup.py @@ -1,8 +1,8 @@ -from setuptools import find_packages, setup - -setup( - name='torchlight', - version='1.0', - description='A mini framework for pytorch', - packages=find_packages(), - install_requires=[]) +from setuptools import find_packages, setup + +setup( + name='torchlight', + version='1.0', + description='A mini framework for pytorch', + packages=find_packages(), + install_requires=[]) diff --git a/torchlight/torchlight.egg-info/PKG-INFO b/torchlight/torchlight.egg-info/PKG-INFO index 53cafc2..4020517 100644 --- a/torchlight/torchlight.egg-info/PKG-INFO +++ b/torchlight/torchlight.egg-info/PKG-INFO @@ -1,10 +1,10 @@ -Metadata-Version: 1.0 -Name: torchlight -Version: 1.0 -Summary: A mini framework for pytorch -Home-page: UNKNOWN -Author: UNKNOWN -Author-email: UNKNOWN -License: UNKNOWN -Description: UNKNOWN -Platform: UNKNOWN +Metadata-Version: 1.0 +Name: torchlight +Version: 1.0 +Summary: A mini framework for pytorch +Home-page: UNKNOWN +Author: UNKNOWN +Author-email: UNKNOWN +License: UNKNOWN +Description: UNKNOWN +Platform: UNKNOWN diff --git a/torchlight/torchlight.egg-info/SOURCES.txt b/torchlight/torchlight.egg-info/SOURCES.txt index 1ee6009..4c2ca9d 100644 --- a/torchlight/torchlight.egg-info/SOURCES.txt +++ b/torchlight/torchlight.egg-info/SOURCES.txt @@ -1,8 +1,8 @@ -setup.py -torchlight/__init__.py -torchlight/gpu.py -torchlight/io.py -torchlight.egg-info/PKG-INFO -torchlight.egg-info/SOURCES.txt -torchlight.egg-info/dependency_links.txt +setup.py +torchlight/__init__.py +torchlight/gpu.py +torchlight/io.py +torchlight.egg-info/PKG-INFO +torchlight.egg-info/SOURCES.txt +torchlight.egg-info/dependency_links.txt torchlight.egg-info/top_level.txt \ No newline at end of file diff --git a/torchlight/torchlight.egg-info/dependency_links.txt b/torchlight/torchlight.egg-info/dependency_links.txt index 8b13789..d3f5a12 100644 --- a/torchlight/torchlight.egg-info/dependency_links.txt +++ b/torchlight/torchlight.egg-info/dependency_links.txt @@ -1 +1 @@ - + diff --git a/torchlight/torchlight.egg-info/top_level.txt b/torchlight/torchlight.egg-info/top_level.txt index c600430..09e8d4c 100644 --- a/torchlight/torchlight.egg-info/top_level.txt +++ b/torchlight/torchlight.egg-info/top_level.txt @@ -1 +1 @@ -torchlight +torchlight diff --git a/torchlight/torchlight/__init__.py b/torchlight/torchlight/__init__.py index 8b13789..d3f5a12 100644 --- a/torchlight/torchlight/__init__.py +++ b/torchlight/torchlight/__init__.py @@ -1 +1 @@ - + diff --git a/torchlight/torchlight/gpu.py b/torchlight/torchlight/gpu.py index 306c391..e086d4c 100644 --- a/torchlight/torchlight/gpu.py +++ b/torchlight/torchlight/gpu.py @@ -1,35 +1,35 @@ -import os -import torch - - -def visible_gpu(gpus): - """ - set visible gpu. - - can be a single id, or a list - - return a list of new gpus ids - """ - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) - return list(range(len(gpus))) - - -def ngpu(gpus): - """ - count how many gpus used. - """ - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - return len(gpus) - - -def occupy_gpu(gpus=None): - """ - make program appear on nvidia-smi. - """ - if gpus is None: - torch.zeros(1).cuda() - else: - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - for g in gpus: - torch.zeros(1).cuda(g) +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/torchlight/torchlight/io.py b/torchlight/torchlight/io.py index c753ca1..5b43720 100644 --- a/torchlight/torchlight/io.py +++ b/torchlight/torchlight/io.py @@ -1,203 +1,203 @@ -#!/usr/bin/env python -import argparse -import os -import sys -import traceback -import time -import warnings -import pickle -from collections import OrderedDict -import yaml -import numpy as np -# torch -import torch -import torch.nn as nn -import torch.optim as optim -from torch.autograd import Variable - -with warnings.catch_warnings(): - warnings.filterwarnings("ignore",category=FutureWarning) - import h5py - -class IO(): - def __init__(self, work_dir, save_log=True, print_log=True): - self.work_dir = work_dir - self.save_log = save_log - self.print_to_screen = print_log - self.cur_time = time.time() - self.split_timer = {} - self.pavi_logger = None - self.session_file = None - self.model_text = '' - - # PaviLogger is removed in this version - def log(self, *args, **kwargs): - pass - # try: - # if self.pavi_logger is None: - # from torchpack.runner.hooks import PaviLogger - # url = 'http://pavi.parrotsdnn.org/log' - # with open(self.session_file, 'r') as f: - # info = dict( - # session_file=self.session_file, - # session_text=f.read(), - # model_text=self.model_text) - # self.pavi_logger = PaviLogger(url) - # self.pavi_logger.connect(self.work_dir, info=info) - # self.pavi_logger.log(*args, **kwargs) - # except: #pylint: disable=W0702 - # pass - - def load_model(self, model, **model_args): - Model = import_class(model) - model = Model(**model_args) - self.model_text += '\n\n' + str(model) - return model - - def load_weights(self, model, weights_path, ignore_weights=None): - if ignore_weights is None: - ignore_weights = [] - if isinstance(ignore_weights, str): - ignore_weights = [ignore_weights] - - self.print_log('Load weights from {}.'.format(weights_path)) - weights = torch.load(weights_path) - weights = OrderedDict([[k.split('module.')[-1], - v.cpu()] for k, v in weights.items()]) - - # filter weights - for i in ignore_weights: - ignore_name = list() - for w in weights: - if w.find(i) == 0: - ignore_name.append(w) - for n in ignore_name: - weights.pop(n) - self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) - - for w in weights: - self.print_log('Load weights [{}].'.format(w)) - - try: - model.load_state_dict(weights) - except (KeyError, RuntimeError): - state = model.state_dict() - diff = list(set(state.keys()).difference(set(weights.keys()))) - for d in diff: - self.print_log('Can not find weights [{}].'.format(d)) - state.update(weights) - model.load_state_dict(state) - return model - - def save_pkl(self, result, filename): - with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: - pickle.dump(result, f) - - def save_h5(self, result, filename): - with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: - for k in result.keys(): - f[k] = result[k] - - def save_model(self, model, name): - model_path = '{}/{}'.format(self.work_dir, name) - state_dict = model.state_dict() - weights = OrderedDict([[''.join(k.split('module.')), - v.cpu()] for k, v in state_dict.items()]) - torch.save(weights, model_path) - self.print_log('The model has been saved as {}.'.format(model_path)) - - def save_arg(self, arg): - - self.session_file = '{}/config.yaml'.format(self.work_dir) - - # save arg - arg_dict = vars(arg) - if not os.path.exists(self.work_dir): - os.makedirs(self.work_dir) - with open(self.session_file, 'w') as f: - f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) - yaml.dump(arg_dict, f, default_flow_style=False, indent=4) - - def print_log(self, str, print_time=True): - if print_time: - # localtime = time.asctime(time.localtime(time.time())) - str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str - - if self.print_to_screen: - print(str) - if self.save_log: - with open('{}/log.txt'.format(self.work_dir), 'a') as f: - print(str, file=f) - - def init_timer(self, *name): - self.record_time() - self.split_timer = {k: 0.0000001 for k in name} - - def check_time(self, name): - self.split_timer[name] += self.split_time() - - def record_time(self): - self.cur_time = time.time() - return self.cur_time - - def split_time(self): - split_time = time.time() - self.cur_time - self.record_time() - return split_time - - def print_timer(self): - proportion = { - k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) - for k, v in self.split_timer.items() - } - self.print_log('Time consumption:') - for k in proportion: - self.print_log( - '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) - ) - - -def str2bool(v): - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - - -def str2dict(v): - return eval('dict({})'.format(v)) #pylint: disable=W0123 - - -def _import_class_0(name): - components = name.split('.') - mod = __import__(components[0]) - for comp in components[1:]: - mod = getattr(mod, comp) - return mod - - -def import_class(import_str): - mod_str, _sep, class_str = import_str.rpartition('.') - __import__(mod_str) - try: - return getattr(sys.modules[mod_str], class_str) - except AttributeError: - raise ImportError('Class %s cannot be found (%s)' % - (class_str, - traceback.format_exception(*sys.exc_info()))) - - -class DictAction(argparse.Action): - def __init__(self, option_strings, dest, nargs=None, **kwargs): - if nargs is not None: - raise ValueError("nargs not allowed") - super(DictAction, self).__init__(option_strings, dest, **kwargs) - - def __call__(self, parser, namespace, values, option_string=None): - input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 - output_dict = getattr(namespace, self.dest) - for k in input_dict: - output_dict[k] = input_dict[k] - setattr(namespace, self.dest, output_dict) +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict)