From 7ab47c909c260310afaebded2be71cb467ad4603 Mon Sep 17 00:00:00 2001 From: wang shuxi Date: Fri, 7 May 2021 11:06:48 +0800 Subject: [PATCH 1/9] some env problems settled --- .../samples_with_missing_skeletons.txt | 302 ----------------- main.py | 6 +- processor/io.py | 303 +++++++++++------- processor/processor.py | 6 +- processor/recognition.py | 6 +- torchlight/__init__.py | 1 + .../__pycache__/__init__.cpython-36.pyc | Bin 0 -> 122 bytes torchlight/__pycache__/io.cpython-36.pyc | Bin 0 -> 7038 bytes torchlight/build/lib/torchlight/__init__.py | 1 + torchlight/build/lib/torchlight/gpu.py | 35 ++ torchlight/build/lib/torchlight/io.py | 203 ++++++++++++ torchlight/dist/torchlight-1.0-py3.6.egg | Bin 0 -> 8580 bytes torchlight/gpu.py | 35 ++ torchlight/io.py | 203 ++++++++++++ torchlight/torchlight.egg-info/PKG-INFO | 10 + torchlight/torchlight.egg-info/SOURCES.txt | 8 + .../torchlight.egg-info/dependency_links.txt | 1 + torchlight/torchlight.egg-info/top_level.txt | 1 + 18 files changed, 703 insertions(+), 418 deletions(-) delete mode 100644 data/NTU-RGB+D/samples_with_missing_skeletons.txt create mode 100644 torchlight/__init__.py create mode 100644 torchlight/__pycache__/__init__.cpython-36.pyc create mode 100644 torchlight/__pycache__/io.cpython-36.pyc create mode 100644 torchlight/build/lib/torchlight/__init__.py create mode 100644 torchlight/build/lib/torchlight/gpu.py create mode 100644 torchlight/build/lib/torchlight/io.py create mode 100644 torchlight/dist/torchlight-1.0-py3.6.egg create mode 100644 torchlight/gpu.py create mode 100644 torchlight/io.py create mode 100644 torchlight/torchlight.egg-info/PKG-INFO create mode 100644 torchlight/torchlight.egg-info/SOURCES.txt create mode 100644 torchlight/torchlight.egg-info/dependency_links.txt create mode 100644 torchlight/torchlight.egg-info/top_level.txt diff --git a/data/NTU-RGB+D/samples_with_missing_skeletons.txt b/data/NTU-RGB+D/samples_with_missing_skeletons.txt deleted file mode 100644 index 5ad472e..0000000 --- a/data/NTU-RGB+D/samples_with_missing_skeletons.txt +++ /dev/null @@ -1,302 +0,0 @@ -S001C002P005R002A008 -S001C002P006R001A008 -S001C003P002R001A055 -S001C003P002R002A012 -S001C003P005R002A004 -S001C003P005R002A005 -S001C003P005R002A006 -S001C003P006R002A008 -S002C002P011R002A030 -S002C003P008R001A020 -S002C003P010R002A010 -S002C003P011R002A007 -S002C003P011R002A011 -S002C003P014R002A007 -S003C001P019R001A055 -S003C002P002R002A055 -S003C002P018R002A055 -S003C003P002R001A055 -S003C003P016R001A055 -S003C003P018R002A024 -S004C002P003R001A013 -S004C002P008R001A009 -S004C002P020R001A003 -S004C002P020R001A004 -S004C002P020R001A012 -S004C002P020R001A020 -S004C002P020R001A021 -S004C002P020R001A036 -S005C002P004R001A001 -S005C002P004R001A003 -S005C002P010R001A016 -S005C002P010R001A017 -S005C002P010R001A048 -S005C002P010R001A049 -S005C002P016R001A009 -S005C002P016R001A010 -S005C002P018R001A003 -S005C002P018R001A028 -S005C002P018R001A029 -S005C003P016R002A009 -S005C003P018R002A013 -S005C003P021R002A057 -S006C001P001R002A055 -S006C002P007R001A005 -S006C002P007R001A006 -S006C002P016R001A043 -S006C002P016R001A051 -S006C002P016R001A052 -S006C002P022R001A012 -S006C002P023R001A020 -S006C002P023R001A021 -S006C002P023R001A022 -S006C002P023R001A023 -S006C002P024R001A018 -S006C002P024R001A019 -S006C003P001R002A013 -S006C003P007R002A009 -S006C003P007R002A010 -S006C003P007R002A025 -S006C003P016R001A060 -S006C003P017R001A055 -S006C003P017R002A013 -S006C003P017R002A014 -S006C003P017R002A015 -S006C003P022R002A013 -S007C001P018R002A050 -S007C001P025R002A051 -S007C001P028R001A050 -S007C001P028R001A051 -S007C001P028R001A052 -S007C002P008R002A008 -S007C002P015R002A055 -S007C002P026R001A008 -S007C002P026R001A009 -S007C002P026R001A010 -S007C002P026R001A011 -S007C002P026R001A012 -S007C002P026R001A050 -S007C002P027R001A011 -S007C002P027R001A013 -S007C002P028R002A055 -S007C003P007R001A002 -S007C003P007R001A004 -S007C003P019R001A060 -S007C003P027R002A001 -S007C003P027R002A002 -S007C003P027R002A003 -S007C003P027R002A004 -S007C003P027R002A005 -S007C003P027R002A006 -S007C003P027R002A007 -S007C003P027R002A008 -S007C003P027R002A009 -S007C003P027R002A010 -S007C003P027R002A011 -S007C003P027R002A012 -S007C003P027R002A013 -S008C002P001R001A009 -S008C002P001R001A010 -S008C002P001R001A014 -S008C002P001R001A015 -S008C002P001R001A016 -S008C002P001R001A018 -S008C002P001R001A019 -S008C002P008R002A059 -S008C002P025R001A060 -S008C002P029R001A004 -S008C002P031R001A005 -S008C002P031R001A006 -S008C002P032R001A018 -S008C002P034R001A018 -S008C002P034R001A019 -S008C002P035R001A059 -S008C002P035R002A002 -S008C002P035R002A005 -S008C003P007R001A009 -S008C003P007R001A016 -S008C003P007R001A017 -S008C003P007R001A018 -S008C003P007R001A019 -S008C003P007R001A020 -S008C003P007R001A021 -S008C003P007R001A022 -S008C003P007R001A023 -S008C003P007R001A025 -S008C003P007R001A026 -S008C003P007R001A028 -S008C003P007R001A029 -S008C003P007R002A003 -S008C003P008R002A050 -S008C003P025R002A002 -S008C003P025R002A011 -S008C003P025R002A012 -S008C003P025R002A016 -S008C003P025R002A020 -S008C003P025R002A022 -S008C003P025R002A023 -S008C003P025R002A030 -S008C003P025R002A031 -S008C003P025R002A032 -S008C003P025R002A033 -S008C003P025R002A049 -S008C003P025R002A060 -S008C003P031R001A001 -S008C003P031R002A004 -S008C003P031R002A014 -S008C003P031R002A015 -S008C003P031R002A016 -S008C003P031R002A017 -S008C003P032R002A013 -S008C003P033R002A001 -S008C003P033R002A011 -S008C003P033R002A012 -S008C003P034R002A001 -S008C003P034R002A012 -S008C003P034R002A022 -S008C003P034R002A023 -S008C003P034R002A024 -S008C003P034R002A044 -S008C003P034R002A045 -S008C003P035R002A016 -S008C003P035R002A017 -S008C003P035R002A018 -S008C003P035R002A019 -S008C003P035R002A020 -S008C003P035R002A021 -S009C002P007R001A001 -S009C002P007R001A003 -S009C002P007R001A014 -S009C002P008R001A014 -S009C002P015R002A050 -S009C002P016R001A002 -S009C002P017R001A028 -S009C002P017R001A029 -S009C003P017R002A030 -S009C003P025R002A054 -S010C001P007R002A020 -S010C002P016R002A055 -S010C002P017R001A005 -S010C002P017R001A018 -S010C002P017R001A019 -S010C002P019R001A001 -S010C002P025R001A012 -S010C003P007R002A043 -S010C003P008R002A003 -S010C003P016R001A055 -S010C003P017R002A055 -S011C001P002R001A008 -S011C001P018R002A050 -S011C002P008R002A059 -S011C002P016R002A055 -S011C002P017R001A020 -S011C002P017R001A021 -S011C002P018R002A055 -S011C002P027R001A009 -S011C002P027R001A010 -S011C002P027R001A037 -S011C003P001R001A055 -S011C003P002R001A055 -S011C003P008R002A012 -S011C003P015R001A055 -S011C003P016R001A055 -S011C003P019R001A055 -S011C003P025R001A055 -S011C003P028R002A055 -S012C001P019R001A060 -S012C001P019R002A060 -S012C002P015R001A055 -S012C002P017R002A012 -S012C002P025R001A060 -S012C003P008R001A057 -S012C003P015R001A055 -S012C003P015R002A055 -S012C003P016R001A055 -S012C003P017R002A055 -S012C003P018R001A055 -S012C003P018R001A057 -S012C003P019R002A011 -S012C003P019R002A012 -S012C003P025R001A055 -S012C003P027R001A055 -S012C003P027R002A009 -S012C003P028R001A035 -S012C003P028R002A055 -S013C001P015R001A054 -S013C001P017R002A054 -S013C001P018R001A016 -S013C001P028R001A040 -S013C002P015R001A054 -S013C002P017R002A054 -S013C002P028R001A040 -S013C003P008R002A059 -S013C003P015R001A054 -S013C003P017R002A054 -S013C003P025R002A022 -S013C003P027R001A055 -S013C003P028R001A040 -S014C001P027R002A040 -S014C002P015R001A003 -S014C002P019R001A029 -S014C002P025R002A059 -S014C002P027R002A040 -S014C002P039R001A050 -S014C003P007R002A059 -S014C003P015R002A055 -S014C003P019R002A055 -S014C003P025R001A048 -S014C003P027R002A040 -S015C001P008R002A040 -S015C001P016R001A055 -S015C001P017R001A055 -S015C001P017R002A055 -S015C002P007R001A059 -S015C002P008R001A003 -S015C002P008R001A004 -S015C002P008R002A040 -S015C002P015R001A002 -S015C002P016R001A001 -S015C002P016R002A055 -S015C003P008R002A007 -S015C003P008R002A011 -S015C003P008R002A012 -S015C003P008R002A028 -S015C003P008R002A040 -S015C003P025R002A012 -S015C003P025R002A017 -S015C003P025R002A020 -S015C003P025R002A021 -S015C003P025R002A030 -S015C003P025R002A033 -S015C003P025R002A034 -S015C003P025R002A036 -S015C003P025R002A037 -S015C003P025R002A044 -S016C001P019R002A040 -S016C001P025R001A011 -S016C001P025R001A012 -S016C001P025R001A060 -S016C001P040R001A055 -S016C001P040R002A055 -S016C002P008R001A011 -S016C002P019R002A040 -S016C002P025R002A012 -S016C003P008R001A011 -S016C003P008R002A002 -S016C003P008R002A003 -S016C003P008R002A004 -S016C003P008R002A006 -S016C003P008R002A009 -S016C003P019R002A040 -S016C003P039R002A016 -S017C001P016R002A031 -S017C002P007R001A013 -S017C002P008R001A009 -S017C002P015R001A042 -S017C002P016R002A031 -S017C002P016R002A055 -S017C003P007R002A013 -S017C003P008R001A059 -S017C003P016R002A031 -S017C003P017R001A055 -S017C003P020R001A059 diff --git a/main.py b/main.py index ee1a0a2..4575f4a 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,7 @@ import argparse import sys import torchlight -from torchlight import import_class +from torchlight.io import import_class if __name__ == '__main__': @@ -10,7 +10,8 @@ processors = dict() processors['recognition'] = import_class('processor.recognition.REC_Processor') - processors['demo'] = import_class('processor.demo.Demo') + #processors['recognition'] = import_class('processor.processor.Processor.io.__init__') + #processors['demo'] = import_class('processor.demo.Demo') subparsers = parser.add_subparsers(dest='processor') for k, p in processors.items(): @@ -20,5 +21,6 @@ # start Processor = processors[arg.processor] + print(sys.argv[:]) p = Processor(sys.argv[2:]) p.start() diff --git a/processor/io.py b/processor/io.py index fb9e4f8..c753ca1 100644 --- a/processor/io.py +++ b/processor/io.py @@ -1,116 +1,203 @@ -import sys -import os +#!/usr/bin/env python import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict import yaml import numpy as np - +# torch import torch import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable -import torchlight -from torchlight import str2bool -from torchlight import DictAction -from torchlight import import_class - +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py class IO(): - - def __init__(self, argv=None): - - self.load_arg(argv) - self.init_environment() - self.load_model() - self.load_weights() - self.gpu() - - def load_arg(self, argv=None): - parser = self.get_parser() - - # load arg form config file - p = parser.parse_args(argv) - if p.config is not None: - # load config file - with open(p.config, 'r') as f: - default_arg = yaml.load(f) - - # update parser from config file - key = vars(p).keys() - for k in default_arg.keys(): - if k not in key: - print('Unknown Arguments: {}'.format(k)) - assert k in key - - parser.set_defaults(**default_arg) - - self.arg = parser.parse_args(argv) - - def init_environment(self): - self.save_dir = os.path.join(self.arg.work_dir, - self.arg.max_hop_dir, - self.arg.lamda_act_dir) - self.io = torchlight.IO(self.save_dir, save_log=self.arg.save_log, print_log=self.arg.print_log) - self.io.save_arg(self.arg) - - # gpu - if self.arg.use_gpu: - gpus = torchlight.visible_gpu(self.arg.device) - torchlight.occupy_gpu(gpus) - self.gpus = gpus - self.dev = "cuda:0" - else: - self.dev = "cpu" - - def load_model(self): - self.model1 = self.io.load_model(self.arg.model1, **(self.arg.model1_args)) - self.model2 = self.io.load_model(self.arg.model2, **(self.arg.model2_args)) - - def load_weights(self): - if self.arg.weights1: - self.model1 = self.io.load_weights(self.model1, self.arg.weights1, self.arg.ignore_weights) - self.model2 = self.io.load_weights(self.model2, self.arg.weights2, self.arg.ignore_weights) - - def gpu(self): - # move modules to gpu - self.model1 = self.model1.to(self.dev) - self.model2 = self.model2.to(self.dev) - for name, value in vars(self).items(): - cls_name = str(value.__class__) - if cls_name.find('torch.nn.modules') != -1: - setattr(self, name, value.to(self.dev)) - - # model parallel - if self.arg.use_gpu and len(self.gpus) > 1: - self.model1 = nn.DataParallel(self.model1, device_ids=self.gpus) - self.model2 = nn.DataParallel(self.model2, device_ids=self.gpus) - - def start(self): - self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg)))) - - @staticmethod - def get_parser(add_help=False): - - #region arguments yapf: disable - # parameter priority: command line > config > default - parser = argparse.ArgumentParser( add_help=add_help, description='IO Processor') - - parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results') - parser.add_argument('-c', '--config', default=None, help='path to the configuration file') - - # processor - parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not') - parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing') - - # visulize and debug - parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not') - parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not') - - # model - parser.add_argument('--model1', default=None, help='the model will be used') - parser.add_argument('--model2', default=None, help='the model will be used') - parser.add_argument('--model1_args', action=DictAction, default=dict(), help='the arguments of model') - parser.add_argument('--model2_args', action=DictAction, default=dict(), help='the arguments of model') - parser.add_argument('--weights', default=None, help='the weights for network initialization') - parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization') - #endregion yapf: enable - - return parser + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/processor/processor.py b/processor/processor.py index 03fc2cf..690d846 100644 --- a/processor/processor.py +++ b/processor/processor.py @@ -8,9 +8,9 @@ import torch.optim as optim import torchlight -from torchlight import str2bool -from torchlight import DictAction -from torchlight import import_class +from torchlight.io import str2bool +from torchlight.io import DictAction +from torchlight.io import import_class from .io import IO diff --git a/processor/recognition.py b/processor/recognition.py index d3905af..06fb642 100644 --- a/processor/recognition.py +++ b/processor/recognition.py @@ -13,9 +13,9 @@ import torch.optim as optim import torchlight -from torchlight import str2bool -from torchlight import DictAction -from torchlight import import_class +from torchlight.io import str2bool +from torchlight.io import DictAction +from torchlight.io import import_class from .processor import Processor diff --git a/torchlight/__init__.py b/torchlight/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/torchlight/__init__.py @@ -0,0 +1 @@ + diff --git a/torchlight/__pycache__/__init__.cpython-36.pyc b/torchlight/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3c790e409f2747726cc89be56d654bae0ca6888 GIT binary patch literal 122 zcmXr!<>k8YcX9$F0|UcjAcg~wfCCU0vjB+{hF}IwM!%H|MId1W@k?DlBR@A)zr46Y y-!WL%-PupSB)=#*BPTOGqeMSGJ~J<~BtBlRpz;=nO>TZlX-=vg$be!XW&i-0_#0{f literal 0 HcmV?d00001 diff --git a/torchlight/__pycache__/io.cpython-36.pyc b/torchlight/__pycache__/io.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..967d1892bfc511e3070341c8edcd71a9781033d2 GIT binary patch literal 7038 zcmaJ`O^_SMb?%-&3Gzxk_McRiz|3<>Zo6u0FVOihao;zWD50s`{D}bM#SFF8h0L@Uy#)jjjIg z>3;qCz3;u(`?Zyo;5Wbf=?{K+Nr?XyCw?W=-$Y9P51A55_C%PMGL+o5LW@f~w7GOb zhf6nfx%5I0rPV9+{m_^Bx5co?V?kKrvK*GVT;OsxY;}3(8l;XCl0)0#!o!qAIHu zl$TURt)hGhv)9x*X0J!0dwFVyFDv;#)Ha^uykfg1P3c{&A|0uFomPgX|9(?MCyor>4j4TqwF0kqqWhwia1hX9Ff&eEc7R;)uq6#qL zsgf#Vq@Y$*1*NZ6)f&p8T2~t=1NEvpr#3OOq+U?x(Nk6z)J2pl>P2-4WktQDE~8vk zFRLpk*OZSVT{YI*?>=lfpoV@_T1a=1(!WMFmSfQnV+*+)+sLiV?%GpB~T#m!lP?eLlXu5U(=kP-5|VAN(2v9u9n6-V+lW=dm+6we;n&y+`K~ zT}#QXJ+`JejTD*Fb*VmaGLQ8|jGb;_D((pypEyUmWBZ0VCmhl*A2!6TJiDjjySqD_AskD@v&NI>{CDQ>|$pZZS07dpX`Xo?_u_3 zUUL&`5_KNtU?raMiyd)X8Ov^Y>`;40#F--`zmn$?JTU;ONRY09EHFCEllieM%$6^s-1-KYTiQRMk;Gc@izoNPYFu z#;J4XeV6Vw<7%8_)pjRd@`T4~n!aVyd%C!2sFt^a6Lah| z!eR%}acTn9X}4iqK#%c;1BG>JW!6p&LkwV_aj7wAX8U0!H#m)1rzy-UvJ;kOV{sGk zDi4i}aZyAU3=H-#Qk}UVQ)b#{OZ*LXVi{)AS?I_@OQ}yjbLi&6IkYINmSRi`}n+{tp za$C;OYHyF+r<3bXC&oJ5t+~cYa07Y^OR)(s-k{SufQ@j}us@iio2R36*u#yG`N+A_ zKFfmKP7MxvA7L$qazT(rsLvJ*Ba34Y9+NY~$GC?|#`udv2*$1aAg@ns;D+ovgjJ8$ zCUhhRnq`ljTjKH6oj-Xr{a09$&bI%?;5c+3`_b}*`YSw`3&!7}uHU4L^h&=<8JWC1 zWBXEI>i6IH0lHHUnUEK8W20bxu>Cg{=b~Y{Dn8dcsNix46(wB`kXd#uV3?EH3moRN zF?P9`+x38)d};Wmpd5T_=jg>(=gzP06dDmP8I7)?!e~^@(MSn=58+V2*(oTXy%%_7 z=>n!NKndXEbdyf3ub@}If(*8|N_A$*RjQrli+&A#`$%&zdUehZwr&3(6pdbaxF1z{ zxVqm=tGn=Fs>EGYMGGif^G{sUb=s4FL({=DVQuKpY{r?Ix72LokVwvLFKM^tkH1de zou=9Ze49)|gyEE|s=#&nq@J>zo8aY~Q(nXhrt-}`L1lik;4C|L>{D1|YjKOX@u^I2 zphfl$by#5^wOp=mknhngDjU|q=>}P;Tc#W0S#a9nG}_7~Pxe#ikJJ|yssi)nmjQg4B;?>S#<(K7PH|qm2}ifw93OA(n)MYVS>I-+3Aho zq)=4C%RR&J%r5!4o-fB1dQODUIa$xiZ6F4Xr7K`}NRts8*y0vUD%c~pu`>|>R9OUG zf0mfme$+Z(YCged^N5AK{^iJmb9D%bRXdZK(}h zW2Taub{>fl5TSR;U-!M>HQ(`F-}lRCJHAaVO*0pgNeW3Fcj?nQvK%~BNQvIVU8yEOLr*iQg(n0Lp`dK70qo1t)^odvW*MdkJliw z!~`Wj03~S9GQ)ef4B)nc=Hi-${3wga0 znc)?OqH)wMybZw2EAfwL11%~Dlc3n-Z%?epuj6gN0(H)|k#&xlHJL}~19`+GrL*^+ zhAHethz0rShb&GQkQqy2FflbYWZ5{>ah8Fr^9A?;sd3}y_;kj{09@RF_o+wdTkw%i zMA*)JB>bG8>LZ)p8bCeHM(B6EmeJ!P92bg>rZ@%OPuzn*uI)@w9=jg@A~NkJjM$(&|PK9B9$Hf+5jiX+O}l(Fgh6g9Il<9jQj zb(hi#nEixi@;%Uv{x=f4Xb5>+zT?^zJ)kM?P(})-e@Gcg7ydoa)G=MQ{#$C6T1l@5 z+W>J)^O*Sox_ literal 0 HcmV?d00001 diff --git a/torchlight/build/lib/torchlight/__init__.py b/torchlight/build/lib/torchlight/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/torchlight/build/lib/torchlight/__init__.py @@ -0,0 +1 @@ + diff --git a/torchlight/build/lib/torchlight/gpu.py b/torchlight/build/lib/torchlight/gpu.py new file mode 100644 index 0000000..306c391 --- /dev/null +++ b/torchlight/build/lib/torchlight/gpu.py @@ -0,0 +1,35 @@ +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/torchlight/build/lib/torchlight/io.py b/torchlight/build/lib/torchlight/io.py new file mode 100644 index 0000000..c753ca1 --- /dev/null +++ b/torchlight/build/lib/torchlight/io.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/torchlight/dist/torchlight-1.0-py3.6.egg b/torchlight/dist/torchlight-1.0-py3.6.egg new file mode 100644 index 0000000000000000000000000000000000000000..d7e16926c97cdb66afe323750bcd72b7fa41499c GIT binary patch literal 8580 zcma)>bx>T(`n3lQ?jAI_YjA=EcV~d$?(Py?g1fuh;O-U%CxPG)+$FdK_{i_py(cI4 zo~o~>x@LCOJge8@#h2B@ULZDa<==u^4}BND+p1VaS7AA67i%c#^&k;-$UZJZk-bLFZ{*ET++Y zQ_}7>t8eq=Stpj%4N<^8H?_NgaXWW;>1WTYY`%v5OolCF#lN`@JgyEjd~rMSGJoC6 ze_X35s40tzsW3UaJG+eZA_xhgeOzyMt0fkZM2Nc}#Ivd!mq+?>?MP9*_6#b`z~h0s zk`d&oOArSP6atu=>Nt%*4V_>$V1QC!q&>^m$^+yz~5U5=PZSxCPF=#dAWo|B=GK-69j_cq9h#uxZ3QaCw z>uP&2_6$rHwn9na5E#}=#ZIl3)ogIv72>4;OnoeuMkE2QiT(-3y)M)m&|c6m`DKpn zZIb>`51_{RjIrEfqz}pdGNk-h@*5ai)tP#keZ?frxia=9QkD6H0v{04tsylLJ?kMB zO6rUyC2UeTr;lu&8nl@;Mt4&NZM!s?%|59$xC|)~n4DJ9UW217jcTc`OcIRGT6;YG zl~fIM$Oj7yxjF8)8yruD`}`HAXn0YHTmr%-+G*lHa5{mJpFlfmy9wZ-&l_mVpnnF;z@q> zF^5>2mN%A9-VSG~GF8|zO{JBl9{(4?Vk_z4m=jOdb+$rN% zuuNYntjQ**1LDqJlsS^Zl-de=Exd9s`0g#H+BMP_rir9QZMArQmAiMxdi~~&q-XQT zb>-ODt`F6VtW0|!kj$Mw51KsIs#WE7%-F^CN7o38nLc1`wmu2dfAf-V?ZFpoPT1ItK<=4-Pk!FO@+ z5J)_N^Mr)=lqFGGy;e?TWUn;z?ZdAuAD7OVM{kH-h%)3k%*4v0reGR%$w`dxnQNZ? zNqP~ec9cKliXihCU)Bar*Q|$4%{Kgm4Y*)`7=t`Vj+4HnfC&}>Bbg{%vsbx#QGWj7 zM)H+P#fzn(x?Nz4z-degp3MP?WkM8@GX$70?)qLtUA8=On}bkvs8r$dt=1;<(x6V& zAPt-^Y?jxRteU%rTtMO0MrzsAEFZxb7Yo}XGBQ_47K970eaix&Y!nJqjlK3)iZIT# zcQ%`GT27O+;0(^lCTI$Z(A3PDYFF8K>1pzYaT|_3F9q4TBZDK@4UI5<&e=)>VDXaN z2dQo6^z4xZbc44_0B?_wiSdYSO45!X*u!DjXf~K#qm$ zmFH}Vf`Mxqa4_vW@5Hui1m2U%r6KFN+dt;BoYHP)5xHttN}It!a< zSKwVl7$WiR5T$#ao!{S{TW(+RCW!?n`}q=I{d{8G^j7ZCA=ZkHqn>$s;2H_qw1cdb zf5z=K-Tq?o%Ev;}l{kmDahb3Q?Hw$_x9lx;qAv+c9PA9m+~zBZT1F_?t8F)~v3eC_ z!m&S7G#E0~_x7F{CZW7aL~ihk19O!`meyfcI#EuI!6gF6RDlBS`}i?Pl7&J!t&C%N z+o7nDXjzE*FvZ;-2tv3$-__LHp;4nRy71l-Z}}%4Q)}M%^7O8Ay|1>N{9L%J*e3if z$k)G*T;O1n1!~I3HMeHHf1F}#xxt6BFQFd572GS3SrGFXDFO}2-x)@I7E=x6F4QLE zkLd+jp82yb)RYZ=nH~imr{Io#IJ}|q4&QPbW@sb^$ViJ)17l+On6+8byFR4E@^C4k zRC)RYY@o{;uhqI?S=|O6XA`og9rhRa*0$$4m1iabr#JiRW?0-ZLXFN9DNs|^NS^8` z-B^DNi^Zi0<5nj+N$7$A8>V6PtI!}pp&0a`LEI+HE3psoD)zpdhdLcf2g)Hyna^mLdCu3ZZAPH_h&_XRaa= zm>S~H7yCZj>3LR3~ z_*tD+7qc!&NH40P@EOj|G^U3{cCyU5q_G#0PXTpFrEOW7Z}Db~4a__BIW&dWpQ1KwxA^`PMA+^MSg(~-O-V#82|hI4leyU=6jFCmKB zdy0x1kL^xQY2Dx>>0L4TtBAU3oLHm!;UCy1;2u;M2yWa zbwi$6#GeMY^_h8v6}A%6ytB&2Nrl%ToOh(I$LI&=B{D`k`z*Kgd!h!?vrM7;?Rl;F zFtTvb6P}BEE+^c|LU2%>7v}^(DAC>2PAhlM=iO-igH_bR(J5NMQ~?3E4!Zz&Z&QV< ziQ zmB}dxwutCpK(xBBL`vla_aa?u@GN0^*ODUMmR zEU}p)!ZL>9jFLN$R`>x{!?Dtujj5cA*T8Ix>lI*Kr3=*}$%vM0tIU^XckziG$nbxa zzti&;u#q>abq(7*Pl>_!^jm52>llI77fodSQj&zcOs2n=CVG1I9!3U6<|cZ2f7O*n z_8!jWcD9V{oG*2Y5&O>ewk1=AoW8NXzWdlmXdy@xA)PTu3T3_Aw&mz}b<$+S<8Tb=nzyp7@plA2s6zUgh} z|Bl?{EaydB5=x=fB^(q;mSvE4V{k&c%84gRH83@zRDl_Wvb?TqN|u4lh{Px7!aEEt6H!OBnuTNph7A~$En)NjnrK&} z)5|N($Sq2N4Br;rR`j~FF>A7&ntqolfCXnlCvFMG&t5*r7*i$=A z0yJME-CzlY7J+o57vmSrE50{!5)@LOw=CT=i@>$}q_Z1+?^2ac`w7&(A zub<{{w)ideJbi`T%zqWPzx}B-(SsJ5Z(?$uI~iBSMVisQJvxj(cxps6P^~mn|4n)N zZMUNtv43rL&=IvJOb8}OY<6={Bx_Ja&_Atblp9`kG#0(1Ga@<@VSbni$C97)%u&22 z2ZP)19a;O<;UadHne6u!#q)xP;IsV)4;Kj=!W=mDQhuc$se@uI}*AaOMm92$ij#nybsELFBxR!*cL6wk@MivfonWuV% zE83kUzDjm}FV#25*%HK$PBZ#enHOxBr!^u(`p*{MVaG{tq5ghS1XvfRd_VyJB3=Uk zH2=J1Urq#n>{#Gv!t2i-r`fMq@bqg@t8X#HgcJAP=&(vI#UIt+E!G*2&@ENGU)-QG zVnxFd=z&BW)AxdX0Qfs5k!mxvQ!H$6#j8}mD|D201ow8`>e|* z+RE8H6_b{Ia(egB1-ctg!gllJe%R{i`;^`s633%SV(b1o z^Un3TFVn|CE{ZdFTafx_GSa#G)9N76<#W%f5lKyR8co;+wJg?dh?^95{w4ec6QV}q zv>rNkZNO7`X@uJ4q|sHn=H~;KoGZ|{+$l`&n%PxU$-tv}7WKAU(OD7m>OrHCY*r}a zg4Hp~&%&ug*u2YqU_181ieh#{ikO*c{sUgrq@)ho?594o9cse9x*%jo204ZxHB8F& z4LRXNF&KpiLowB9K0~$2JmzqhTK`pS=5RzqO#0Cc9tU&bd70^!HW%o zd$k_{Ez!Pm-1__Ch8-6tu1nJ3g4Lc=xi>$S%!|E8(yc9Xw^cB!&@+6TCRpk`XGC($ z5z847w2t9IqN1vTW@W;iutpw>_rL?;TA{v9Xs!vfaYrx?>?j#Z+V{S(nhth*qWVS) zss)Z51ebW-jHKi^eFY2O`UO&jKLvC=;n!3tFp7|G}}@ot=eOav@8%ev-XzYI%lkUm99O-O*@S8ErbaQ& z&YbC>MXg@?q7zf4kLZa!qK}(aArY#p+x(9z3y2C!_E+smB`5+K)Jxn8L(7WvIY%{t zH#PS!|Ht(Dk`DO|>#fVH=f-~da+Ot(EMxj<4QZ2R2ZjdUP}c)J({ncaeSd_SqDDlP zyN%WXXj0h656Va;2Alch$5q)UC#H#!gUmM*w1-ZUn5IN;du1+@-Yz9tD3^ybiDG;9^}WV6`4IUzWP{?L(c!8&5gugjlX2 zLXo&x8|D_Z0KG3Mn zA;RwQ=pdsi7bnP}>MPJ!w zhJJYK@RBkNi)^$F;^ZWqJO~o2<%Te?hu#U*5K&K{J|>*X;6skAzJ3*C@43TN#Y#?+ zar{87v=O0O5-0@a8QqPK590K2=IbrRfc=*0+}Y*y`53ci zl1YqB^JZIWbvWeVjDhfiORj^)8+2v{R^n(jXL?CN0@Pe|1WGNstok=otX2(zw=98< z;<4hTIi28zny&+rwLg*#1J7_}S%S6G<1|dE>;rEY`^qJ(*bTt}W@##i#C$Ea7r4L# zkjfdr&I=SJ&yU!2%FLapr@`E4;2laZ!b)+M-ls`U$T54jP13=)t+XIKdSyU|s@E-46ejrm|+dsNJE6WBVO zGJVA8KP{8SuX@Cb8bQTudwTcws;y}*!m@qX% zAV~!e%PLRoioY!os1$=9OoSabTz#r|GR+K+Io)^A?$D|?{R6QgwujuGBUlc2Z1#=) zaSqu|^fn8}yJZ1C zNJ3}D*98d`fY+1}XY%df7n8n4piC|O){P>`71XESHszf?|TL=nGCj?X{(jQZ@ z$R5BzcrCo2gu-_i1y~@v5zxT5#FFp`MiOP-Lad-Iu~b?|x@D|yW|MEImqGUl8T~_Y z(=8rsop-$jivz)Qu~wWLo80S+SZp{D#+?+S=|!5Zs`0wY6S*%?&Nmn%GMD$(}wpFU|16)Sxt!bOv(!VaTXIB@D=`LecjIr}yqSI98>Qkv8P zg46+$yj)LNNZ`sK`vwh0b&B87Op|-NdpfnRUGO&M`EE||*}+lI>~shtfInd>mW)RA z)+3rs)*y@y4Qo^eFCHX5p_=T=mnZ@PPQRi}>-5R)W^Q^2+I9d>`*yjfJa`P1u$5Mn zeJox!)o5rg;AC0otd#Br(-Wwjzc~V7d|ka*dazu;Y%0|ZOB~G5*~(;`O=&QzbS^5F ziH%1WDyD)TI?&1S9+D)^yW0+MNl_^0wv!5(#FLota@Eh#`^IEErbbmA!-*?VA<8z9 zv5O`=*`@CRs4^wc96nZ+r7!lmBu-ZVhSFa+XMZIA0)f9DDM-fu7MvB+EgkZ(H#DX= z$uL6peQ3CEAMs$sZ0SWy*aT}N=t(t{t(=EJNU8E1?-j<{ z8rOcM=Nt#V=hP)5gyy!xy6O9VV%+&2$oGhbf5NGIJO z8S#0%RHaErH-94sNG$cCz*5oD%7R^6OO>ms8Ek#)>XTKCNMp*C5x!)7KnT4lUGKqS zi+uUIRmyDICBQp7y?w?zhCmu`8@1FU74unP#yc@Bzo~b^LR1GXDe~-%3_4no+XB!H z;bb&=dp>$om8F?5!PIh_Hk85mX<8rOJ>uY9Ji3y*UUORflN=7S9g4A-=9w{`x%)!jao@{;`%Ed2)~W37GkG`BL%#n} zqXObh8NQ;ZunMd!WfNJH1jDF>1(!#zbH;qwW0x!qMhx1zkKH5k0ugE)lyy#|vkMGm zPQr1S=5^GE*bBRXj*0;#|KuO+#@_1E9w<5{pD==##nwu>jVqdoCl7?2MqfECUO$b1 zFMI8PC~4>Rcke$j8eFfaGACY5xFfwPvdna>-#o<22zW$t$`!-zUCq@&Olx&G8$q#$ zoB?XX+m)sRb-z9d;j;oQJfFYHMtmS?%x$~((C#D$Awt2iO3ThW8cqk+TUT(lo0ye^(R4l~Q4 z!LC5HBA$m#m&i|%mC9@0g?YNf>ycnYI}DG;45g|Sgwo5kD=Q)q&|+1FQs)U(baTisXZ?vph8z^H}? z8r0)D_qA=emi6TN#WblyRp8%|g9?MgUYUJ(U`u~CzQf)kfQ<5QLQr|z;^}n|!PBm< zGN|G^sG~?0D%}uo7Lng(zLR}Lo9x9&xgIW)vS8N`zB9;}m8JF0g7KbC#%^EQkx||A zA+sw|B_*GWpV0elK1PAY><*U_|ZYN={eLCG>v@Txhp zVXx8$=fs-#7a)92DQPNS{8Vc3sb;LNA*{T!sXx~6khhzK3wN2Br8n6qxoBXdC7Wh5 zLK)=rg&rP8pM^RCXBv)*vMl#PT^=4v94m6DmX5E(kvqB`+*T&AHQ}BW04;QpwdPCI&4x5j_Cv7^-fT`?or=hSYv%d$ z%Mld<5)0~|W964~>i>-f|6gDKiEIBIGXIJHneYA!1^^oTn|{ZCOMCxR_&wbGmB#*i zH26!uWS9Rj@V~OzKc)VRX8)23`(5g9LG7O!zuWnh&i;Ee_=A7f_;*(OJO00-+Q0CJ z7yP#b_P5~nPmMow+rKo}UK(J(Y5Xh2{S*H)eEbWqeCeJ2hW}p_`8)l;`h$Pb5=eie z|3A8ge`^2P!~0A7`Jc7_yQBA~?4Ny>zhvuB|IzY)bzKx?U|@fdh%X=fmj|Ur{Ojuf E0UVex3;+NC literal 0 HcmV?d00001 diff --git a/torchlight/gpu.py b/torchlight/gpu.py new file mode 100644 index 0000000..306c391 --- /dev/null +++ b/torchlight/gpu.py @@ -0,0 +1,35 @@ +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/torchlight/io.py b/torchlight/io.py new file mode 100644 index 0000000..c753ca1 --- /dev/null +++ b/torchlight/io.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/torchlight/torchlight.egg-info/PKG-INFO b/torchlight/torchlight.egg-info/PKG-INFO new file mode 100644 index 0000000..53cafc2 --- /dev/null +++ b/torchlight/torchlight.egg-info/PKG-INFO @@ -0,0 +1,10 @@ +Metadata-Version: 1.0 +Name: torchlight +Version: 1.0 +Summary: A mini framework for pytorch +Home-page: UNKNOWN +Author: UNKNOWN +Author-email: UNKNOWN +License: UNKNOWN +Description: UNKNOWN +Platform: UNKNOWN diff --git a/torchlight/torchlight.egg-info/SOURCES.txt b/torchlight/torchlight.egg-info/SOURCES.txt new file mode 100644 index 0000000..1ee6009 --- /dev/null +++ b/torchlight/torchlight.egg-info/SOURCES.txt @@ -0,0 +1,8 @@ +setup.py +torchlight/__init__.py +torchlight/gpu.py +torchlight/io.py +torchlight.egg-info/PKG-INFO +torchlight.egg-info/SOURCES.txt +torchlight.egg-info/dependency_links.txt +torchlight.egg-info/top_level.txt \ No newline at end of file diff --git a/torchlight/torchlight.egg-info/dependency_links.txt b/torchlight/torchlight.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/torchlight/torchlight.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/torchlight/torchlight.egg-info/top_level.txt b/torchlight/torchlight.egg-info/top_level.txt new file mode 100644 index 0000000..c600430 --- /dev/null +++ b/torchlight/torchlight.egg-info/top_level.txt @@ -0,0 +1 @@ +torchlight From ebbf43b8ffb5af2ade97a0e442e5f8be0bfa8058 Mon Sep 17 00:00:00 2001 From: wang shuxi Date: Fri, 7 May 2021 22:35:03 +0800 Subject: [PATCH 2/9] settled some bugs --- README.md | 6 +- data/readme.md | 1 - main.py | 2 - processor/gpu.py | 35 ++ processor/io.py | 303 +++++++----------- torchlight/__init__.py | 9 +- .../__pycache__/__init__.cpython-36.pyc | Bin 122 -> 392 bytes torchlight/__pycache__/io.cpython-36.pyc | Bin 7038 -> 7050 bytes torchlight/dist/torchlight-1.0-py3.6.egg | Bin 8580 -> 8583 bytes 9 files changed, 154 insertions(+), 202 deletions(-) delete mode 100644 data/readme.md create mode 100644 processor/gpu.py diff --git a/README.md b/README.md index 66256bf..601db34 100644 --- a/README.md +++ b/README.md @@ -43,9 +43,9 @@ Test: python main.py recognition -c config/as_gcn/ntu-xsub/test.yaml For Cross-View, ``` -PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml -TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xsub/train.yaml -Test: python main.py recognition -c config/as_gcn/ntu-xsub/test.yaml +PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xview/train_aim.yaml +TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xview/train.yaml +Test: python main.py recognition -c config/as_gcn/ntu-xview/test.yaml ``` # Acknowledgement diff --git a/data/readme.md b/data/readme.md deleted file mode 100644 index 777d39a..0000000 --- a/data/readme.md +++ /dev/null @@ -1 +0,0 @@ -The filepath of data (NTU-RGB+D) diff --git a/main.py b/main.py index 4575f4a..bd402cb 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,6 @@ processors = dict() processors['recognition'] = import_class('processor.recognition.REC_Processor') - #processors['recognition'] = import_class('processor.processor.Processor.io.__init__') #processors['demo'] = import_class('processor.demo.Demo') subparsers = parser.add_subparsers(dest='processor') @@ -21,6 +20,5 @@ # start Processor = processors[arg.processor] - print(sys.argv[:]) p = Processor(sys.argv[2:]) p.start() diff --git a/processor/gpu.py b/processor/gpu.py new file mode 100644 index 0000000..306c391 --- /dev/null +++ b/processor/gpu.py @@ -0,0 +1,35 @@ +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/processor/io.py b/processor/io.py index c753ca1..b8d11ca 100644 --- a/processor/io.py +++ b/processor/io.py @@ -1,203 +1,116 @@ -#!/usr/bin/env python -import argparse -import os import sys -import traceback -import time -import warnings -import pickle -from collections import OrderedDict +import os +import argparse import yaml import numpy as np -# torch + import torch import torch.nn as nn -import torch.optim as optim -from torch.autograd import Variable -with warnings.catch_warnings(): - warnings.filterwarnings("ignore",category=FutureWarning) - import h5py +import torchlight +from torchlight.io import str2bool +from torchlight.io import DictAction +from torchlight.io import import_class + class IO(): - def __init__(self, work_dir, save_log=True, print_log=True): - self.work_dir = work_dir - self.save_log = save_log - self.print_to_screen = print_log - self.cur_time = time.time() - self.split_timer = {} - self.pavi_logger = None - self.session_file = None - self.model_text = '' - - # PaviLogger is removed in this version - def log(self, *args, **kwargs): - pass - # try: - # if self.pavi_logger is None: - # from torchpack.runner.hooks import PaviLogger - # url = 'http://pavi.parrotsdnn.org/log' - # with open(self.session_file, 'r') as f: - # info = dict( - # session_file=self.session_file, - # session_text=f.read(), - # model_text=self.model_text) - # self.pavi_logger = PaviLogger(url) - # self.pavi_logger.connect(self.work_dir, info=info) - # self.pavi_logger.log(*args, **kwargs) - # except: #pylint: disable=W0702 - # pass - - def load_model(self, model, **model_args): - Model = import_class(model) - model = Model(**model_args) - self.model_text += '\n\n' + str(model) - return model - - def load_weights(self, model, weights_path, ignore_weights=None): - if ignore_weights is None: - ignore_weights = [] - if isinstance(ignore_weights, str): - ignore_weights = [ignore_weights] - - self.print_log('Load weights from {}.'.format(weights_path)) - weights = torch.load(weights_path) - weights = OrderedDict([[k.split('module.')[-1], - v.cpu()] for k, v in weights.items()]) - - # filter weights - for i in ignore_weights: - ignore_name = list() - for w in weights: - if w.find(i) == 0: - ignore_name.append(w) - for n in ignore_name: - weights.pop(n) - self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) - - for w in weights: - self.print_log('Load weights [{}].'.format(w)) - - try: - model.load_state_dict(weights) - except (KeyError, RuntimeError): - state = model.state_dict() - diff = list(set(state.keys()).difference(set(weights.keys()))) - for d in diff: - self.print_log('Can not find weights [{}].'.format(d)) - state.update(weights) - model.load_state_dict(state) - return model - - def save_pkl(self, result, filename): - with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: - pickle.dump(result, f) - - def save_h5(self, result, filename): - with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: - for k in result.keys(): - f[k] = result[k] - - def save_model(self, model, name): - model_path = '{}/{}'.format(self.work_dir, name) - state_dict = model.state_dict() - weights = OrderedDict([[''.join(k.split('module.')), - v.cpu()] for k, v in state_dict.items()]) - torch.save(weights, model_path) - self.print_log('The model has been saved as {}.'.format(model_path)) - - def save_arg(self, arg): - - self.session_file = '{}/config.yaml'.format(self.work_dir) - - # save arg - arg_dict = vars(arg) - if not os.path.exists(self.work_dir): - os.makedirs(self.work_dir) - with open(self.session_file, 'w') as f: - f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) - yaml.dump(arg_dict, f, default_flow_style=False, indent=4) - - def print_log(self, str, print_time=True): - if print_time: - # localtime = time.asctime(time.localtime(time.time())) - str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str - - if self.print_to_screen: - print(str) - if self.save_log: - with open('{}/log.txt'.format(self.work_dir), 'a') as f: - print(str, file=f) - - def init_timer(self, *name): - self.record_time() - self.split_timer = {k: 0.0000001 for k in name} - - def check_time(self, name): - self.split_timer[name] += self.split_time() - - def record_time(self): - self.cur_time = time.time() - return self.cur_time - - def split_time(self): - split_time = time.time() - self.cur_time - self.record_time() - return split_time - - def print_timer(self): - proportion = { - k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) - for k, v in self.split_timer.items() - } - self.print_log('Time consumption:') - for k in proportion: - self.print_log( - '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) - ) - - -def str2bool(v): - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - - -def str2dict(v): - return eval('dict({})'.format(v)) #pylint: disable=W0123 - - -def _import_class_0(name): - components = name.split('.') - mod = __import__(components[0]) - for comp in components[1:]: - mod = getattr(mod, comp) - return mod - - -def import_class(import_str): - mod_str, _sep, class_str = import_str.rpartition('.') - __import__(mod_str) - try: - return getattr(sys.modules[mod_str], class_str) - except AttributeError: - raise ImportError('Class %s cannot be found (%s)' % - (class_str, - traceback.format_exception(*sys.exc_info()))) - - -class DictAction(argparse.Action): - def __init__(self, option_strings, dest, nargs=None, **kwargs): - if nargs is not None: - raise ValueError("nargs not allowed") - super(DictAction, self).__init__(option_strings, dest, **kwargs) - - def __call__(self, parser, namespace, values, option_string=None): - input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 - output_dict = getattr(namespace, self.dest) - for k in input_dict: - output_dict[k] = input_dict[k] - setattr(namespace, self.dest, output_dict) + + def __init__(self, argv=None): + + self.load_arg(argv) + self.init_environment() + self.load_model() + self.load_weights() + self.gpu() + + def load_arg(self, argv=None): + parser = self.get_parser() + + # load arg form config file + p = parser.parse_args(argv) + if p.config is not None: + # load config file + with open(p.config, 'r') as f: + default_arg = yaml.load(f) + + # update parser from config file + key = vars(p).keys() + for k in default_arg.keys(): + if k not in key: + print('Unknown Arguments: {}'.format(k)) + assert k in key + + parser.set_defaults(**default_arg) + + self.arg = parser.parse_args(argv) + + def init_environment(self): + self.save_dir = os.path.join(self.arg.work_dir, + self.arg.max_hop_dir, + self.arg.lamda_act_dir) + self.io = torchlight.io.IO(self.save_dir, save_log=self.arg.save_log, print_log=self.arg.print_log) + self.io.save_arg(self.arg) + + # gpu + if self.arg.use_gpu: + gpus = torchlight.gpu.visible_gpu(self.arg.device) + #torchlight.occupy_gpu(gpus) + self.gpus = gpus + self.dev = "cuda:0" + else: + self.dev = "cpu" + + def load_model(self): + self.model1 = self.io.load_model(self.arg.model1, **(self.arg.model1_args)) + self.model2 = self.io.load_model(self.arg.model2, **(self.arg.model2_args)) + + def load_weights(self): + if self.arg.weights1: + self.model1 = self.io.load_weights(self.model1, self.arg.weights1, self.arg.ignore_weights) + self.model2 = self.io.load_weights(self.model2, self.arg.weights2, self.arg.ignore_weights) + + def gpu(self): + # move modules to gpu + self.model1 = self.model1.to(self.dev) + self.model2 = self.model2.to(self.dev) + for name, value in vars(self).items(): + cls_name = str(value.__class__) + if cls_name.find('torch.nn.modules') != -1: + setattr(self, name, value.to(self.dev)) + + # model parallel + if self.arg.use_gpu and len(self.gpus) > 1: + self.model1 = nn.DataParallel(self.model1, device_ids=self.gpus) + self.model2 = nn.DataParallel(self.model2, device_ids=self.gpus) + + def start(self): + self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg)))) + + @staticmethod + def get_parser(add_help=False): + + #region arguments yapf: disable + # parameter priority: command line > config > default + parser = argparse.ArgumentParser( add_help=add_help, description='IO Processor') + + parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results') + parser.add_argument('-c', '--config', default=None, help='path to the configuration file') + + # processor + parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not') + parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing') + + # visulize and debug + parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not') + parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not') + + # model + parser.add_argument('--model1', default=None, help='the model will be used') + parser.add_argument('--model2', default=None, help='the model will be used') + parser.add_argument('--model1_args', action=DictAction, default=dict(), help='the arguments of model') + parser.add_argument('--model2_args', action=DictAction, default=dict(), help='the arguments of model') + parser.add_argument('--weights', default=None, help='the weights for network initialization') + parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization') + #endregion yapf: enable + + return parser diff --git a/torchlight/__init__.py b/torchlight/__init__.py index 8b13789..07e70f1 100644 --- a/torchlight/__init__.py +++ b/torchlight/__init__.py @@ -1 +1,8 @@ - +from .io import IO +from .io import str2bool +from .io import str2dict +from .io import DictAction +from .io import import_class +from .gpu import visible_gpu +from .gpu import occupy_gpu +from .gpu import ngpu diff --git a/torchlight/__pycache__/__init__.cpython-36.pyc b/torchlight/__pycache__/__init__.cpython-36.pyc index b3c790e409f2747726cc89be56d654bae0ca6888..d75de414854592b86f7c3e4bf36812c1d7880aaf 100644 GIT binary patch literal 392 zcmYk%Jx;?g7=U3r|4Cc!ks(;n6(Oo3A(nRN=Eby5)L3%j$aa7pgcC4v7fz6si7PPi z*#bl&zj;2(Qsm`gQM|5S9^Mf`AF%$O0CWS()Q}q0SfUP7jGY5;Km+d(9MaG`0!K9R zj=?dFy%U*qDNVgoa7Hul44l&(obOOsd}8=jiKpOx3(btxmj|WfuWP|AxVVMyid&(2 znC7Azl(vk^y)iJ)9)%GH*|6gKgK7o{8-Svf5R5>rFN6EpMv}VwL8t6vx&Q#zZeHW9P$Z%C0(9 QkFvRTcQj{^IsX5C13uPUZ~y=R delta 93 zcmeBRt_m>b<>k8YcX9$F0|UcjAcg~wfCCU0vjB+{hF}IwM!%H|MId1W@k?DlBR@A) Wzr46Y-!WL%-Pvz)0HYd62Lk~4KMDZb<=Yr9qMI?vn0ZI=<)e zo%z0bzw^)aJ7?ycnS16wf81B7T!b9HmMR)LDF6V#0w`8}Nx-i|`xiY4{{n7^BKDWb zD3X6UgTjrTRmQ3b$3_MH4Rz9}UMT-b-_cUBu>QHEqZ@VvhlBAXq2Fsugi{iNg4TJ_ zhymEb05@LT&u=WTY{zU6Z~L{-i4}=DU~m<0B!b-eCP6BQ%}vk+pX$aY5{Dh)Hn;WXOVXdBg&-IkCKwa&Ulr(ls>21oM+E@Z zG0*@2AON1gKncfWB!(;EGeh;qSp+lHg;suEUKS&g9w}c15}haOa^HW-Nt+L4L3}dI z0WX4bckX^wA0SV62m|gH8Mny66kJ1BJf0+YtXZ~2AY{EQ_p ziJ-~G2wBi_C#N-OplgI*dYH3dJl$9p24`2d>?ls$m1q1)c&j&Z;_ zZA(`Y3QW#&sk`TZZ{Q{HzuINrPkH}g#!K=faF_0g{ulyXIygS2Jvuq0Rm{g@#oo`M}}C#yK+>oeCTj&N$mht~Z|8ThPs#B-*Y~_PZvn-38fnU2l2oGhM9t>ypIPOPSmoz- z!q(qqp_g0!CR$kuCCJ9ENpRtu@{wNKhJBJ^c760$^Xflac#&-fTVK{XuX+eO-Mjnf zspLYMi)BzMu6-8yQ-RUyr+a^%zBn{i2+5Z1F5cXQT9|>XU@>~qDK+tK4ea0aJ7oIu z^1Xw5u3y^K-8;9SVPhp<%VKG1!)#3aGGBu-FQOYqO5L1-E*1k+{V%wmJfU_)KEu04xuLrr(7e-2})9qWOdPZ*X^+oC(uORkf=Cyz^%yU4adc^Oiv zcLl;itsLFYi#5}|^_vS|j~S(rFA%c|`jrLH=M}TPq6XD>l~C*VzT8mkN5%?>IMZ>{ zzBsUhJB&NvOrx6CdSCo~znK$cjWx_n2&-6gbfIeQB>cR-jZ+jvvdn7luecFhYDN~$ zBehMG!t6|B;b6sBnZbrl4#6qlWcSe_8PK-wjEyeZ z>)Z_c*tx?@kMM-)Lb6-=w^}aRV~$EZ&d$!vm1Zhgn_6`I3Ysh$Q=;dYAvKwmaW>qM znQn8ig^8<>-e^BLdQ|Gc5~}3F@jUX1TEsx%jCOqc!7X2zJT#CL*_}7PCo^}eqor%7+!BxJNAXf9h0LT(` zvjR8qM>P^xr>dmyA`WxMuEkr98Jwd^j4xjo{MObJBZ5x8lLTwm+nqGI>sAy=y>pCb ztOn+7bS5NoWu~w7!ETm53zgWWURwpoowkbz{8R_3S~ zcyVab`U-#5XmtA)+mde}xojChZ-K;b*kQRW62PbgI_?udn|d@VbXTg?UjOqtL!(hE zv7y+npgU@=Cm6hwgc906To&WOH9D}L^>PGw;1SqE2s!5Q|EQc>sKGhD%WY)g-<#B! zUDb~(6@Ks3Cpqt>cktWQ$dAt{tJZL0AfhO$I&UxdU^$bX8YCHlL|ArwPIbdincnf3 z5qMk>uCJ*R0m^>a3|jS3d@FG$bzeapRn;oQyI9EUJY>J5dJUE7jvt{MPCX7mS*oEW zP$FC%P@|<@6XPagJ z*eQAuGnq>|Db%j0vJ&8;oyF&oCE9vZM64M_7c1@{S&cCMzuAkvfx>5I;xdX?)DCVz0JiO_EI;24rbv>7=vNMVfhLwq`1~?! zU9a#;#bVS8S+sLOw+*(3Sl=VGUR+gJL?PZX&P4n#H2YB5s1%i*R;&+0RJ3@KBY9cR zi=OY>L7?t}hy1)c&aH)W4EGo$YU?SsTMg;hOKSl!HICjcik#q6P7g;3i-;z#M+T>J zbsV}im#5hAGP|zJ2aJew>iuFtm{fkaJDzeNvw1(*_*7qS>ZIe$qo5&g?mJGb`Vy@6 zLCC!w2Pn(d9;h~s7YWD_K0~*`jOgUy5cHsFgu2r-TK_NyxU(0pXBO0)R+Wvbp4r*m zXUi~&$rXxGQcd$yVb8E8g_t-Ht|iolDur_rbJ^sfpVr=5a(@fP%w{Wog~(iz<)m6xCc zkt>DJ@@~up%dtIA>MvF6(v){L7t>TWP;qCJFA0@5SBx674~esM5=;XJ;Z2P98n}{6 zhj`uAd52{xooMUN+yXq65A~VH3R7UjOz6gNVR$eUm02+gKEe#J0am5wj8M1}(IgP1} z_>D2(O;%^*ek!TfsDIjyggnX%{vXG9xL@ue{3WNygikl-Pmj@%lV8Na;;Cbw z8QJu%D`@1>x{34|*iuq?UYW;cm2p0Jv4xq{6%gw1`CYw)nmbTMBdZ%q49nO@ve5zBp^w5&>g`T7HYXE$4O^ zHQH0rPi0_c`lO~#44ZC|yva&9EVw}id!}lZ3qzE{M=6y?{Jbo3&8N9+hV)s!#tV`w zRmky;WRerfjJ8{R1XS5lL5+SNsi|?41Rzi{%)b#B{@D}9kSIdqo1lcM_&6qt<@zV% zd0)*8X^nCYFs}Xj-MxpqzvH;Po}E3EJW8V!r(~lDu?=4~3LQ%ukR{d21>_9G zl20Zw2?bi~_`5%?2>m>~1a;~8zs%K8^RciRbX5bt=cjY@1JX2L@>xsya{bED&L^f} z)`h0}Me?>J|3K1zNa(*6G-ISe7lx2tb+0|zw8?|T9*8K9h0IRiPn2{h5I7wbD^&PB zK7zVM*>2o3*gqlTZpuHGN`+#Zu*S_Y-thO>ENhY)ut@)KK@+9B0w_sqE za~qjhOT!eR_*+>0TBwnP35(IK3@>_wpR3N2c?{WZ2ZksJ`$L`iy;nMQdlh&wyE&W8QPcj7=%9aSh#%r z-H+&kfCGJ_vbJ@OP@bgkr>;HI~@T`v&0 zqE3O|Vna$Zz$;^q(v-Yl2j8bjwP;iNhP>~$9}U}>s&UbsN)4d~8_Mu3E+^MXX*p(aKgYWG`yB9kYB5 z-R_kL!(;8GWoAfp6;+OjVk>LMI73|iax<<;UDd5TPzI({HFnH8(qp}Hy8W8Bms<~4 zrlBL-F}-8YfLh+X*rGy>4Qj;E6=?SL`vVW4p@6D#UOB6DlcaLhtvZqn$GjC4Yw({i zM(C%kOeAj^S0XO44fD}~zzvcD-M%#AT)PI^Rx61-Cppj6398S9(chDsnT=tzV@b)5X)pE<4%8*;fCO6qB-Yr&p>`kjN~rAJ9lGW> z8lf{9A|FyDj>fy2hUnYeaz zqQ8v;rx0^&`j*2uRN6+At6!0R>dIm>mS;*!DheKXE5d|Fq0K|pSAKXpur55pLY{IP zse)A!VtXM)HlmD-_zoB0plH+hS@*j4mJ50$GhK7|E;u=NGKHgTX;HGuoCd4rPXFt( ziQSG`eg(OH>k3w6CSO65R6mW{#8XI7@lQ8Z%p&I-iHwxRb&EBQg_ouv$fUKHG^%U# zHJ{FQ-A<`kiz3mQ_F6I$wk3dYUe-s}%?K0x_{RJy-9yKvY2z{wB*s_fHTBL)(p={w ze_Y?h>bQsUuYA-m(YQF}FTMd!E^v%P7oxebMxrjHdw$Dp$c2F;xS#Y(1!z2j40w0> zTzX0$nK{Im<(>P0Pf}W{D5#`p|LaZQHgw{D4M`!L)L(AXz4=QXdYQlcKrj9ub5c() z{hx46FZQ1x#31^A-sB^L9`?Ty<*&tn&oF2({JZ=ABj-O;=igNyR8RO9%Tv25JiM5`ZqKw;`-tIm)_xJ5)%GdGNkBFWI~xgkVm)(RO{F} z+S&$;zrLr4OQvWwj>~RH7e+5ee`_j26B%hytPT8-d zJw7@xV0kpVMUPTMIBT9=(WTuX%c?k?3TSX@)S#L=RewQfBEtBbcuKv_3lR^2gJo!I zE1csq+M85uC4|Q4BAJ)ht({nNNckVg==e&H+|xofGOgx4aKI3R4QZE)gwCXS)LZpQ zYqexWbK7ZhK$iIKCf-V_u#*zKOMF-QxVt59vXl8n9UF~ZV!piTgsEd?V9^FIRoX3F zsK<)x-FXr&D29VdWYa%08GeZLZ87+aX?a0)DrVBr+=KiFcReeavIV$$fvTXL+G&;J znu+%(U{qW*=NUWuc-FQOa;c-vuJK{ZB_OvH)3D1jx6v>E#^7jx!Q?aPtnABaR5uv| zLS|xeUNYm6k*_9yKxbl%Y~<9aO1McywC$6s?5kcV1Q^Dd!C- zD}-jfACn|99!sX|>P9B!63>qaJamy|IfkmeFQAe{$g_5C?Jv^iLA1ZF=%1EgBae5n zfxham7z+qZD`X8$-VCX-+7Cf7g(k9y1c?jDJg97YA8tw1#XbmpXINo%*=2{>gdMPz zgWR2=1+X{d(f@dbezZ^zva3~8G|PKPS{Yry0OVr;fa}kf@$rD7=E&s`!q1md59;X_8y!b^ zma3mEZtyq=Q#_P=k4rLc3nRb+LZRs_mVBM;3)@?%Ms@OUp_))kU(c0|=bV*yagzLA z>aU)|EF8ZM@wTa?to*~HYitkt`eQn+e~2XZdoNn#XRq6?4Sny?#-m%J1lut3H|4NH zuerzhML{#7CLV!^jJ;Qrronx@g%SiqMVnGq2+LE=D*x`rJzi_o3*Or*Gn10J-`QXqOsF~5H*}QJRPptY_NvCyGWI4lMS`)u4WX;Fg0Up_)VyOGQeKpLJR7O)@gb^K zlxEXrLuVTZh8vTUu=d0d9ZB0=HT$l!L%$_WOv&o|Bb_HFOE1gdquFjQh1*8dwUjx* z-V;KNL9@yQFG;HSi7kE*MkOTFM$Bo&dee;FmVL(z##%&&cvJYH=28yuy+p}!3@mR# zlEGdc-&JfK7K}@vVhrc@Pcy6oXQdCO{a@IC(npdI#nyCUGIE4VOi}&JVh(qbb(XQT z*Wrmiip_jQ8k#qB@M}wu#PQY``bQ!R&kJUDUfeGH9e;BZ&4j>S?Chgge|D}cd2ubl z!ma%Udsavum!Jln#bG$#Us^wLu1SFhDBX5AD{%LoxwBc$?Py7tophdOm0I1yUAIKW zyyHeVmk+vN+|y`&CRLbu_;`m}XS3W81iz|K;lyVQSAXT8p(7rm@b)Pp928Z&19#H> zNS1aXRiH%qaNoPnA$--!el5(pNV(2fRM@J`9hAQ6Ni*8M0BkOPURQ-G0e|M#DXtyd zFEs>7%V9r=$!pcg5Ly;jW(1?Spfq!2gFc89JyV(=3)tG_S)N_JtVNX(SK^9!f5*}> zC(xC!gS^C6OXUJb4|t8D%R@Ios(qcW^+r4aVrr$|Gf-yQD+%!s4pL}J^e^ABJV=pf za)3-YMQ)mvJ>2=?H00c1lfCFIP!p_js7n&;Z}FCi+|Pghq}Cay+C}t!XL>oAoEd_1 zNpfL$S)aGypg#Pv{^rjDHSAM1?q`~(zW48126U^8R?*u0*+=y(t>7*yGnwHYEU&;_ zzUa*W(nQ}3nHTWNVh=s36pX-g&`P2eJUqFmjPn+l7~RW#GQqv?9Wq2-cO(Q4R4ORZ z<(vM_Y;FsGpVpdJi}6+n`{k-Gk67?C-XHMp_wWsti5Gjil;75&f*(m2F}7|D zno^{IpB0Fu#hu$D=+=pzYw<0La5t;5F?UN5tFKns<3ApjSFC5^oFKnN9cj_yj;=nw zAK@9aBTy^M#+>s5%W1HIL|B)HE8+#E_6M-nmsuOEa#N<^oO0uRG~iu5tCGHQEY?CU zyuR#$-Uz(JvAs6v$KygyJge`cCwc_we%DwN)%y___qs(D@Uo`JH(<_(mc{hs+s0AV z`l-&qdpBwiH17k?YlecFd;_;7D_l_tyf7K!@vnQi@k|2lgui;`0SNXt!6efoxsSK> z!*<*23g?!y=aV39%#&1`FWcNXO$kXyat32d&L#F*E-3~1gn@}7KD^Hr7|07Li8(BI z@|vDZ3A;8cTnU9k)sodJ3c4{1^)rLd8%{Fp!;c?n3q@LHrWJlMeQ+R zkSwEppiFzi*&}ll8ewz{@PMHcbY)3ej|3&tY|I2(?BLN1qr&Xh>2Pj{bDx=kLGR_F z`ZHIh248xHoOXH#7b4{4wk_%Q!LJvvWvn?M zp=!5oX;hNAF*$K16RGAGrPd}QR^=>4RinpszvC0|J(~j&BI$uSyVZyy(N~9y0c#dC z^_t+`yyciGbkn?JBy;+JKXh6vOV;>6nj8rdbU(U&dhxFH3(}=?1mbcYTVY3#Rx6|B z@=k~J2Y;-v1~{%6znXSpmV{(dmrn%4_cPUhKgmW;J?Aj&SwRe%E7Ej5#8c0Dnl+$d z$M|7Gs`F1zghrEh9kpD#5gmav-i9VzG5VL8xp z0h(!FNIg`p@*cw57w99@6&I+}U*GR1-?-UJov+p`L`3t`bEpwEgS?CZBLFqi;|-1H(Gg?SUfQ*L@|6gONPRuHqY;lI#35A`fvKyV=LGOTm7|sn&)Q zg-{NN|;)TmrWn-u)cVU;0xTL^0o|1(^z>?d&({K7a6|R!*uD>YU zc|mO;N5fQMvZ$$KRxabKz(iJUyBHn{nlUZ9RJ7WJaYl$tnljpa`aXA7cW{2MU@JCa z+Y2)t(i4!04IC~Psi>}eUAF9G*4$PiF0{~HquGbyWiUB?a)73qSv_0Ax-3w)R)7&P zX+t?yTec*S_&KoXdOOTFQ?E+WLnCUEPJIIAXIo(N zS-^4J1XPzK{z$!AMPwr98->zjk8L2J)`rOW&hVh$L?q5GPe};hfl<RpPM3u`!A4u*#Rv5KfSr=~t(_ghaR-(kz7{N>#Hak5omL?vL>kW zjX66@1Qz<_%*WjJXme2|QLG1h+>$OZk*DCf6~z1a;&T+BepG^gCcH>K^BC-^9yAEe zI1zP(n`#D4nXm zgYIUukuR8A3g6uXT6VLcN$?1THMQrV_S4}_ZZ+bWg2~zbzk^xT9~mwKK%)!hb*HmC zhyji|c`>4FMY%7Sag2rA_pl_F5~=oju-j4RX;F_2fc&I1bq*Mx<>l+^zUt_oi ze_Lb?T#v_)cu)nl-UNuTdLdU8kJcL*4`T6X3UW}Fujr>+Y0R2+FpxjF(S+RM-Nti{7SUe7E_S? za-f9xLuPqv)nZ(Qb6ov+e{)P#ck4j1{l0Xski?^9L7~13hm6v}(e`|2EJZ39-`RHj<;J$ax;>YvB|8>8AjaKZpL01~CEtJN*xhPq~Qz From 56a2e1d9f1f014c032422446720c580a5cd36dc3 Mon Sep 17 00:00:00 2001 From: wsx Date: Sat, 8 May 2021 19:33:03 +0800 Subject: [PATCH 3/9] the worked version with the evn yaml file --- README.md | 13 +- data_gen/__pycache__/__init__.cpython-36.pyc | Bin 0 -> 106 bytes .../__pycache__/preprocess.cpython-36.pyc | Bin 0 -> 1811 bytes data_gen/__pycache__/rotation.cpython-36.pyc | Bin 0 -> 1789 bytes data_gen/gpu.py | 35 +++ data_gen/io.py | 203 ++++++++++++++++++ data_gen/ntu_gen_preprocess.py | 3 +- environment.yml | 88 ++++++++ log/data_tree.log | 16 ++ log/train_aim.log | 21 ++ .../__pycache__/__init__.cpython-36.pyc | Bin 392 -> 376 bytes torchlight/__pycache__/io.cpython-36.pyc | Bin 7050 -> 7034 bytes torchlight/gpu.py | 1 + 13 files changed, 376 insertions(+), 4 deletions(-) create mode 100644 data_gen/__pycache__/__init__.cpython-36.pyc create mode 100644 data_gen/__pycache__/preprocess.cpython-36.pyc create mode 100644 data_gen/__pycache__/rotation.cpython-36.pyc create mode 100644 data_gen/gpu.py create mode 100644 data_gen/io.py create mode 100644 environment.yml create mode 100644 log/data_tree.log create mode 100644 log/train_aim.log diff --git a/README.md b/README.md index 601db34..411cfdb 100644 --- a/README.md +++ b/README.md @@ -16,10 +16,16 @@ In this repo, we show the example of model on NTU-RGB+D dataset. # Environments We use the similar input/output interface and system configuration like ST-GCN, where the torchlight module should be set up. +``` +cd torchlight +cp torchlight/torchlight/_init__.py gpu.py io.py ../ +``` +change all "from torchlight import ..." to +"from torchlight.io import ..." Run ``` -cd torchlight, python setup.py, cd .. +cd torchlight, python setup.py install, cd .. ``` @@ -30,13 +36,14 @@ For NTU-RGB+D dataset, you can download it from [NTU-RGB+D](http://rose1.ntu.edu ``` Then, run the preprocessing program to generate the input data, which is very important. ``` -python ./data_gen/ntu_gen_preprocess.py +cd data_gen +python ntu_gen_preprocess.py ``` # Training and Testing With this repo, you can pretrain AIM and save the module at first; then run the code to train the main pipleline of AS-GCN. For the recommended benchmark of Cross-Subject in NTU-RGB+D, ``` -PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml +PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml --device 0 1 2 TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xsub/train.yaml Test: python main.py recognition -c config/as_gcn/ntu-xsub/test.yaml ``` diff --git a/data_gen/__pycache__/__init__.cpython-36.pyc b/data_gen/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..600c7f33ca097297a747c4e361f68e8da10a3d08 GIT binary patch literal 106 zcmXr!<>iXIKRJPsfq~&M5W@jTzyXMhS%5?eLokCTqu)w~B9JhG_$98Vr=OBok{F+! hnx`KhpP83g5+AQuP(n5yc_$5JN*eB$V{=3}fxJLhC^Hc4&`m&o(}%PiTP_`=m#xOH1El z&tU~xJ|VovibI!HX4oq+hm~{=@RVs4#st>lFVRlz9QH%|Q>)Y2-N6dK1I89(Iy%ST zKXY2ngq`IX9(S{^A*N*mBh|(lZnQ{^iiB*Y3 zlM&+qY^60QE1is#eH16BUXe4i(sPb@oTgqai25P(yG(q`SmdqnSOj7aNB$@fe2^)N z#$xBTvS}a!&zT0TV3H>Ril zylx$kc!h*PrIE5nL3-j<4GV}8A|83A`J0FnWzY9A`aE|o$DisHJ>~RxFk*T<^c=Iq zTbZxq3t9_DxO6OXWWOT*8esXm4S#*QF2B#B-GsA*$44wpJISe9gI7O_`6vhnQv*c* zA*PSOAiRp3#Klcq#RT8M4b!*qo0f}PxDLJs{F-r&D!C2NKAOY>(hCZsc#+% zCF#hW3Y>@2RM|-&j#YtWf<@F@TY%^+PJlT&OED4f_I+N_+Uzh6C~ufy8a!r(<`wh0 eDPV6D(@Dra(w8D#1>=%Bc5&5ee&_-k7yk?CCg7p~ literal 0 HcmV?d00001 diff --git a/data_gen/__pycache__/rotation.cpython-36.pyc b/data_gen/__pycache__/rotation.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..722593f2ca1ecb0d44a88bf840ab7c3a85f35edd GIT binary patch literal 1789 zcmcIkJ(Jr+7~U^^*gpF*6DN1%0;ISiqPR=~WhUGV6jWCLmF8v~St}GDwtQL1$C+7G zhLVO~KubqMO~;Q=OLYY^bj(okylcz3Tn2^)+53J=dY^Zn)%*Q^_m24f`L`Zpzp-nV zfPIWrehi_R=2I4Pr$sEZ(DEA=ORcmAt+cNLXitZ_1MTas?m-87NB5ya%|Bz&;0jM- zCn-Idu3K0E)<;<7F=wn|XWX(CKNBmw?@Lxba1p-Z zM^Um-9ilqeR`1$ARl6G0Z|+2NQ{CBAx2mB$Vp;QX$xQjHwc7!5! zfgGUJ83;hMz<~HkB#lBQ_NS8~ndY%7=4KZ0#=pqYYHnP6YjFoa?5s?j=k8-dc1 zmz#aqTO!i~sZsad&tcG8UnBCNi-5>2war5m6m55U+dXyFWmcW&DZ mVK#KbwJFgT#<fd literal 0 HcmV?d00001 diff --git a/data_gen/gpu.py b/data_gen/gpu.py new file mode 100644 index 0000000..306c391 --- /dev/null +++ b/data_gen/gpu.py @@ -0,0 +1,35 @@ +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/data_gen/io.py b/data_gen/io.py new file mode 100644 index 0000000..c753ca1 --- /dev/null +++ b/data_gen/io.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/data_gen/ntu_gen_preprocess.py b/data_gen/ntu_gen_preprocess.py index 6323b30..9bc8423 100644 --- a/data_gen/ntu_gen_preprocess.py +++ b/data_gen/ntu_gen_preprocess.py @@ -140,4 +140,5 @@ def gendata(data_path, out_path, ignored_sample_path=None, benchmark='xsub', set if not os.path.exists(out_path): os.makedirs(out_path) print(b, sn) - gendata(arg.data_path, out_path, arg.ignored_sample_path, benchmark=b, part=sn) \ No newline at end of file + #gendata(arg.data_path, out_path, arg.ignored_sample_path, benchmark=b, part=sn) + gendata(arg.data_path, out_path, arg.ignored_sample_path, benchmark=b, set_name=sn) diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..9d0ea0f --- /dev/null +++ b/environment.yml @@ -0,0 +1,88 @@ +name: asgcn +channels: + - pytorch + - https://mirrors.ustc.edu.cn/anaconda/pkgs/main + - https://mirrors.ustc.edu.cn/anaconda/pkgs/main/ + - https://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge/ + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - blas=1.0=mkl + - ca-certificates=2021.4.13=h06a4308_1 + - certifi=2020.12.5=py36h06a4308_0 + - cffi=1.14.5=py36h261ae71_0 + - cuda90=1.0=h6433d27_0 + - cudatoolkit=10.0.130=0 + - cudnn=7.6.5=cuda10.0_0 + - cycler=0.10.0=py36_0 + - dbus=1.13.18=hb2f20db_0 + - expat=2.3.0=h2531618_2 + - fontconfig=2.13.1=h6c09931_0 + - freetype=2.10.4=h5ab3b9f_0 + - glib=2.68.1=h36276a3_0 + - gst-plugins-base=1.14.0=h8213a91_2 + - gstreamer=1.14.0=h28cd5cc_2 + - icu=58.2=he6710b0_3 + - intel-openmp=2019.4=243 + - jpeg=9b=h024ee3a_2 + - kiwisolver=1.3.1=py36h2531618_0 + - lcms2=2.11=h396b838_0 + - ld_impl_linux-64=2.33.1=h53a641e_7 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.1.0=hdf63c60_0 + - libgfortran-ng=7.3.0=hdf63c60_0 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - libtiff=4.2.0=h3942068_0 + - libuuid=1.0.3=h1bed415_2 + - libwebp-base=1.2.0=h27cfd23_0 + - libxcb=1.14=h7b6447c_0 + - libxml2=2.9.10=hb55368b_3 + - lz4-c=1.9.3=h2531618_0 + - matplotlib=3.3.2=h06a4308_0 + - matplotlib-base=3.3.2=py36h817c723_0 + - mkl=2018.0.3=1 + - mkl_fft=1.0.6=py36h7dd41cf_0 + - mkl_random=1.0.1=py36h4414c95_1 + - ncurses=6.2=he6710b0_1 + - ninja=1.10.2=py36hff7bd54_0 + - olefile=0.46=py36_0 + - openssl=1.1.1k=h27cfd23_0 + - pcre=8.44=he6710b0_0 + - pillow=8.1.2=py36he98fc37_0 + - pip=21.0.1=py36h06a4308_0 + - pycparser=2.20=py_2 + - pyparsing=2.4.7=pyhd3eb1b0_0 + - pyqt=5.9.2=py36h05f1152_2 + - python=3.6.13=hdb3f193_0 + - python-dateutil=2.8.1=pyhd3eb1b0_0 + - qt=5.9.7=h5867ecd_1 + - readline=8.1=h27cfd23_0 + - setuptools=52.0.0=py36h06a4308_0 + - sip=4.19.8=py36hf484d3e_0 + - six=1.15.0=py36h06a4308_0 + - sqlite=3.35.1=hdfb4753_0 + - tbb=2021.2.0=hff7bd54_0 + - tbb4py=2021.2.0=py36hff7bd54_0 + - tk=8.6.10=hbc83047_0 + - tornado=6.1=py36h27cfd23_0 + - wheel=0.36.2=pyhd3eb1b0_0 + - xz=5.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.5=h9ceee32_0 + - pip: + - argparse==1.4.0 + - cached-property==1.5.2 + - dataclasses==0.8 + - h5py==3.1.0 + - imageio==2.9.0 + - numpy==1.19.5 + - opencv-python==4.5.1.48 + - pyyaml==5.4.1 + - scikit-video==1.1.11 + - scipy==1.5.4 + - torch==1.7.1 + - torchvision==0.9.0 + - tqdm==4.60.0 + - typing-extensions==3.7.4.3 diff --git a/log/data_tree.log b/log/data_tree.log new file mode 100644 index 0000000..3bfbe88 --- /dev/null +++ b/log/data_tree.log @@ -0,0 +1,16 @@ +data +|-- NTU-RGB+D +| `-- samples_with_missing_skeletons.txt +`-- nturgb_d + |-- xsub + | |-- train_data_joint_pad.npy + | |-- train_label.pkl + | |-- val_data_joint_pad.npy + | `-- val_label.pkl + `-- xview + |-- train_data_joint_pad.npy + |-- train_label.pkl + |-- val_data_joint_pad.npy + `-- val_label.pkl + +4 directories, 9 files diff --git a/log/train_aim.log b/log/train_aim.log new file mode 100644 index 0000000..1f84ef7 --- /dev/null +++ b/log/train_aim.log @@ -0,0 +1,21 @@ +$python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml --device 0 1 2 + +/root/AS-GCN/processor/io.py:34: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details. + default_arg = yaml.load(f) +/root/AS-GCN/net/utils/adj_learn.py:18: UserWarning: This overload of nonzero is deprecated: + nonzero() +Consider using one of the following signatures instead: + nonzero(*, bool as_tuple) (Triggered internally at /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.) + offdiag_indices = (ones - eye).nonzero().t() +/root/anaconda3/envs/stgcn/lib/python3.6/site-packages/torch/nn/modules/container.py:435: UserWarning: Setting attributes on ParameterList is not supported. + warnings.warn("Setting attributes on ParameterList is not supported.") +/root/AS-GCN/net/utils/adj_learn.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument. + soft_max_1d = F.softmax(trans_input) + +[05.08.21|19:25:22] Parameters: +{'work_dir': './work_dir/recognition/ntu-xsub/AS_GCN', 'config': 'config/as_gcn/ntu-xsub/train_aim.yaml', 'phase': 'train', 'save_result': False, 'start_epoch': 0, 'num_epoch': 10, 'use_gpu': True, 'device': [0, 1, 2], 'log_interval': 100, 'save_interval': 1, 'eval_interval': 5, 'save_log': True, 'print_log': True, 'pavi_log': False, 'feeder': 'feeder.feeder.Feeder', 'num_worker': 4, 'train_feeder_args': {'data_path': './data/nturgb_d/xsub/train_data_joint_pad.npy', 'label_path': './data/nturgb_d/xsub/train_label.pkl', 'random_move': True, 'repeat_pad': True, 'down_sample': True, 'debug': False}, 'test_feeder_args': {'data_path': './data/nturgb_d/xsub/val_data_joint_pad.npy', 'label_path': './data/nturgb_d/xsub/val_label.pkl', 'random_move': False, 'repeat_pad': True, 'down_sample': True}, 'batch_size': 32, 'test_batch_size': 32, 'debug': False, 'model1': 'net.as_gcn.Model', 'model2': 'net.utils.adj_learn.AdjacencyLearn', 'model1_args': {'in_channels': 3, 'num_class': 60, 'dropout': 0.5, 'edge_importance_weighting': True, 'graph_args': {'layout': 'ntu-rgb+d', 'strategy': 'spatial', 'max_hop': 4}}, 'model2_args': {'n_in_enc': 150, 'n_hid_enc': 128, 'edge_types': 3, 'n_in_dec': 3, 'n_hid_dec': 128, 'node_num': 25}, 'weights1': None, 'weights2': None, 'ignore_weights': [], 'show_topk': [1, 5], 'base_lr1': 0.1, 'base_lr2': 0.0005, 'step': [50, 70, 90], 'optimizer': 'SGD', 'nesterov': True, 'weight_decay': 0.0001, 'max_hop_dir': 'max_hop_4', 'lamda_act': 0.5, 'lamda_act_dir': 'lamda_05'} + +[05.08.21|19:25:22] Training epoch: 0 +[05.08.21|19:25:29] Iter 0 Done. | loss2: 876.8732 | loss_nll: 832.9621 | loss_kl: 43.9111 | lr: 0.000500 +[05.08.21|19:26:56] Iter 100 Done. | loss2: 118.9051 | loss_nll: 110.3876 | loss_kl: 8.5176 | lr: 0.000500 +[05.08.21|19:28:14] Iter 200 Done. | loss2: 76.1775 | loss_nll: 71.7404 | loss_kl: 4.4371 | lr: 0.000500 diff --git a/torchlight/__pycache__/__init__.cpython-36.pyc b/torchlight/__pycache__/__init__.cpython-36.pyc index d75de414854592b86f7c3e4bf36812c1d7880aaf..9f75a6bb9658ea632019f7c35cb9c2b299d1129d 100644 GIT binary patch delta 33 ocmeBR{=vj%%*)Hg6ts~|myuUlzbHSyMBg!3*WKA~asZqHfxcs~uDi3JZb4#6a)z$qWCunI E09}C&RsaA1 diff --git a/torchlight/__pycache__/io.cpython-36.pyc b/torchlight/__pycache__/io.cpython-36.pyc index 6643cdd031af6d93a7ca34b6d5c5652079d894fa..4b8a85cc437e91d0559b5d5b476091e9e1f1731d 100644 GIT binary patch delta 34 pcmeA&|7FHz%*)Hg6tt1;G!w6^eo=mYiN0g7uDi3}=66hUBmkYe3jLDY0)59|U3X_c-GaoD Date: Sun, 9 May 2021 19:51:19 +0800 Subject: [PATCH 4/9] train main pipeline worked --- README.md | 4 +++- net/as_gcn.py | 24 +++++++++++++----------- processor/recognition.py | 2 ++ 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 411cfdb..59b8bb0 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ In this repo, we show the example of model on NTU-RGB+D dataset. * pyyaml * argparse * numpy +* torch 1.7.1 # Environments We use the similar input/output interface and system configuration like ST-GCN, where the torchlight module should be set up. @@ -44,7 +45,8 @@ python ntu_gen_preprocess.py With this repo, you can pretrain AIM and save the module at first; then run the code to train the main pipleline of AS-GCN. For the recommended benchmark of Cross-Subject in NTU-RGB+D, ``` PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml --device 0 1 2 -TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xsub/train.yaml +TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xsub/train.yaml --device 0 --batch_size 4 +# only can use one gpu otherwise got the error "Caught RuntimeError in replica 0 on device 0"" Test: python main.py recognition -c config/as_gcn/ntu-xsub/test.yaml ``` diff --git a/net/as_gcn.py b/net/as_gcn.py index 7468be4..fab4925 100644 --- a/net/as_gcn.py +++ b/net/as_gcn.py @@ -52,16 +52,16 @@ def __init__(self, in_channels, num_class, graph_args, def forward(self, x, x_target, x_last, A_act, lamda_act): N, C, T, V, M = x.size() - x_recon = x[:,:,:,:,0] # [2N, 3, 300, 25] - x = x.permute(0, 4, 3, 1, 2).contiguous() # [N, 2, 25, 3, 300] - x = x.view(N * M, V * C, T) # [2N, 75, 300] + x_recon = x[:,:,:,:,0] # [2N, 3, 300, 25] wsx: x_recon(4,3,290,25) select the first person data? + x = x.permute(0, 4, 3, 1, 2).contiguous() # [N, 2, 25, 3, 300] wsx: x(4,2,25,3,290) + x = x.view(N * M, V * C, T) # [2N, 75, 300]m wsx: x(8,75,290) - x_last = x_last.permute(0,4,1,2,3).contiguous().view(-1,3,1,25) + x_last = x_last.permute(0,4,1,2,3).contiguous().view(-1,3,1,25) #(8,3,1,25) x_bn = self.data_bn(x) x_bn = x_bn.view(N, M, V, C, T) x_bn = x_bn.permute(0, 1, 3, 4, 2).contiguous() - x_bn = x_bn.view(N * M, C, T, V) + x_bn = x_bn.view(N * M, C, T, V) #2N,3,290,25 h0, _ = self.class_layer_0(x_bn, self.A * self.edge_importance[0], A_act, lamda_act) # [N, 64, 300, 25] h1, _ = self.class_layer_1(h0, self.A * self.edge_importance[1], A_act, lamda_act) # [N, 64, 300, 25] @@ -74,10 +74,12 @@ def forward(self, x, x_target, x_last, A_act, lamda_act): h7, _ = self.class_layer_7(h6, self.A * self.edge_importance[7], A_act, lamda_act) # [N, 256, 75, 25] h8, _ = self.class_layer_8(h7, self.A * self.edge_importance[8], A_act, lamda_act) # [N, 256, 75, 25] - x_class = F.avg_pool2d(h8, h8.size()[2:]) - x_class = x_class.view(N, M, -1, 1, 1).mean(dim=1) - x_class = self.fcn(x_class) - x_class = x_class.view(x_class.size(0), -1) + x_class = F.avg_pool2d(h8, h8.size()[2:]) #(8,256,1,1) + x_class = x_class.view(N, M, -1, 1, 1).mean(dim=1) #(4,256,1,1) + #x_class = x_class.view(N, M, -1, 1, 1) #(4,2,256,1,1) + #x_class = x_class.mean(dim=1) #(4,256,1,1) + x_class = self.fcn(x_class) #(4,60,1,1) Conv2d(256, 60, kernel_size=(1, 1), stride=(1, 1)) + x_class = x_class.view(x_class.size(0), -1) #(4,60) r0, _ = self.recon_layer_0(h8, self.A*self.edge_importance_recon[0], A_act, lamda_act) # [N, 128, 75, 25] r1, _ = self.recon_layer_1(r0, self.A*self.edge_importance_recon[1], A_act, lamda_act) # [N, 128, 38, 25] @@ -85,8 +87,8 @@ def forward(self, x, x_target, x_last, A_act, lamda_act): r3, _ = self.recon_layer_3(r2, self.A*self.edge_importance_recon[3], A_act, lamda_act) # [N, 128, 10, 25] r4, _ = self.recon_layer_4(r3, self.A*self.edge_importance_recon[4], A_act, lamda_act) # [N, 128, 5, 25] r5, _ = self.recon_layer_5(r4, self.A*self.edge_importance_recon[5], A_act, lamda_act) # [N, 128, 1, 25] - r6, _ = self.recon_layer_6(torch.cat((r5, x_last),1), self.A*self.edge_importance_recon[6], A_act, lamda_act) # [N, 64, 1, 25] - pred = x_last.squeeze().repeat(1,10,1) + r6.squeeze() # [N, 3, 25] + r6, _ = self.recon_layer_6(torch.cat((r5, x_last),1), self.A*self.edge_importance_recon[6], A_act, lamda_act) # [N, 64, 1, 25] wsx:(8,30,1,25) + pred = x_last.squeeze().repeat(1,10,1) + r6.squeeze() # [N, 3, 25] wsx:(8,30,25) pred = pred.contiguous().view(-1, 3, 10, 25) x_target = x_target.permute(0,4,1,2,3).contiguous().view(-1,3,10,25) diff --git a/processor/recognition.py b/processor/recognition.py index 06fb642..0dc333c 100644 --- a/processor/recognition.py +++ b/processor/recognition.py @@ -112,6 +112,7 @@ def train(self, training_A=False): self.epoch_info.clear() for data, data_downsample, target_data, data_last, label in loader: + # data: (32,3,290,25,2) data_downsample:(32,3,50,25,2) target_data:(32,3,10,25,2) data_last:(32,3,1,25,2) label:(32) data = data.float().to(self.dev) data_downsample = data_downsample.float().to(self.dev) label = label.long().to(self.dev) @@ -158,6 +159,7 @@ def train(self, training_A=False): label = label.long().to(self.dev) A_batch, prob, outputs, _ = self.model2(data_downsample) + # wsx x_class, pred, target = self.model1(data, target_data, data_last, A_batch, self.arg.lamda_act) loss_class = self.loss_class(x_class, label) loss_recon = self.loss_pred(pred, target) From e1bd0c2d75cd5ef4f2e5e5a4f7f38357ce73eb87 Mon Sep 17 00:00:00 2001 From: wsx Date: Mon, 17 May 2021 20:16:01 +0800 Subject: [PATCH 5/9] some changes --- feeder/feeder.py | 2 +- net/as_gcn.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/feeder/feeder.py b/feeder/feeder.py index dba96f3..7998a14 100644 --- a/feeder/feeder.py +++ b/feeder/feeder.py @@ -51,7 +51,7 @@ def load_data(self, mmap): self.data = self.data[0:100] self.sample_name = self.sample_name[0:100] - self.N, self.C, self.T, self.V, self.M = self.data.shape + self.N, self.C, self.T, self.V, self.M = self.data.shape # (40091, 3, 300, 25, 2) def __len__(self): return len(self.label) diff --git a/net/as_gcn.py b/net/as_gcn.py index fab4925..0be1829 100644 --- a/net/as_gcn.py +++ b/net/as_gcn.py @@ -50,13 +50,12 @@ def __init__(self, in_channels, num_class, graph_args, self.fcn = nn.Conv2d(256, num_class, kernel_size=1) def forward(self, x, x_target, x_last, A_act, lamda_act): - N, C, T, V, M = x.size() x_recon = x[:,:,:,:,0] # [2N, 3, 300, 25] wsx: x_recon(4,3,290,25) select the first person data? x = x.permute(0, 4, 3, 1, 2).contiguous() # [N, 2, 25, 3, 300] wsx: x(4,2,25,3,290) x = x.view(N * M, V * C, T) # [2N, 75, 300]m wsx: x(8,75,290) - x_last = x_last.permute(0,4,1,2,3).contiguous().view(-1,3,1,25) #(8,3,1,25) + x_last = x_last.permute(0,4,1,2,3).contiguous().view(-1,3,1,25) #(2N,3,1,25) x_bn = self.data_bn(x) x_bn = x_bn.view(N, M, V, C, T) @@ -76,8 +75,6 @@ def forward(self, x, x_target, x_last, A_act, lamda_act): x_class = F.avg_pool2d(h8, h8.size()[2:]) #(8,256,1,1) x_class = x_class.view(N, M, -1, 1, 1).mean(dim=1) #(4,256,1,1) - #x_class = x_class.view(N, M, -1, 1, 1) #(4,2,256,1,1) - #x_class = x_class.mean(dim=1) #(4,256,1,1) x_class = self.fcn(x_class) #(4,60,1,1) Conv2d(256, 60, kernel_size=(1, 1), stride=(1, 1)) x_class = x_class.view(x_class.size(0), -1) #(4,60) From 173673bdd089c7e8e89782b5b0ad834e03fec34c Mon Sep 17 00:00:00 2001 From: wsx Date: Wed, 26 May 2021 22:16:38 +0800 Subject: [PATCH 6/9] the version worked --- processor/recognition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/processor/recognition.py b/processor/recognition.py index 0dc333c..cf23165 100644 --- a/processor/recognition.py +++ b/processor/recognition.py @@ -86,7 +86,7 @@ def nll_gaussian(self, preds, target, variance, add_const=False): return neg_log_p.sum() / (target.size(0) * target.size(1)) def kl_categorical(self, preds, log_prior, num_node, eps=1e-16): - kl_div = preds*(torch.log(preds+eps)-log_prior) + kl_ddiv = preds*(torch.log(preds+eps)-log_prior) return kl_div.sum()/(num_node*preds.size(0)) From d58d922f68c5121276d38f9af9eadb1dfd844273 Mon Sep 17 00:00:00 2001 From: wsx Date: Sun, 30 May 2021 21:45:29 +0800 Subject: [PATCH 7/9] add requirement.txt --- requirement.txt | 327 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 327 insertions(+) create mode 100644 requirement.txt diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000..d282296 --- /dev/null +++ b/requirement.txt @@ -0,0 +1,327 @@ +# This file may be used to create an environment using: +# $ conda create --name --file +# platform: linux-64 +_ipyw_jlab_nb_ext_conf=0.1.0=py38_0 +_libgcc_mutex=0.1=main +alabaster=0.7.12=py_0 +anaconda=2020.11=py38_0 +anaconda-client=1.7.2=py38_0 +anaconda-navigator=1.10.0=py38_0 +anaconda-project=0.8.4=py_0 +argh=0.26.2=py38_0 +argon2-cffi=20.1.0=py38h7b6447c_1 +asn1crypto=1.4.0=py_0 +astroid=2.4.2=py38_0 +astropy=4.0.2=py38h7b6447c_0 +async_generator=1.10=py_0 +atomicwrites=1.4.0=py_0 +attrs=20.3.0=pyhd3eb1b0_0 +autopep8=1.5.4=py_0 +babel=2.8.1=pyhd3eb1b0_0 +backcall=0.2.0=py_0 +backports=1.0=py_2 +backports.functools_lru_cache=1.6.4=pyhd3eb1b0_0 +backports.shutil_get_terminal_size=1.0.0=py38_2 +backports.tempfile=1.0=pyhd3eb1b0_1 +backports.weakref=1.0.post1=py_1 +beautifulsoup4=4.9.3=pyhb0f4dca_0 +bitarray=1.6.1=py38h27cfd23_0 +bkcharts=0.2=py38_0 +blas=1.0=mkl +bleach=3.2.1=py_0 +blosc=1.20.1=hd408876_0 +bokeh=2.2.3=py38_0 +boto=2.49.0=py38_0 +bottleneck=1.3.2=py38heb32a55_1 +brotlipy=0.7.0=py38h7b6447c_1000 +bzip2=1.0.8=h7b6447c_0 +ca-certificates=2020.10.14=0 +cairo=1.14.12=h8948797_3 +certifi=2020.6.20=pyhd3eb1b0_3 +cffi=1.14.3=py38he30daa8_0 +chardet=3.0.4=py38_1003 +click=7.1.2=py_0 +cloudpickle=1.6.0=py_0 +clyent=1.2.2=py38_1 +colorama=0.4.4=py_0 +conda=4.10.1=py38h06a4308_1 +conda-build=3.20.5=py38_1 +conda-env=2.6.0=1 +conda-package-handling=1.7.3=py38h27cfd23_1 +conda-verify=3.4.2=py_1 +contextlib2=0.6.0.post1=py_0 +cryptography=3.1.1=py38h1ba5d50_0 +curl=7.71.1=hbc83047_1 +cycler=0.10.0=py38_0 +cython=0.29.21=py38he6710b0_0 +cytoolz=0.11.0=py38h7b6447c_0 +dask=2.30.0=py_0 +dask-core=2.30.0=py_0 +dbus=1.13.18=hb2f20db_0 +decorator=4.4.2=py_0 +defusedxml=0.6.0=py_0 +diff-match-patch=20200713=py_0 +distributed=2.30.1=py38h06a4308_0 +docutils=0.16=py38_1 +entrypoints=0.3=py38_0 +et_xmlfile=1.0.1=py_1001 +expat=2.2.10=he6710b0_2 +fastcache=1.1.0=py38h7b6447c_0 +filelock=3.0.12=py_0 +flake8=3.8.4=py_0 +flask=1.1.2=py_0 +fontconfig=2.13.0=h9420a91_0 +freetype=2.10.4=h5ab3b9f_0 +fribidi=1.0.10=h7b6447c_0 +fsspec=0.8.3=py_0 +future=0.18.2=py38_1 +get_terminal_size=1.0.0=haa9412d_0 +gevent=20.9.0=py38h7b6447c_0 +glib=2.66.1=h92f7085_0 +glob2=0.7=py_0 +gmp=6.1.2=h6c8ec71_1 +gmpy2=2.0.8=py38hd5f6e3b_3 +graphite2=1.3.14=h23475e2_0 +greenlet=0.4.17=py38h7b6447c_0 +gst-plugins-base=1.14.0=hbbd80ab_1 +gstreamer=1.14.0=hb31296c_0 +h5py=2.10.0=py38h7918eee_0 +harfbuzz=2.4.0=hca77d97_1 +hdf5=1.10.4=hb1b8bf9_0 +heapdict=1.0.1=py_0 +html5lib=1.1=py_0 +icu=58.2=he6710b0_3 +idna=2.10=py_0 +imageio=2.9.0=py_0 +imagesize=1.2.0=py_0 +importlib-metadata=2.0.0=py_1 +importlib_metadata=2.0.0=1 +iniconfig=1.1.1=py_0 +intel-openmp=2020.2=254 +intervaltree=3.1.0=py_0 +ipykernel=5.3.4=py38h5ca1d4c_0 +ipython=7.19.0=py38hb070fc8_0 +ipython_genutils=0.2.0=py38_0 +ipywidgets=7.5.1=py_1 +isort=5.6.4=py_0 +itsdangerous=1.1.0=py_0 +jbig=2.1=hdba287a_0 +jdcal=1.4.1=py_0 +jedi=0.17.1=py38_0 +jeepney=0.5.0=pyhd3eb1b0_0 +jinja2=2.11.2=py_0 +joblib=0.17.0=py_0 +jpeg=9b=h024ee3a_2 +json5=0.9.5=py_0 +jsonschema=3.2.0=py_2 +jupyter=1.0.0=py38_7 +jupyter_client=6.1.7=py_0 +jupyter_console=6.2.0=py_0 +jupyter_core=4.6.3=py38_0 +jupyterlab=2.2.6=py_0 +jupyterlab_pygments=0.1.2=py_0 +jupyterlab_server=1.2.0=py_0 +keyring=21.4.0=py38_1 +kiwisolver=1.3.0=py38h2531618_0 +krb5=1.18.2=h173b8e3_0 +lazy-object-proxy=1.4.3=py38h7b6447c_0 +lcms2=2.11=h396b838_0 +ld_impl_linux-64=2.33.1=h53a641e_7 +libarchive=3.4.2=h62408e4_0 +libcurl=7.71.1=h20c2e04_1 +libedit=3.1.20191231=h14c3975_1 +libffi=3.3=he6710b0_2 +libgcc-ng=9.1.0=hdf63c60_0 +libgfortran-ng=7.3.0=hdf63c60_0 +liblief=0.10.1=he6710b0_0 +libllvm10=10.0.1=hbcb73fb_5 +libpng=1.6.37=hbc83047_0 +libsodium=1.0.18=h7b6447c_0 +libspatialindex=1.9.3=he6710b0_0 +libssh2=1.9.0=h1ba5d50_1 +libstdcxx-ng=9.1.0=hdf63c60_0 +libtiff=4.1.0=h2733197_1 +libtool=2.4.6=h7b6447c_1005 +libuuid=1.0.3=h1bed415_2 +libxcb=1.14=h7b6447c_0 +libxml2=2.9.10=hb55368b_3 +libxslt=1.1.34=hc22bd24_0 +llvmlite=0.34.0=py38h269e1b5_4 +locket=0.2.0=py38_1 +lxml=4.6.1=py38hefd8a0e_0 +lz4-c=1.9.2=heb0550a_3 +lzo=2.10=h7b6447c_2 +markupsafe=1.1.1=py38h7b6447c_0 +matplotlib=3.3.2=0 +matplotlib-base=3.3.2=py38h817c723_0 +mccabe=0.6.1=py38_1 +mistune=0.8.4=py38h7b6447c_1000 +mkl=2020.2=256 +mkl-service=2.3.0=py38he904b0f_0 +mkl_fft=1.2.0=py38h23d657b_0 +mkl_random=1.1.1=py38h0573a6f_0 +mock=4.0.2=py_0 +more-itertools=8.6.0=pyhd3eb1b0_0 +mpc=1.1.0=h10f8cd9_1 +mpfr=4.0.2=hb69a4c5_1 +mpmath=1.1.0=py38_0 +msgpack-python=1.0.0=py38hfd86e86_1 +multipledispatch=0.6.0=py38_0 +navigator-updater=0.2.1=py38_0 +nbclient=0.5.1=py_0 +nbconvert=6.0.7=py38_0 +nbformat=5.0.8=py_0 +ncurses=6.2=he6710b0_1 +nest-asyncio=1.4.2=pyhd3eb1b0_0 +networkx=2.5=py_0 +nltk=3.5=py_0 +nose=1.3.7=py38_2 +notebook=6.1.4=py38_0 +numba=0.51.2=py38h0573a6f_1 +numexpr=2.7.1=py38h423224d_0 +numpy=1.19.2=py38h54aff64_0 +numpy-base=1.19.2=py38hfa32c7d_0 +numpydoc=1.1.0=pyhd3eb1b0_1 +olefile=0.46=py_0 +openpyxl=3.0.5=py_0 +openssl=1.1.1h=h7b6447c_0 +packaging=20.4=py_0 +pandas=1.1.3=py38he6710b0_0 +pandoc=2.11=hb0f4dca_0 +pandocfilters=1.4.3=py38h06a4308_1 +pango=1.45.3=hd140c19_0 +parso=0.7.0=py_0 +partd=1.1.0=py_0 +patchelf=0.12=he6710b0_0 +path=15.0.0=py38_0 +path.py=12.5.0=0 +pathlib2=2.3.5=py38_0 +pathtools=0.1.2=py_1 +patsy=0.5.1=py38_0 +pcre=8.44=he6710b0_0 +pep8=1.7.1=py38_0 +pexpect=4.8.0=py38_0 +pickleshare=0.7.5=py38_1000 +pillow=8.0.1=py38he98fc37_0 +pip=21.1=pypi_0 +pixman=0.40.0=h7b6447c_0 +pkginfo=1.6.1=py38h06a4308_0 +pluggy=0.13.1=py38_0 +ply=3.11=py38_0 +prometheus_client=0.8.0=py_0 +prompt-toolkit=3.0.8=py_0 +prompt_toolkit=3.0.8=0 +psutil=5.7.2=py38h7b6447c_0 +ptyprocess=0.6.0=py38_0 +py=1.9.0=py_0 +py-lief=0.10.1=py38h403a769_0 +pycodestyle=2.6.0=py_0 +pycosat=0.6.3=py38h7b6447c_1 +pycparser=2.20=py_2 +pycurl=7.43.0.6=py38h1ba5d50_0 +pydocstyle=5.1.1=py_0 +pyflakes=2.2.0=py_0 +pygments=2.7.2=pyhd3eb1b0_0 +pylint=2.6.0=py38_0 +pyodbc=4.0.30=py38he6710b0_0 +pyopenssl=19.1.0=py_1 +pyparsing=2.4.7=py_0 +pyqt=5.9.2=py38h05f1152_4 +pyrsistent=0.17.3=py38h7b6447c_0 +pysocks=1.7.1=py38_0 +pytables=3.6.1=py38h9fd0a39_0 +pytest=6.1.1=py38_0 +python=3.8.5=h7579374_1 +python-dateutil=2.8.1=py_0 +python-jsonrpc-server=0.4.0=py_0 +python-language-server=0.35.1=py_0 +python-libarchive-c=2.9=py_0 +pytz=2020.1=py_0 +pywavelets=1.1.1=py38h7b6447c_2 +pyxdg=0.27=pyhd3eb1b0_0 +pyyaml=5.3.1=py38h7b6447c_1 +pyzmq=19.0.2=py38he6710b0_1 +qdarkstyle=2.8.1=py_0 +qt=5.9.7=h5867ecd_1 +qtawesome=1.0.1=py_0 +qtconsole=4.7.7=py_0 +qtpy=1.9.0=py_0 +readline=8.0=h7b6447c_0 +regex=2020.10.15=py38h7b6447c_0 +requests=2.24.0=py_0 +ripgrep=12.1.1=0 +rope=0.18.0=py_0 +rtree=0.9.4=py38_1 +ruamel_yaml=0.15.87=py38h7b6447c_1 +scikit-image=0.17.2=py38hdf5156a_0 +scikit-learn=0.23.2=py38h0573a6f_0 +scipy=1.5.2=py38h0b6359f_0 +seaborn=0.11.0=py_0 +secretstorage=3.1.2=py38_0 +send2trash=1.5.0=py38_0 +setuptools=50.3.1=py38h06a4308_1 +simplegeneric=0.8.1=py38_2 +singledispatch=3.4.0.3=py_1001 +sip=4.19.13=py38he6710b0_0 +six=1.15.0=py38h06a4308_0 +snowballstemmer=2.0.0=py_0 +sortedcollections=1.2.1=py_0 +sortedcontainers=2.2.2=py_0 +soupsieve=2.0.1=py_0 +sphinx=3.2.1=py_0 +sphinxcontrib=1.0=py38_1 +sphinxcontrib-applehelp=1.0.2=py_0 +sphinxcontrib-devhelp=1.0.2=py_0 +sphinxcontrib-htmlhelp=1.0.3=py_0 +sphinxcontrib-jsmath=1.0.1=py_0 +sphinxcontrib-qthelp=1.0.3=py_0 +sphinxcontrib-serializinghtml=1.1.4=py_0 +sphinxcontrib-websupport=1.2.4=py_0 +spyder=4.1.5=py38_0 +spyder-kernels=1.9.4=py38_0 +sqlalchemy=1.3.20=py38h7b6447c_0 +sqlite=3.33.0=h62c20be_0 +statsmodels=0.12.0=py38h7b6447c_0 +sympy=1.6.2=py38h06a4308_1 +tbb=2020.3=hfd86e86_0 +tblib=1.7.0=py_0 +terminado=0.9.1=py38_0 +testpath=0.4.4=py_0 +threadpoolctl=2.1.0=pyh5ca1d4c_0 +tifffile=2020.10.1=py38hdd07704_2 +tk=8.6.10=hbc83047_0 +toml=0.10.1=py_0 +toolz=0.11.1=py_0 +torch=1.8.1=pypi_0 +torchvision=0.9.1=pypi_0 +tornado=6.0.4=py38h7b6447c_1 +tqdm=4.50.2=py_0 +traitlets=5.0.5=py_0 +typing_extensions=3.7.4.3=py_0 +ujson=4.0.1=py38he6710b0_0 +unicodecsv=0.14.1=py38_0 +unixodbc=2.3.9=h7b6447c_0 +urllib3=1.25.11=py_0 +watchdog=0.10.3=py38_0 +wcwidth=0.2.5=py_0 +webencodings=0.5.1=py38_1 +werkzeug=1.0.1=py_0 +wheel=0.35.1=py_0 +widgetsnbextension=3.5.1=py38_0 +wrapt=1.11.2=py38h7b6447c_0 +wurlitzer=2.0.1=py38_0 +xlrd=1.2.0=py_0 +xlsxwriter=1.3.7=py_0 +xlwt=1.3.0=py38_0 +xmltodict=0.12.0=py_0 +xz=5.2.5=h7b6447c_0 +yaml=0.2.5=h7b6447c_0 +yapf=0.30.0=py_0 +zeromq=4.3.3=he6710b0_3 +zict=2.0.0=py_0 +zipp=3.4.0=pyhd3eb1b0_0 +zlib=1.2.11=h7b6447c_3 +zope=1.0=py38_1 +zope.event=4.5.0=py38_0 +zope.interface=5.1.2=py38h7b6447c_0 +zstd=1.4.5=h9ceee32_0 From 66f1e46b5832189f4076f64ec8ad54de5b09d5e0 Mon Sep 17 00:00:00 2001 From: wsx Date: Sun, 13 Jun 2021 21:22:26 +0800 Subject: [PATCH 8/9] add transformer to asgcn and run --- net/model_poseformer.py | 223 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 223 insertions(+) create mode 100644 net/model_poseformer.py diff --git a/net/model_poseformer.py b/net/model_poseformer.py new file mode 100644 index 0000000..e38ac89 --- /dev/null +++ b/net/model_poseformer.py @@ -0,0 +1,223 @@ +## Our PoseFormer model was revised from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + +import math +import logging +from functools import partial +from collections import OrderedDict +from einops import rearrange, repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import load_pretrained +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model + + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class PoseTransformer(nn.Module): + def __init__(self, num_frame=9, num_joints=25, in_chans=3, embed_dim_ratio: object = 32, depth=4, + num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=None, + num_class=60 + ): + """ ##########hybrid_backbone=None, representation_size=None, + Args: + num_frame (int, tuple): input frame number + num_joints (int, tuple): joints number + in_chans (int): number of input channels, 2D joints have 2 channels: (x,y) + embed_dim_ratio (int): embedding dimension ratio + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + num_class (int): the pose action class amount 30 + """ + super().__init__() + + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + embed_dim = embed_dim_ratio * num_joints #### temporal embed_dim is num_joints * spatial embedding dim ratio + out_dim = num_joints * 3 #### output dimension is num_joints * 3 + + ### spatial patch embedding + self.Spatial_patch_to_embedding = nn.Linear(3, 32) + self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio)) + + self.Temporal_pos_embed = nn.Parameter(torch.zeros(1, num_frame, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.Spatial_blocks = nn.ModuleList([ + Block( + dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.Spatial_norm = norm_layer(embed_dim_ratio) + self.Temporal_norm = norm_layer(embed_dim) + + ####### A easy way to implement weighted mean + self.weighted_mean = torch.nn.Conv1d(in_channels=num_frame, out_channels=1, kernel_size=1) + + self.head = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim , out_dim), + ) + + # wsx aciton_class_head + self.action_class_head = nn.Conv2d(290, num_class, kernel_size=1) + + # self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) + self.data_bn = nn.BatchNorm1d(3 * 25) + + + + + + def Spatial_forward_features(self, x): + b, _, f, p = x.shape ##### b is batch size, f is number of frames, p is number of joints + x = rearrange(x, 'b c f p -> (b f) p c', ) + + x = self.Spatial_patch_to_embedding(x) + x += self.Spatial_pos_embed + x = self.pos_drop(x) + + for blk in self.Spatial_blocks: + x = blk(x) + + x = self.Spatial_norm(x) + x = rearrange(x, '(b f) w c -> b f (w c)', f=f) + return x + + def forward_features(self, x): + b = x.shape[0] + x += self.Temporal_pos_embed + x = self.pos_drop(x) + for blk in self.blocks: + x = blk(x) + + x = self.Temporal_norm(x) + ##### x size [b, f, emb_dim], then take weighted mean on frame dimension, we only predict 3D pose of the center frame + # x = self.weighted_mean(x) #wsx don't change all frame to one + # x = x.view(b, 1, -1) + return x + + def forward(self, x, x_target): + ''' + # x input shape [170, 81, 17, 2] + x = x.permute(0, 3, 1, 2) #[170, 2, 81, 17] + b, _, _, p = x.shape #[170, 2, 81, 17] b:batch_size p:joint_num + ### now x is [batch_size, 2 channels, receptive frames, joint_num], following image data + ''' + + N, C, T, V, M = x.size() + x = x.permute(0, 4, 3, 1, 2).contiguous() + x = x.view(N * M, V * C, T) + x = self.data_bn(x) + x = x.view(N, M, V, C, T) + x = x.permute(0, 1, 3, 4, 2).contiguous() + x = x.view(N * M, C, T, V) + + x = self.Spatial_forward_features(x) + x = self.forward_features(x) # (2n, 290,800) + + # action_class_head + BatchN, FrameN, FutureN = x.size() + x = x.view(BatchN, FrameN, FutureN, 1) + x_class = F.avg_pool2d(x, x.size()[2:]) + x_class = x_class.view(N, M, -1, 1, 1).mean(dim=1) + x_class = self.action_class_head(x_class) + x_class = x_class.view(x_class.size(0), -1) + + + #action_class = x.permute(0,2,1) #[170, 544, 1] + #action_class = self.action_class_head(action_class) + #action_class = torch.squeeze(action_class) + #x = self.head(x) + #x = x.view(b, 1, p, -1) + + x_target = x_target.permute(0, 4, 1, 2, 3).contiguous().view(-1, 3, 10, 25) + + return x_class, x_target[::2] # [170,1,17,3] + From e8706917dc83c410d53ba0ebd89eeb91c382fe1c Mon Sep 17 00:00:00 2001 From: 1suancaiyu Date: Sun, 27 Mar 2022 19:30:40 +0800 Subject: [PATCH 9/9] run asgcn on 3090 cuda11.1 --- ...yml => asgcn_3090_cuda11_1_environment.yml | 9 +- processor/io.py | 2 +- requirement.txt | 327 ------------------ .../__pycache__/__init__.cpython-36.pyc | Bin 376 -> 379 bytes torchlight/__pycache__/io.cpython-36.pyc | Bin 7034 -> 7037 bytes 5 files changed, 4 insertions(+), 334 deletions(-) rename environment.yml => asgcn_3090_cuda11_1_environment.yml (90%) delete mode 100644 requirement.txt diff --git a/environment.yml b/asgcn_3090_cuda11_1_environment.yml similarity index 90% rename from environment.yml rename to asgcn_3090_cuda11_1_environment.yml index 9d0ea0f..5c5699e 100644 --- a/environment.yml +++ b/asgcn_3090_cuda11_1_environment.yml @@ -2,9 +2,6 @@ name: asgcn channels: - pytorch - https://mirrors.ustc.edu.cn/anaconda/pkgs/main - - https://mirrors.ustc.edu.cn/anaconda/pkgs/main/ - - https://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge/ - - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ - defaults dependencies: - _libgcc_mutex=0.1=main @@ -79,10 +76,10 @@ dependencies: - imageio==2.9.0 - numpy==1.19.5 - opencv-python==4.5.1.48 - - pyyaml==5.4.1 + - pyyaml==6.0 - scikit-video==1.1.11 - scipy==1.5.4 - - torch==1.7.1 - - torchvision==0.9.0 + - torch==1.8.1+cu111 + - torchvision==0.9.1+cu111 - tqdm==4.60.0 - typing-extensions==3.7.4.3 diff --git a/processor/io.py b/processor/io.py index b8d11ca..b778d8c 100644 --- a/processor/io.py +++ b/processor/io.py @@ -31,7 +31,7 @@ def load_arg(self, argv=None): if p.config is not None: # load config file with open(p.config, 'r') as f: - default_arg = yaml.load(f) + default_arg = yaml.safe_load(f) # update parser from config file key = vars(p).keys() diff --git a/requirement.txt b/requirement.txt deleted file mode 100644 index d282296..0000000 --- a/requirement.txt +++ /dev/null @@ -1,327 +0,0 @@ -# This file may be used to create an environment using: -# $ conda create --name --file -# platform: linux-64 -_ipyw_jlab_nb_ext_conf=0.1.0=py38_0 -_libgcc_mutex=0.1=main -alabaster=0.7.12=py_0 -anaconda=2020.11=py38_0 -anaconda-client=1.7.2=py38_0 -anaconda-navigator=1.10.0=py38_0 -anaconda-project=0.8.4=py_0 -argh=0.26.2=py38_0 -argon2-cffi=20.1.0=py38h7b6447c_1 -asn1crypto=1.4.0=py_0 -astroid=2.4.2=py38_0 -astropy=4.0.2=py38h7b6447c_0 -async_generator=1.10=py_0 -atomicwrites=1.4.0=py_0 -attrs=20.3.0=pyhd3eb1b0_0 -autopep8=1.5.4=py_0 -babel=2.8.1=pyhd3eb1b0_0 -backcall=0.2.0=py_0 -backports=1.0=py_2 -backports.functools_lru_cache=1.6.4=pyhd3eb1b0_0 -backports.shutil_get_terminal_size=1.0.0=py38_2 -backports.tempfile=1.0=pyhd3eb1b0_1 -backports.weakref=1.0.post1=py_1 -beautifulsoup4=4.9.3=pyhb0f4dca_0 -bitarray=1.6.1=py38h27cfd23_0 -bkcharts=0.2=py38_0 -blas=1.0=mkl -bleach=3.2.1=py_0 -blosc=1.20.1=hd408876_0 -bokeh=2.2.3=py38_0 -boto=2.49.0=py38_0 -bottleneck=1.3.2=py38heb32a55_1 -brotlipy=0.7.0=py38h7b6447c_1000 -bzip2=1.0.8=h7b6447c_0 -ca-certificates=2020.10.14=0 -cairo=1.14.12=h8948797_3 -certifi=2020.6.20=pyhd3eb1b0_3 -cffi=1.14.3=py38he30daa8_0 -chardet=3.0.4=py38_1003 -click=7.1.2=py_0 -cloudpickle=1.6.0=py_0 -clyent=1.2.2=py38_1 -colorama=0.4.4=py_0 -conda=4.10.1=py38h06a4308_1 -conda-build=3.20.5=py38_1 -conda-env=2.6.0=1 -conda-package-handling=1.7.3=py38h27cfd23_1 -conda-verify=3.4.2=py_1 -contextlib2=0.6.0.post1=py_0 -cryptography=3.1.1=py38h1ba5d50_0 -curl=7.71.1=hbc83047_1 -cycler=0.10.0=py38_0 -cython=0.29.21=py38he6710b0_0 -cytoolz=0.11.0=py38h7b6447c_0 -dask=2.30.0=py_0 -dask-core=2.30.0=py_0 -dbus=1.13.18=hb2f20db_0 -decorator=4.4.2=py_0 -defusedxml=0.6.0=py_0 -diff-match-patch=20200713=py_0 -distributed=2.30.1=py38h06a4308_0 -docutils=0.16=py38_1 -entrypoints=0.3=py38_0 -et_xmlfile=1.0.1=py_1001 -expat=2.2.10=he6710b0_2 -fastcache=1.1.0=py38h7b6447c_0 -filelock=3.0.12=py_0 -flake8=3.8.4=py_0 -flask=1.1.2=py_0 -fontconfig=2.13.0=h9420a91_0 -freetype=2.10.4=h5ab3b9f_0 -fribidi=1.0.10=h7b6447c_0 -fsspec=0.8.3=py_0 -future=0.18.2=py38_1 -get_terminal_size=1.0.0=haa9412d_0 -gevent=20.9.0=py38h7b6447c_0 -glib=2.66.1=h92f7085_0 -glob2=0.7=py_0 -gmp=6.1.2=h6c8ec71_1 -gmpy2=2.0.8=py38hd5f6e3b_3 -graphite2=1.3.14=h23475e2_0 -greenlet=0.4.17=py38h7b6447c_0 -gst-plugins-base=1.14.0=hbbd80ab_1 -gstreamer=1.14.0=hb31296c_0 -h5py=2.10.0=py38h7918eee_0 -harfbuzz=2.4.0=hca77d97_1 -hdf5=1.10.4=hb1b8bf9_0 -heapdict=1.0.1=py_0 -html5lib=1.1=py_0 -icu=58.2=he6710b0_3 -idna=2.10=py_0 -imageio=2.9.0=py_0 -imagesize=1.2.0=py_0 -importlib-metadata=2.0.0=py_1 -importlib_metadata=2.0.0=1 -iniconfig=1.1.1=py_0 -intel-openmp=2020.2=254 -intervaltree=3.1.0=py_0 -ipykernel=5.3.4=py38h5ca1d4c_0 -ipython=7.19.0=py38hb070fc8_0 -ipython_genutils=0.2.0=py38_0 -ipywidgets=7.5.1=py_1 -isort=5.6.4=py_0 -itsdangerous=1.1.0=py_0 -jbig=2.1=hdba287a_0 -jdcal=1.4.1=py_0 -jedi=0.17.1=py38_0 -jeepney=0.5.0=pyhd3eb1b0_0 -jinja2=2.11.2=py_0 -joblib=0.17.0=py_0 -jpeg=9b=h024ee3a_2 -json5=0.9.5=py_0 -jsonschema=3.2.0=py_2 -jupyter=1.0.0=py38_7 -jupyter_client=6.1.7=py_0 -jupyter_console=6.2.0=py_0 -jupyter_core=4.6.3=py38_0 -jupyterlab=2.2.6=py_0 -jupyterlab_pygments=0.1.2=py_0 -jupyterlab_server=1.2.0=py_0 -keyring=21.4.0=py38_1 -kiwisolver=1.3.0=py38h2531618_0 -krb5=1.18.2=h173b8e3_0 -lazy-object-proxy=1.4.3=py38h7b6447c_0 -lcms2=2.11=h396b838_0 -ld_impl_linux-64=2.33.1=h53a641e_7 -libarchive=3.4.2=h62408e4_0 -libcurl=7.71.1=h20c2e04_1 -libedit=3.1.20191231=h14c3975_1 -libffi=3.3=he6710b0_2 -libgcc-ng=9.1.0=hdf63c60_0 -libgfortran-ng=7.3.0=hdf63c60_0 -liblief=0.10.1=he6710b0_0 -libllvm10=10.0.1=hbcb73fb_5 -libpng=1.6.37=hbc83047_0 -libsodium=1.0.18=h7b6447c_0 -libspatialindex=1.9.3=he6710b0_0 -libssh2=1.9.0=h1ba5d50_1 -libstdcxx-ng=9.1.0=hdf63c60_0 -libtiff=4.1.0=h2733197_1 -libtool=2.4.6=h7b6447c_1005 -libuuid=1.0.3=h1bed415_2 -libxcb=1.14=h7b6447c_0 -libxml2=2.9.10=hb55368b_3 -libxslt=1.1.34=hc22bd24_0 -llvmlite=0.34.0=py38h269e1b5_4 -locket=0.2.0=py38_1 -lxml=4.6.1=py38hefd8a0e_0 -lz4-c=1.9.2=heb0550a_3 -lzo=2.10=h7b6447c_2 -markupsafe=1.1.1=py38h7b6447c_0 -matplotlib=3.3.2=0 -matplotlib-base=3.3.2=py38h817c723_0 -mccabe=0.6.1=py38_1 -mistune=0.8.4=py38h7b6447c_1000 -mkl=2020.2=256 -mkl-service=2.3.0=py38he904b0f_0 -mkl_fft=1.2.0=py38h23d657b_0 -mkl_random=1.1.1=py38h0573a6f_0 -mock=4.0.2=py_0 -more-itertools=8.6.0=pyhd3eb1b0_0 -mpc=1.1.0=h10f8cd9_1 -mpfr=4.0.2=hb69a4c5_1 -mpmath=1.1.0=py38_0 -msgpack-python=1.0.0=py38hfd86e86_1 -multipledispatch=0.6.0=py38_0 -navigator-updater=0.2.1=py38_0 -nbclient=0.5.1=py_0 -nbconvert=6.0.7=py38_0 -nbformat=5.0.8=py_0 -ncurses=6.2=he6710b0_1 -nest-asyncio=1.4.2=pyhd3eb1b0_0 -networkx=2.5=py_0 -nltk=3.5=py_0 -nose=1.3.7=py38_2 -notebook=6.1.4=py38_0 -numba=0.51.2=py38h0573a6f_1 -numexpr=2.7.1=py38h423224d_0 -numpy=1.19.2=py38h54aff64_0 -numpy-base=1.19.2=py38hfa32c7d_0 -numpydoc=1.1.0=pyhd3eb1b0_1 -olefile=0.46=py_0 -openpyxl=3.0.5=py_0 -openssl=1.1.1h=h7b6447c_0 -packaging=20.4=py_0 -pandas=1.1.3=py38he6710b0_0 -pandoc=2.11=hb0f4dca_0 -pandocfilters=1.4.3=py38h06a4308_1 -pango=1.45.3=hd140c19_0 -parso=0.7.0=py_0 -partd=1.1.0=py_0 -patchelf=0.12=he6710b0_0 -path=15.0.0=py38_0 -path.py=12.5.0=0 -pathlib2=2.3.5=py38_0 -pathtools=0.1.2=py_1 -patsy=0.5.1=py38_0 -pcre=8.44=he6710b0_0 -pep8=1.7.1=py38_0 -pexpect=4.8.0=py38_0 -pickleshare=0.7.5=py38_1000 -pillow=8.0.1=py38he98fc37_0 -pip=21.1=pypi_0 -pixman=0.40.0=h7b6447c_0 -pkginfo=1.6.1=py38h06a4308_0 -pluggy=0.13.1=py38_0 -ply=3.11=py38_0 -prometheus_client=0.8.0=py_0 -prompt-toolkit=3.0.8=py_0 -prompt_toolkit=3.0.8=0 -psutil=5.7.2=py38h7b6447c_0 -ptyprocess=0.6.0=py38_0 -py=1.9.0=py_0 -py-lief=0.10.1=py38h403a769_0 -pycodestyle=2.6.0=py_0 -pycosat=0.6.3=py38h7b6447c_1 -pycparser=2.20=py_2 -pycurl=7.43.0.6=py38h1ba5d50_0 -pydocstyle=5.1.1=py_0 -pyflakes=2.2.0=py_0 -pygments=2.7.2=pyhd3eb1b0_0 -pylint=2.6.0=py38_0 -pyodbc=4.0.30=py38he6710b0_0 -pyopenssl=19.1.0=py_1 -pyparsing=2.4.7=py_0 -pyqt=5.9.2=py38h05f1152_4 -pyrsistent=0.17.3=py38h7b6447c_0 -pysocks=1.7.1=py38_0 -pytables=3.6.1=py38h9fd0a39_0 -pytest=6.1.1=py38_0 -python=3.8.5=h7579374_1 -python-dateutil=2.8.1=py_0 -python-jsonrpc-server=0.4.0=py_0 -python-language-server=0.35.1=py_0 -python-libarchive-c=2.9=py_0 -pytz=2020.1=py_0 -pywavelets=1.1.1=py38h7b6447c_2 -pyxdg=0.27=pyhd3eb1b0_0 -pyyaml=5.3.1=py38h7b6447c_1 -pyzmq=19.0.2=py38he6710b0_1 -qdarkstyle=2.8.1=py_0 -qt=5.9.7=h5867ecd_1 -qtawesome=1.0.1=py_0 -qtconsole=4.7.7=py_0 -qtpy=1.9.0=py_0 -readline=8.0=h7b6447c_0 -regex=2020.10.15=py38h7b6447c_0 -requests=2.24.0=py_0 -ripgrep=12.1.1=0 -rope=0.18.0=py_0 -rtree=0.9.4=py38_1 -ruamel_yaml=0.15.87=py38h7b6447c_1 -scikit-image=0.17.2=py38hdf5156a_0 -scikit-learn=0.23.2=py38h0573a6f_0 -scipy=1.5.2=py38h0b6359f_0 -seaborn=0.11.0=py_0 -secretstorage=3.1.2=py38_0 -send2trash=1.5.0=py38_0 -setuptools=50.3.1=py38h06a4308_1 -simplegeneric=0.8.1=py38_2 -singledispatch=3.4.0.3=py_1001 -sip=4.19.13=py38he6710b0_0 -six=1.15.0=py38h06a4308_0 -snowballstemmer=2.0.0=py_0 -sortedcollections=1.2.1=py_0 -sortedcontainers=2.2.2=py_0 -soupsieve=2.0.1=py_0 -sphinx=3.2.1=py_0 -sphinxcontrib=1.0=py38_1 -sphinxcontrib-applehelp=1.0.2=py_0 -sphinxcontrib-devhelp=1.0.2=py_0 -sphinxcontrib-htmlhelp=1.0.3=py_0 -sphinxcontrib-jsmath=1.0.1=py_0 -sphinxcontrib-qthelp=1.0.3=py_0 -sphinxcontrib-serializinghtml=1.1.4=py_0 -sphinxcontrib-websupport=1.2.4=py_0 -spyder=4.1.5=py38_0 -spyder-kernels=1.9.4=py38_0 -sqlalchemy=1.3.20=py38h7b6447c_0 -sqlite=3.33.0=h62c20be_0 -statsmodels=0.12.0=py38h7b6447c_0 -sympy=1.6.2=py38h06a4308_1 -tbb=2020.3=hfd86e86_0 -tblib=1.7.0=py_0 -terminado=0.9.1=py38_0 -testpath=0.4.4=py_0 -threadpoolctl=2.1.0=pyh5ca1d4c_0 -tifffile=2020.10.1=py38hdd07704_2 -tk=8.6.10=hbc83047_0 -toml=0.10.1=py_0 -toolz=0.11.1=py_0 -torch=1.8.1=pypi_0 -torchvision=0.9.1=pypi_0 -tornado=6.0.4=py38h7b6447c_1 -tqdm=4.50.2=py_0 -traitlets=5.0.5=py_0 -typing_extensions=3.7.4.3=py_0 -ujson=4.0.1=py38he6710b0_0 -unicodecsv=0.14.1=py38_0 -unixodbc=2.3.9=h7b6447c_0 -urllib3=1.25.11=py_0 -watchdog=0.10.3=py38_0 -wcwidth=0.2.5=py_0 -webencodings=0.5.1=py38_1 -werkzeug=1.0.1=py_0 -wheel=0.35.1=py_0 -widgetsnbextension=3.5.1=py38_0 -wrapt=1.11.2=py38h7b6447c_0 -wurlitzer=2.0.1=py38_0 -xlrd=1.2.0=py_0 -xlsxwriter=1.3.7=py_0 -xlwt=1.3.0=py38_0 -xmltodict=0.12.0=py_0 -xz=5.2.5=h7b6447c_0 -yaml=0.2.5=h7b6447c_0 -yapf=0.30.0=py_0 -zeromq=4.3.3=he6710b0_3 -zict=2.0.0=py_0 -zipp=3.4.0=pyhd3eb1b0_0 -zlib=1.2.11=h7b6447c_3 -zope=1.0=py38_1 -zope.event=4.5.0=py38_0 -zope.interface=5.1.2=py38h7b6447c_0 -zstd=1.4.5=h9ceee32_0 diff --git a/torchlight/__pycache__/__init__.cpython-36.pyc b/torchlight/__pycache__/__init__.cpython-36.pyc index 9f75a6bb9658ea632019f7c35cb9c2b299d1129d..2136593a4fb500f7a39ec0ece126a23cbe21508d 100644 GIT binary patch delta 31 kcmeyt^qYypn3tD}r^+U2BZoF4r<#6waRmrZwq&#i0E(Ijy#N3J delta 28 jcmey(^n;1Rn3tD}DQIfKMh