diff --git a/LICENSE b/LICENSE index ec0a6fd..e272c92 100644 --- a/LICENSE +++ b/LICENSE @@ -1,23 +1,23 @@ -Copyright (c) 2019, Cooperative Medianet Innovation Center, Shanghai Jiao Tong University -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +Copyright (c) 2019, Cooperative Medianet Innovation Center, Shanghai Jiao Tong University +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md index 66256bf..2f9fec4 100644 --- a/README.md +++ b/README.md @@ -1,64 +1,73 @@ -This repository contains the implementation of: -Actional-Structural Graph Convolutional Networks for Skeleton-based Action Recognition. [Paper](https://arxiv.org/pdf/1904.12659.pdf) - -![image](https://github.com/limaosen0/AS-GCN/blob/master/img/pipeline.png) - -Abstract: Action recognition with skeleton data has recently attracted much attention in computer vision. Previous studies are mostly based on fixed skeleton graphs, only capturing local physical dependencies among joints, which may miss implicit joint correlations. To capture richer dependencies, we introduce an encoder-decoder structure, called A-link inference module, to capture action-specific latent dependencies, i.e. actional links, directly from actions. We also extend the existing skeleton graphs to represent higherorder dependencies, i.e. structural links. Combing the two types of links into a generalized skeleton graph, we further propose the actional-structural graph convolution network (AS-GCN), which stacks actional-structural graph convolution and temporal convolution as a basic building block, to learn both spatial and temporal features for action recognition. A future pose prediction head is added in parallel to the recognition head to help capture more detailed action patterns through self-supervision. We validate AS-GCN in action recognition using two skeleton data sets, NTU-RGB+D and Kinetics. The proposed AS-GCN achieves consistently large improvement compared to the state-of-the-art methods. As a side product, AS-GCN also shows promising results for future pose prediction. - -In this repo, we show the example of model on NTU-RGB+D dataset. - -# Experiment Requirement -* Python 3.6 -* Pytorch 0.4.1 -* pyyaml -* argparse -* numpy - -# Environments -We use the similar input/output interface and system configuration like ST-GCN, where the torchlight module should be set up. - -Run -``` -cd torchlight, python setup.py, cd .. -``` - - -# Data Preparing -For NTU-RGB+D dataset, you can download it from [NTU-RGB+D](http://rose1.ntu.edu.sg/datasets/actionrecognition.asp). And put the dataset in the file path: -``` -'./data/NTU-RGB+D/nturgb+d_skeletons/' -``` -Then, run the preprocessing program to generate the input data, which is very important. -``` -python ./data_gen/ntu_gen_preprocess.py -``` - -# Training and Testing -With this repo, you can pretrain AIM and save the module at first; then run the code to train the main pipleline of AS-GCN. For the recommended benchmark of Cross-Subject in NTU-RGB+D, -``` -PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml -TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xsub/train.yaml -Test: python main.py recognition -c config/as_gcn/ntu-xsub/test.yaml -``` - -For Cross-View, -``` -PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml -TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xsub/train.yaml -Test: python main.py recognition -c config/as_gcn/ntu-xsub/test.yaml -``` - -# Acknowledgement -Thanks for the framework provided by 'yysijie/st-gcn', which is source code of the published work [ST-GCN](https://aaai.org/ocs/index.php/AAAI/AAAI18/paper/view/17135) in AAAI-2018. The github repo is [ST-GCN code](https://github.com/yysijie/st-gcn). We borrow the framework and interface from the code. - -# Citation -If you use this code, please cite our paper: -``` -@InProceedings{Li_2019_CVPR, -author = {Li, Maosen and Chen, Siheng and Chen, Xu and Zhang, Ya and Wang, Yanfeng and Tian, Qi}, -title = {Actional-Structural Graph Convolutional Networks for Skeleton-Based Action Recognition}, -booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, -month = {June}, -year = {2019} -} -``` +This repository contains the implementation of: +Actional-Structural Graph Convolutional Networks for Skeleton-based Action Recognition. [Paper](https://arxiv.org/pdf/1904.12659.pdf) + +![image](https://github.com/limaosen0/AS-GCN/blob/master/img/pipeline.png) + +Abstract: Action recognition with skeleton data has recently attracted much attention in computer vision. Previous studies are mostly based on fixed skeleton graphs, only capturing local physical dependencies among joints, which may miss implicit joint correlations. To capture richer dependencies, we introduce an encoder-decoder structure, called A-link inference module, to capture action-specific latent dependencies, i.e. actional links, directly from actions. We also extend the existing skeleton graphs to represent higherorder dependencies, i.e. structural links. Combing the two types of links into a generalized skeleton graph, we further propose the actional-structural graph convolution network (AS-GCN), which stacks actional-structural graph convolution and temporal convolution as a basic building block, to learn both spatial and temporal features for action recognition. A future pose prediction head is added in parallel to the recognition head to help capture more detailed action patterns through self-supervision. We validate AS-GCN in action recognition using two skeleton data sets, NTU-RGB+D and Kinetics. The proposed AS-GCN achieves consistently large improvement compared to the state-of-the-art methods. As a side product, AS-GCN also shows promising results for future pose prediction. + +In this repo, we show the example of model on NTU-RGB+D dataset. + +# Experiment Requirement +* Python 3.6 +* Pytorch 0.4.1 +* pyyaml +* argparse +* numpy +* torch 1.7.1 + +# Environments +We use the similar input/output interface and system configuration like ST-GCN, where the torchlight module should be set up. +``` +cd torchlight +cp torchlight/torchlight/_init__.py gpu.py io.py ../ +``` +change all "from torchlight import ..." to +"from torchlight.io import ..." + +Run +``` +cd torchlight, python setup.py install, cd .. +``` + + +# Data Preparing +For NTU-RGB+D dataset, you can download it from [NTU-RGB+D](http://rose1.ntu.edu.sg/datasets/actionrecognition.asp). And put the dataset in the file path: +``` +'./data/NTU-RGB+D/nturgb+d_skeletons/' +``` +Then, run the preprocessing program to generate the input data, which is very important. +``` +cd data_gen +python ntu_gen_preprocess.py +``` + +# Training and Testing +With this repo, you can pretrain AIM and save the module at first; then run the code to train the main pipleline of AS-GCN. For the recommended benchmark of Cross-Subject in NTU-RGB+D, +``` +PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml --device 0 1 2 +TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xsub/train.yaml --device 0 --batch_size 4 +# only can use one gpu otherwise got the error "Caught RuntimeError in replica 0 on device 0"" +Test: python main.py recognition -c config/as_gcn/ntu-xsub/test.yaml +``` + +For Cross-View, +``` +PretrainAIM: python main.py recognition -c config/as_gcn/ntu-xview/train_aim.yaml +TrainMainPipeline: python main.py recognition -c config/as_gcn/ntu-xview/train.yaml +Test: python main.py recognition -c config/as_gcn/ntu-xview/test.yaml +``` + +# Acknowledgement +Thanks for the framework provided by 'yysijie/st-gcn', which is source code of the published work [ST-GCN](https://aaai.org/ocs/index.php/AAAI/AAAI18/paper/view/17135) in AAAI-2018. The github repo is [ST-GCN code](https://github.com/yysijie/st-gcn). We borrow the framework and interface from the code. + +# Citation +If you use this code, please cite our paper: +``` +@InProceedings{Li_2019_CVPR, +author = {Li, Maosen and Chen, Siheng and Chen, Xu and Zhang, Ya and Wang, Yanfeng and Tian, Qi}, +title = {Actional-Structural Graph Convolutional Networks for Skeleton-Based Action Recognition}, +booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, +month = {June}, +year = {2019} +} +``` diff --git a/config/as_gcn/ntu-xsub/__init__.py b/config/as_gcn/ntu-xsub/__init__.py index 8b13789..d3f5a12 100644 --- a/config/as_gcn/ntu-xsub/__init__.py +++ b/config/as_gcn/ntu-xsub/__init__.py @@ -1 +1 @@ - + diff --git a/config/as_gcn/ntu-xsub/test.yaml b/config/as_gcn/ntu-xsub/test.yaml index 30fbde9..f81fd70 100644 --- a/config/as_gcn/ntu-xsub/test.yaml +++ b/config/as_gcn/ntu-xsub/test.yaml @@ -1,48 +1,48 @@ -work_dir: ./work_dir/recognition/ntu-xsub/AS_GCN -weights1: ./work_dir/recognition/ntu-xsub/AS_GCN/max_hop_4/lamda_05/epoch99_model1.pt -weights2: ./work_dir/recognition/ntu-xsub/AS_GCN/max_hop_4/lamda_05/epoch99_model2.pt - -feeder: feeder.feeder.Feeder -train_feeder_args: - data_path: ./data/nturgb_d/xsub/train_data_joint_pad.npy - label_path: ./data/nturgb_d/xsub/train_label.pkl - random_move: True - repeat_pad: True - down_sample: True -test_feeder_args: - data_path: ./data/nturgb_d/xsub/val_data_joint_pad.npy - label_path: ./data/nturgb_d/xsub/val_label.pkl - random_move: False - repeat_pad: True - down_sample: True - -model1: net.as_gcn.Model -model1_args: - in_channels: 3 - num_class: 60 - dropout: 0.5 - edge_importance_weighting: True - graph_args: - layout: 'ntu-rgb+d' - strategy: 'spatial' - max_hop: 4 - -model2: net.utils.adj_learn.AdjacencyLearn -model2_args: - n_in_enc: 150 - n_hid_enc: 128 - edge_types: 3 - n_in_dec: 3 - n_hid_dec: 128 - node_num: 25 - -device: [0,1,2,3] -batch_size: 32 -test_batch_size: 32 -num_worker: 4 - -max_hop_dir: max_hop_4 -lamda_act_dir: lamda_05 -lamda_act: 0.5 - -phase: test +work_dir: ./work_dir/recognition/ntu-xsub/AS_GCN +weights1: ./work_dir/recognition/ntu-xsub/AS_GCN/max_hop_4/lamda_05/epoch99_model1.pt +weights2: ./work_dir/recognition/ntu-xsub/AS_GCN/max_hop_4/lamda_05/epoch99_model2.pt + +feeder: feeder.feeder.Feeder +train_feeder_args: + data_path: ./data/nturgb_d/xsub/train_data_joint_pad.npy + label_path: ./data/nturgb_d/xsub/train_label.pkl + random_move: True + repeat_pad: True + down_sample: True +test_feeder_args: + data_path: ./data/nturgb_d/xsub/val_data_joint_pad.npy + label_path: ./data/nturgb_d/xsub/val_label.pkl + random_move: False + repeat_pad: True + down_sample: True + +model1: net.as_gcn.Model +model1_args: + in_channels: 3 + num_class: 60 + dropout: 0.5 + edge_importance_weighting: True + graph_args: + layout: 'ntu-rgb+d' + strategy: 'spatial' + max_hop: 4 + +model2: net.utils.adj_learn.AdjacencyLearn +model2_args: + n_in_enc: 150 + n_hid_enc: 128 + edge_types: 3 + n_in_dec: 3 + n_hid_dec: 128 + node_num: 25 + +device: [0,1,2,3] +batch_size: 32 +test_batch_size: 32 +num_worker: 4 + +max_hop_dir: max_hop_4 +lamda_act_dir: lamda_05 +lamda_act: 0.5 + +phase: test diff --git a/config/as_gcn/ntu-xsub/train.yaml b/config/as_gcn/ntu-xsub/train.yaml index 2d43f64..939f3a3 100644 --- a/config/as_gcn/ntu-xsub/train.yaml +++ b/config/as_gcn/ntu-xsub/train.yaml @@ -1,54 +1,54 @@ -work_dir: ./work_dir/recognition/ntu-xsub/AS_GCN - -weights1: ./work_dir/recognition/ntu-xsub/AS_GCN/max_hop_4/lamda_05/epoch9_model1.pt -weights2: ./work_dir/recognition/ntu-xsub/AS_GCN/max_hop_4/lamda_05/epoch9_model2.pt - -feeder: feeder.feeder.Feeder -train_feeder_args: - data_path: ./data/nturgb_d/xsub/train_data_joint_pad.npy - label_path: ./data/nturgb_d/xsub/train_label.pkl - random_move: True - repeat_pad: True - down_sample: True -test_feeder_args: - data_path: ./data/nturgb_d/xsub/val_data_joint_pad.npy - label_path: ./data/nturgb_d/xsub/val_label.pkl - random_move: False - repeat_pad: True - down_sample: True - -model1: net.as_gcn.Model -model1_args: - in_channels: 3 - num_class: 60 - dropout: 0.5 - edge_importance_weighting: True - graph_args: - layout: 'ntu-rgb+d' - strategy: 'spatial' - max_hop: 4 - -model2: net.utils.adj_learn.AdjacencyLearn -model2_args: - n_in_enc: 150 - n_hid_enc: 128 - edge_types: 3 - n_in_dec: 3 - n_hid_dec: 128 - node_num: 25 - -weight_decay: 0.0001 -base_lr1: 0.1 -base_lr2: 0.0005 -step: [50, 70, 90] - -device: [0,1,2,3] -batch_size: 32 -test_batch_size: 32 -start_epoch: 10 -num_epoch: 100 -num_worker: 4 - -max_hop_dir: max_hop_4 -lamda_act_dir: lamda_05 -lamda_act: 0.5 +work_dir: ./work_dir/recognition/ntu-xsub/AS_GCN + +weights1: ./work_dir/recognition/ntu-xsub/AS_GCN/max_hop_4/lamda_05/epoch9_model1.pt +weights2: ./work_dir/recognition/ntu-xsub/AS_GCN/max_hop_4/lamda_05/epoch9_model2.pt + +feeder: feeder.feeder.Feeder +train_feeder_args: + data_path: ./data/nturgb_d/xsub/train_data_joint_pad.npy + label_path: ./data/nturgb_d/xsub/train_label.pkl + random_move: True + repeat_pad: True + down_sample: True +test_feeder_args: + data_path: ./data/nturgb_d/xsub/val_data_joint_pad.npy + label_path: ./data/nturgb_d/xsub/val_label.pkl + random_move: False + repeat_pad: True + down_sample: True + +model1: net.as_gcn.Model +model1_args: + in_channels: 3 + num_class: 60 + dropout: 0.5 + edge_importance_weighting: True + graph_args: + layout: 'ntu-rgb+d' + strategy: 'spatial' + max_hop: 4 + +model2: net.utils.adj_learn.AdjacencyLearn +model2_args: + n_in_enc: 150 + n_hid_enc: 128 + edge_types: 3 + n_in_dec: 3 + n_hid_dec: 128 + node_num: 25 + +weight_decay: 0.0001 +base_lr1: 0.0076 +base_lr2: 0.0005 +step: [50, 70, 90] + +device: [0,1,2,3] +batch_size: 32 +test_batch_size: 32 +start_epoch: 10 +num_epoch: 100 +num_worker: 4 + +max_hop_dir: max_hop_4 +lamda_act_dir: lamda_05 +lamda_act: 0.5 diff --git a/config/as_gcn/ntu-xsub/train_aim.yaml b/config/as_gcn/ntu-xsub/train_aim.yaml index c4c54b7..d74e1cf 100644 --- a/config/as_gcn/ntu-xsub/train_aim.yaml +++ b/config/as_gcn/ntu-xsub/train_aim.yaml @@ -1,51 +1,51 @@ -work_dir: ./work_dir/recognition/ntu-xsub/AS_GCN - -feeder: feeder.feeder.Feeder -train_feeder_args: - data_path: ./data/nturgb_d/xsub/train_data_joint_pad.npy - label_path: ./data/nturgb_d/xsub/train_label.pkl - random_move: True - repeat_pad: True - down_sample: True -test_feeder_args: - data_path: ./data/nturgb_d/xsub/val_data_joint_pad.npy - label_path: ./data/nturgb_d/xsub/val_label.pkl - random_move: False - repeat_pad: True - down_sample: True - -model1: net.as_gcn.Model -model1_args: - in_channels: 3 - num_class: 60 - dropout: 0.5 - edge_importance_weighting: True - graph_args: - layout: 'ntu-rgb+d' - strategy: 'spatial' - max_hop: 4 - -model2: net.utils.adj_learn.AdjacencyLearn -model2_args: - n_in_enc: 150 - n_hid_enc: 128 - edge_types: 3 - n_in_dec: 3 - n_hid_dec: 128 - node_num: 25 - -weight_decay: 0.0001 -base_lr1: 0.1 -base_lr2: 0.0005 -step: [50, 70, 90] - -device: [0,1,2,3] -batch_size: 32 -test_batch_size: 32 -start_epoch: 0 -num_epoch: 10 -num_worker: 4 - -max_hop_dir: max_hop_4 -lamda_act_dir: lamda_05 -lamda_act: 0.5 +work_dir: ./work_dir/recognition/ntu-xsub/AS_GCN + +feeder: feeder.feeder.Feeder +train_feeder_args: + data_path: ./data/nturgb_d/xsub/train_data_joint_pad.npy + label_path: ./data/nturgb_d/xsub/train_label.pkl + random_move: True + repeat_pad: True + down_sample: True +test_feeder_args: + data_path: ./data/nturgb_d/xsub/val_data_joint_pad.npy + label_path: ./data/nturgb_d/xsub/val_label.pkl + random_move: False + repeat_pad: True + down_sample: True + +model1: net.as_gcn.Model +model1_args: + in_channels: 3 + num_class: 60 + dropout: 0.5 + edge_importance_weighting: True + graph_args: + layout: 'ntu-rgb+d' + strategy: 'spatial' + max_hop: 4 + +model2: net.utils.adj_learn.AdjacencyLearn +model2_args: + n_in_enc: 150 + n_hid_enc: 128 + edge_types: 3 + n_in_dec: 3 + n_hid_dec: 128 + node_num: 25 + +weight_decay: 0.0001 +base_lr1: 0.1 +base_lr2: 0.0005 +step: [50, 70, 90] + +device: [0,1,2,3] +batch_size: 32 +test_batch_size: 32 +start_epoch: 0 +num_epoch: 10 +num_worker: 4 + +max_hop_dir: max_hop_4 +lamda_act_dir: lamda_05 +lamda_act: 0.5 diff --git a/config/as_gcn/ntu-xview/__init__.py b/config/as_gcn/ntu-xview/__init__.py index 8b13789..d3f5a12 100644 --- a/config/as_gcn/ntu-xview/__init__.py +++ b/config/as_gcn/ntu-xview/__init__.py @@ -1 +1 @@ - + diff --git a/config/as_gcn/ntu-xview/test.yaml b/config/as_gcn/ntu-xview/test.yaml index 50a1400..1496724 100644 --- a/config/as_gcn/ntu-xview/test.yaml +++ b/config/as_gcn/ntu-xview/test.yaml @@ -1,48 +1,48 @@ -work_dir: ./work_dir/recognition/ntu-xview/AS_GCN -weights1: ./work_dir/recognition/ntu-xview/AS_GCN/max_hop_4/lamda_05/epoch99_model1.pt -weights2: ./work_dir/recognition/ntu-xview/AS_GCN/max_hop_4/lamda_05/epoch99_model2.pt - -feeder: feeder.feeder.Feeder -train_feeder_args: - data_path: ./data/nturgb_d/xview/train_data_joint_pad.npy - label_path: ./data/nturgb_d/xview/train_label.pkl - random_move: True - repeat_pad: True - down_sample: True -test_feeder_args: - data_path: ./data/nturgb_d/xview/val_data_joint_pad.npy - label_path: ./data/nturgb_d/xview/val_label.pkl - random_move: False - repeat_pad: True - down_sample: True - -model1: net.as_gcn.Model -model1_args: - in_channels: 3 - num_class: 60 - dropout: 0.5 - edge_importance_weighting: True - graph_args: - layout: 'ntu-rgb+d' - strategy: 'spatial' - max_hop: 4 - -model2: net.utils.adj_learn.AdjacencyLearn -model2_args: - n_in_enc: 150 - n_hid_enc: 128 - edge_types: 3 - n_in_dec: 3 - n_hid_dec: 128 - node_num: 25 - -device: [0,1,2,3] -batch_size: 32 -test_batch_size: 32 -num_worker: 4 - -max_hop_dir: max_hop_4 -lamda_act_dir: lamda_05 -lamda_act: 0.5 - -phase: test +work_dir: ./work_dir/recognition/ntu-xview/AS_GCN +weights1: ./work_dir/recognition/ntu-xview/AS_GCN/max_hop_4/lamda_05/epoch99_model1.pt +weights2: ./work_dir/recognition/ntu-xview/AS_GCN/max_hop_4/lamda_05/epoch99_model2.pt + +feeder: feeder.feeder.Feeder +train_feeder_args: + data_path: ./data/nturgb_d/xview/train_data_joint_pad.npy + label_path: ./data/nturgb_d/xview/train_label.pkl + random_move: True + repeat_pad: True + down_sample: True +test_feeder_args: + data_path: ./data/nturgb_d/xview/val_data_joint_pad.npy + label_path: ./data/nturgb_d/xview/val_label.pkl + random_move: False + repeat_pad: True + down_sample: True + +model1: net.as_gcn.Model +model1_args: + in_channels: 3 + num_class: 60 + dropout: 0.5 + edge_importance_weighting: True + graph_args: + layout: 'ntu-rgb+d' + strategy: 'spatial' + max_hop: 4 + +model2: net.utils.adj_learn.AdjacencyLearn +model2_args: + n_in_enc: 150 + n_hid_enc: 128 + edge_types: 3 + n_in_dec: 3 + n_hid_dec: 128 + node_num: 25 + +device: [0,1,2,3] +batch_size: 32 +test_batch_size: 32 +num_worker: 4 + +max_hop_dir: max_hop_4 +lamda_act_dir: lamda_05 +lamda_act: 0.5 + +phase: test diff --git a/config/as_gcn/ntu-xview/train.yaml b/config/as_gcn/ntu-xview/train.yaml index da040ed..85344f7 100644 --- a/config/as_gcn/ntu-xview/train.yaml +++ b/config/as_gcn/ntu-xview/train.yaml @@ -1,54 +1,54 @@ -work_dir: ./work_dir/recognition/ntu-xview/AS_GCN - -weights1: ./work_dir/recognition/ntu-xview/AS_GCN/max_hop_4/lamda_05/epoch9_model1.pt -weights2: ./work_dir/recognition/ntu-xview/AS_GCN/max_hop_4/lamda_05/epoch9_model2.pt - -feeder: feeder.feeder.Feeder -train_feeder_args: - data_path: ./data/nturgb_d/xview/train_data_joint_pad.npy - label_path: ./data/nturgb_d/xview/train_label.pkl - random_move: True - repeat_pad: True - down_sample: True -test_feeder_args: - data_path: ./data/nturgb_d/xview/val_data_joint_pad.npy - label_path: ./data/nturgb_d/xview/val_label.pkl - random_move: False - repeat_pad: True - down_sample: True - -model1: net.as_gcn.Model -model1_args: - in_channels: 3 - num_class: 60 - dropout: 0.5 - edge_importance_weighting: True - graph_args: - layout: 'ntu-rgb+d' - strategy: 'spatial' - max_hop: 4 - -model2: net.utils.adj_learn.AdjacencyLearn -model2_args: - n_in_enc: 150 - n_hid_enc: 128 - edge_types: 3 - n_in_dec: 3 - n_hid_dec: 128 - node_num: 25 - -weight_decay: 0.0001 -base_lr1: 0.1 -base_lr2: 0.0005 -step: [50, 70, 90] - -device: [0,1,2,3] -batch_size: 32 -test_batch_size: 32 -start_epoch: 10 -num_epoch: 100 -num_worker: 4 - -max_hop_dir: max_hop_4 -lamda_act_dir: lamda_05 -lamda_act: 0.5 +work_dir: ./work_dir/recognition/ntu-xview/AS_GCN + +weights1: ./work_dir/recognition/ntu-xview/AS_GCN/max_hop_4/lamda_05/epoch9_model1.pt +weights2: ./work_dir/recognition/ntu-xview/AS_GCN/max_hop_4/lamda_05/epoch9_model2.pt + +feeder: feeder.feeder.Feeder +train_feeder_args: + data_path: ./data/nturgb_d/xview/train_data_joint_pad.npy + label_path: ./data/nturgb_d/xview/train_label.pkl + random_move: True + repeat_pad: True + down_sample: True +test_feeder_args: + data_path: ./data/nturgb_d/xview/val_data_joint_pad.npy + label_path: ./data/nturgb_d/xview/val_label.pkl + random_move: False + repeat_pad: True + down_sample: True + +model1: net.as_gcn.Model +model1_args: + in_channels: 3 + num_class: 60 + dropout: 0.5 + edge_importance_weighting: True + graph_args: + layout: 'ntu-rgb+d' + strategy: 'spatial' + max_hop: 4 + +model2: net.utils.adj_learn.AdjacencyLearn +model2_args: + n_in_enc: 150 + n_hid_enc: 128 + edge_types: 3 + n_in_dec: 3 + n_hid_dec: 128 + node_num: 25 + +weight_decay: 0.0001 +base_lr1: 0.1 +base_lr2: 0.0005 +step: [50, 70, 90] + +device: [0,1,2,3] +batch_size: 32 +test_batch_size: 32 +start_epoch: 10 +num_epoch: 100 +num_worker: 4 + +max_hop_dir: max_hop_4 +lamda_act_dir: lamda_05 +lamda_act: 0.5 diff --git a/config/as_gcn/ntu-xview/train_aim.yaml b/config/as_gcn/ntu-xview/train_aim.yaml index ac9e2ff..ec8aaf7 100644 --- a/config/as_gcn/ntu-xview/train_aim.yaml +++ b/config/as_gcn/ntu-xview/train_aim.yaml @@ -1,51 +1,51 @@ -work_dir: ./work_dir/recognition/ntu-xview/AS_GCN - -feeder: feeder.feeder.Feeder -train_feeder_args: - data_path: ./data/nturgb_d/xview/train_data_joint_pad.npy - label_path: ./data/nturgb_d/xview/train_label.pkl - random_move: True - repeat_pad: True - down_sample: True -test_feeder_args: - data_path: ./data/nturgb_d/xview/val_data_joint_pad.npy - label_path: ./data/nturgb_d/xview/val_label.pkl - random_move: False - repeat_pad: True - down_sample: True - -model1: net.as_gcn.Model -model1_args: - in_channels: 3 - num_class: 60 - dropout: 0.5 - edge_importance_weighting: True - graph_args: - layout: 'ntu-rgb+d' - strategy: 'spatial' - max_hop: 4 - -model2: net.utils.adj_learn.AdjacencyLearn -model2_args: - n_in_enc: 150 - n_hid_enc: 128 - edge_types: 3 - n_in_dec: 3 - n_hid_dec: 128 - node_num: 25 - -weight_decay: 0.0001 -base_lr1: 0.1 -base_lr2: 0.0005 -step: [40, 70, 90] - -device: [0,1,2,3] -batch_size: 32 -test_batch_size: 32 -start_epoch: 0 -num_epoch: 10 -num_worker: 4 - -max_hop_dir: max_hop_4 -lamda_act_dir: lamda_05 -lamda_act: 0.5 +work_dir: ./work_dir/recognition/ntu-xview/AS_GCN + +feeder: feeder.feeder.Feeder +train_feeder_args: + data_path: ./data/nturgb_d/xview/train_data_joint_pad.npy + label_path: ./data/nturgb_d/xview/train_label.pkl + random_move: True + repeat_pad: True + down_sample: True +test_feeder_args: + data_path: ./data/nturgb_d/xview/val_data_joint_pad.npy + label_path: ./data/nturgb_d/xview/val_label.pkl + random_move: False + repeat_pad: True + down_sample: True + +model1: net.as_gcn.Model +model1_args: + in_channels: 3 + num_class: 60 + dropout: 0.5 + edge_importance_weighting: True + graph_args: + layout: 'ntu-rgb+d' + strategy: 'spatial' + max_hop: 4 + +model2: net.utils.adj_learn.AdjacencyLearn +model2_args: + n_in_enc: 150 + n_hid_enc: 128 + edge_types: 3 + n_in_dec: 3 + n_hid_dec: 128 + node_num: 25 + +weight_decay: 0.0001 +base_lr1: 0.1 +base_lr2: 0.0005 +step: [40, 70, 90] + +device: [0,1,2,3] +batch_size: 32 +test_batch_size: 32 +start_epoch: 0 +num_epoch: 10 +num_worker: 4 + +max_hop_dir: max_hop_4 +lamda_act_dir: lamda_05 +lamda_act: 0.5 diff --git a/data/NTU-RGB+D/samples_with_missing_skeletons.txt b/data/NTU-RGB+D/samples_with_missing_skeletons.txt deleted file mode 100644 index 5ad472e..0000000 --- a/data/NTU-RGB+D/samples_with_missing_skeletons.txt +++ /dev/null @@ -1,302 +0,0 @@ -S001C002P005R002A008 -S001C002P006R001A008 -S001C003P002R001A055 -S001C003P002R002A012 -S001C003P005R002A004 -S001C003P005R002A005 -S001C003P005R002A006 -S001C003P006R002A008 -S002C002P011R002A030 -S002C003P008R001A020 -S002C003P010R002A010 -S002C003P011R002A007 -S002C003P011R002A011 -S002C003P014R002A007 -S003C001P019R001A055 -S003C002P002R002A055 -S003C002P018R002A055 -S003C003P002R001A055 -S003C003P016R001A055 -S003C003P018R002A024 -S004C002P003R001A013 -S004C002P008R001A009 -S004C002P020R001A003 -S004C002P020R001A004 -S004C002P020R001A012 -S004C002P020R001A020 -S004C002P020R001A021 -S004C002P020R001A036 -S005C002P004R001A001 -S005C002P004R001A003 -S005C002P010R001A016 -S005C002P010R001A017 -S005C002P010R001A048 -S005C002P010R001A049 -S005C002P016R001A009 -S005C002P016R001A010 -S005C002P018R001A003 -S005C002P018R001A028 -S005C002P018R001A029 -S005C003P016R002A009 -S005C003P018R002A013 -S005C003P021R002A057 -S006C001P001R002A055 -S006C002P007R001A005 -S006C002P007R001A006 -S006C002P016R001A043 -S006C002P016R001A051 -S006C002P016R001A052 -S006C002P022R001A012 -S006C002P023R001A020 -S006C002P023R001A021 -S006C002P023R001A022 -S006C002P023R001A023 -S006C002P024R001A018 -S006C002P024R001A019 -S006C003P001R002A013 -S006C003P007R002A009 -S006C003P007R002A010 -S006C003P007R002A025 -S006C003P016R001A060 -S006C003P017R001A055 -S006C003P017R002A013 -S006C003P017R002A014 -S006C003P017R002A015 -S006C003P022R002A013 -S007C001P018R002A050 -S007C001P025R002A051 -S007C001P028R001A050 -S007C001P028R001A051 -S007C001P028R001A052 -S007C002P008R002A008 -S007C002P015R002A055 -S007C002P026R001A008 -S007C002P026R001A009 -S007C002P026R001A010 -S007C002P026R001A011 -S007C002P026R001A012 -S007C002P026R001A050 -S007C002P027R001A011 -S007C002P027R001A013 -S007C002P028R002A055 -S007C003P007R001A002 -S007C003P007R001A004 -S007C003P019R001A060 -S007C003P027R002A001 -S007C003P027R002A002 -S007C003P027R002A003 -S007C003P027R002A004 -S007C003P027R002A005 -S007C003P027R002A006 -S007C003P027R002A007 -S007C003P027R002A008 -S007C003P027R002A009 -S007C003P027R002A010 -S007C003P027R002A011 -S007C003P027R002A012 -S007C003P027R002A013 -S008C002P001R001A009 -S008C002P001R001A010 -S008C002P001R001A014 -S008C002P001R001A015 -S008C002P001R001A016 -S008C002P001R001A018 -S008C002P001R001A019 -S008C002P008R002A059 -S008C002P025R001A060 -S008C002P029R001A004 -S008C002P031R001A005 -S008C002P031R001A006 -S008C002P032R001A018 -S008C002P034R001A018 -S008C002P034R001A019 -S008C002P035R001A059 -S008C002P035R002A002 -S008C002P035R002A005 -S008C003P007R001A009 -S008C003P007R001A016 -S008C003P007R001A017 -S008C003P007R001A018 -S008C003P007R001A019 -S008C003P007R001A020 -S008C003P007R001A021 -S008C003P007R001A022 -S008C003P007R001A023 -S008C003P007R001A025 -S008C003P007R001A026 -S008C003P007R001A028 -S008C003P007R001A029 -S008C003P007R002A003 -S008C003P008R002A050 -S008C003P025R002A002 -S008C003P025R002A011 -S008C003P025R002A012 -S008C003P025R002A016 -S008C003P025R002A020 -S008C003P025R002A022 -S008C003P025R002A023 -S008C003P025R002A030 -S008C003P025R002A031 -S008C003P025R002A032 -S008C003P025R002A033 -S008C003P025R002A049 -S008C003P025R002A060 -S008C003P031R001A001 -S008C003P031R002A004 -S008C003P031R002A014 -S008C003P031R002A015 -S008C003P031R002A016 -S008C003P031R002A017 -S008C003P032R002A013 -S008C003P033R002A001 -S008C003P033R002A011 -S008C003P033R002A012 -S008C003P034R002A001 -S008C003P034R002A012 -S008C003P034R002A022 -S008C003P034R002A023 -S008C003P034R002A024 -S008C003P034R002A044 -S008C003P034R002A045 -S008C003P035R002A016 -S008C003P035R002A017 -S008C003P035R002A018 -S008C003P035R002A019 -S008C003P035R002A020 -S008C003P035R002A021 -S009C002P007R001A001 -S009C002P007R001A003 -S009C002P007R001A014 -S009C002P008R001A014 -S009C002P015R002A050 -S009C002P016R001A002 -S009C002P017R001A028 -S009C002P017R001A029 -S009C003P017R002A030 -S009C003P025R002A054 -S010C001P007R002A020 -S010C002P016R002A055 -S010C002P017R001A005 -S010C002P017R001A018 -S010C002P017R001A019 -S010C002P019R001A001 -S010C002P025R001A012 -S010C003P007R002A043 -S010C003P008R002A003 -S010C003P016R001A055 -S010C003P017R002A055 -S011C001P002R001A008 -S011C001P018R002A050 -S011C002P008R002A059 -S011C002P016R002A055 -S011C002P017R001A020 -S011C002P017R001A021 -S011C002P018R002A055 -S011C002P027R001A009 -S011C002P027R001A010 -S011C002P027R001A037 -S011C003P001R001A055 -S011C003P002R001A055 -S011C003P008R002A012 -S011C003P015R001A055 -S011C003P016R001A055 -S011C003P019R001A055 -S011C003P025R001A055 -S011C003P028R002A055 -S012C001P019R001A060 -S012C001P019R002A060 -S012C002P015R001A055 -S012C002P017R002A012 -S012C002P025R001A060 -S012C003P008R001A057 -S012C003P015R001A055 -S012C003P015R002A055 -S012C003P016R001A055 -S012C003P017R002A055 -S012C003P018R001A055 -S012C003P018R001A057 -S012C003P019R002A011 -S012C003P019R002A012 -S012C003P025R001A055 -S012C003P027R001A055 -S012C003P027R002A009 -S012C003P028R001A035 -S012C003P028R002A055 -S013C001P015R001A054 -S013C001P017R002A054 -S013C001P018R001A016 -S013C001P028R001A040 -S013C002P015R001A054 -S013C002P017R002A054 -S013C002P028R001A040 -S013C003P008R002A059 -S013C003P015R001A054 -S013C003P017R002A054 -S013C003P025R002A022 -S013C003P027R001A055 -S013C003P028R001A040 -S014C001P027R002A040 -S014C002P015R001A003 -S014C002P019R001A029 -S014C002P025R002A059 -S014C002P027R002A040 -S014C002P039R001A050 -S014C003P007R002A059 -S014C003P015R002A055 -S014C003P019R002A055 -S014C003P025R001A048 -S014C003P027R002A040 -S015C001P008R002A040 -S015C001P016R001A055 -S015C001P017R001A055 -S015C001P017R002A055 -S015C002P007R001A059 -S015C002P008R001A003 -S015C002P008R001A004 -S015C002P008R002A040 -S015C002P015R001A002 -S015C002P016R001A001 -S015C002P016R002A055 -S015C003P008R002A007 -S015C003P008R002A011 -S015C003P008R002A012 -S015C003P008R002A028 -S015C003P008R002A040 -S015C003P025R002A012 -S015C003P025R002A017 -S015C003P025R002A020 -S015C003P025R002A021 -S015C003P025R002A030 -S015C003P025R002A033 -S015C003P025R002A034 -S015C003P025R002A036 -S015C003P025R002A037 -S015C003P025R002A044 -S016C001P019R002A040 -S016C001P025R001A011 -S016C001P025R001A012 -S016C001P025R001A060 -S016C001P040R001A055 -S016C001P040R002A055 -S016C002P008R001A011 -S016C002P019R002A040 -S016C002P025R002A012 -S016C003P008R001A011 -S016C003P008R002A002 -S016C003P008R002A003 -S016C003P008R002A004 -S016C003P008R002A006 -S016C003P008R002A009 -S016C003P019R002A040 -S016C003P039R002A016 -S017C001P016R002A031 -S017C002P007R001A013 -S017C002P008R001A009 -S017C002P015R001A042 -S017C002P016R002A031 -S017C002P016R002A055 -S017C003P007R002A013 -S017C003P008R001A059 -S017C003P016R002A031 -S017C003P017R001A055 -S017C003P020R001A059 diff --git a/data/readme.md b/data/readme.md deleted file mode 100644 index 777d39a..0000000 --- a/data/readme.md +++ /dev/null @@ -1 +0,0 @@ -The filepath of data (NTU-RGB+D) diff --git a/data_gen/__init__.py b/data_gen/__init__.py index 8b13789..d3f5a12 100644 --- a/data_gen/__init__.py +++ b/data_gen/__init__.py @@ -1 +1 @@ - + diff --git a/data_gen/__pycache__/__init__.cpython-36.pyc b/data_gen/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..600c7f3 Binary files /dev/null and b/data_gen/__pycache__/__init__.cpython-36.pyc differ diff --git a/data_gen/__pycache__/preprocess.cpython-36.pyc b/data_gen/__pycache__/preprocess.cpython-36.pyc new file mode 100644 index 0000000..c42e997 Binary files /dev/null and b/data_gen/__pycache__/preprocess.cpython-36.pyc differ diff --git a/data_gen/__pycache__/rotation.cpython-36.pyc b/data_gen/__pycache__/rotation.cpython-36.pyc new file mode 100644 index 0000000..722593f Binary files /dev/null and b/data_gen/__pycache__/rotation.cpython-36.pyc differ diff --git a/data_gen/gpu.py b/data_gen/gpu.py new file mode 100644 index 0000000..e086d4c --- /dev/null +++ b/data_gen/gpu.py @@ -0,0 +1,35 @@ +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/data_gen/io.py b/data_gen/io.py new file mode 100644 index 0000000..5b43720 --- /dev/null +++ b/data_gen/io.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/data_gen/ntu_gen_preprocess.py b/data_gen/ntu_gen_preprocess.py index 6323b30..99b27a4 100644 --- a/data_gen/ntu_gen_preprocess.py +++ b/data_gen/ntu_gen_preprocess.py @@ -1,143 +1,144 @@ -import argparse -import pickle -from tqdm import tqdm -import sys - -sys.path.extend(['../']) -from data_gen.preprocess import pre_normalization - -training_subjects = [1, 2, 4, 5, 8, 9, 13, 14, 15, 16, 17, 18, 19, 25, 27, 28, 31, 34, 35, 38] -training_cameras = [2, 3] -max_body_true = 2 -max_body_kinect = 4 -num_joint = 25 -max_frame = 300 - -import numpy as np -import os - - -def read_skeleton_filter(file): - with open(file, 'r') as f: - skeleton_sequence = {} - skeleton_sequence['numFrame'] = int(f.readline()) - skeleton_sequence['frameInfo'] = [] - for t in range(skeleton_sequence['numFrame']): - frame_info = {} - frame_info['numBody'] = int(f.readline()) - frame_info['bodyInfo'] = [] - - for m in range(frame_info['numBody']): - body_info = {} - body_info_key = ['bodyID', 'clipedEdges', 'handLeftConfidence', - 'handLeftState', 'handRightConfidence', 'handRightState', - 'isResticted', 'leanX', 'leanY', 'trackingState'] - body_info = {k: float(v) for k, v in zip(body_info_key, f.readline().split())} - body_info['numJoint'] = int(f.readline()) - body_info['jointInfo'] = [] - for v in range(body_info['numJoint']): - joint_info_key = ['x', 'y', 'z', 'depthX', 'depthY', 'colorX', 'colorY', - 'orientationW', 'orientationX', 'orientationY', - 'orientationZ', 'trackingState'] - joint_info = {k: float(v) for k, v in zip(joint_info_key, f.readline().split())} - body_info['jointInfo'].append(joint_info) - frame_info['bodyInfo'].append(body_info) - skeleton_sequence['frameInfo'].append(frame_info) - - return skeleton_sequence - - -def get_nonzero_std(s): - index = s.sum(-1).sum(-1) != 0 - s = s[index] - if len(s) != 0: - s = s[:, :, 0].std() + s[:, :, 1].std() + s[:, :, 2].std() - else: - s = 0 - return s - - -def read_xyz(file, max_body=4, num_joint=25): - seq_info = read_skeleton_filter(file) - data = np.zeros((max_body, seq_info['numFrame'], num_joint, 3)) - for n, f in enumerate(seq_info['frameInfo']): - for m, b in enumerate(f['bodyInfo']): - for j, v in enumerate(b['jointInfo']): - if m < max_body and j < num_joint: - data[m, n, j, :] = [v['x'], v['y'], v['z']] - else: - pass - - energy = np.array([get_nonzero_std(x) for x in data]) - index = energy.argsort()[::-1][0:max_body_true] - data = data[index] - - data = data.transpose(3, 1, 2, 0) - return data - - -def gendata(data_path, out_path, ignored_sample_path=None, benchmark='xsub', set_name='val'): - if ignored_sample_path != None: - with open(ignored_sample_path, 'r') as f: - ignored_samples = [line.strip() + '.skeleton' for line in f.readlines()] - else: - ignored_samples = [] - sample_name = [] - sample_label = [] - for filename in os.listdir(data_path): - if filename in ignored_samples: - continue - action_class = int(filename[filename.find('A') + 1:filename.find('A') + 4]) - subject_id = int(filename[filename.find('P') + 1:filename.find('P') + 4]) - camera_id = int(filename[filename.find('C') + 1:filename.find('C') + 4]) - - if benchmark == 'xview': - istraining = (camera_id in training_cameras) - elif benchmark == 'xsub': - istraining = (subject_id in training_subjects) - else: - raise ValueError() - - if set_name == 'train': - issample = istraining - elif set_name == 'val': - issample = not (istraining) - else: - raise ValueError() - - if issample: - sample_name.append(filename) - sample_label.append(action_class - 1) - print(len(sample_label)) - - with open('{}/{}_label.pkl'.format(out_path, set_name), 'wb') as f: - pickle.dump((sample_name, list(sample_label)), f) - - fp = np.zeros((len(sample_label), 3, max_frame, num_joint, max_body_true), dtype=np.float32) - - for i, s in enumerate(tqdm(sample_name)): - print(s) - data = read_xyz(os.path.join(data_path, s), max_body=max_body_kinect, num_joint=num_joint) - fp[i, :, 0:data.shape[1], :, :] = data - - fp = pre_normalization(fp) - np.save('{}/{}_data_joint_pad.npy'.format(out_path, set_name), fp) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='NTU-RGB-D Data Converter.') - parser.add_argument('--data_path', default='../data/NTU-RGB+D/nturgb+d_skeletons/') - parser.add_argument('--ignored_sample_path', default='../data/NTU-RGB+D/samples_with_missing_skeletons.txt') - parser.add_argument('--out_folder', default='../data/nturgb_d/') - - benchmark = ['xsub', 'xview'] - set_name = ['train', 'val'] - arg = parser.parse_args() - - for b in benchmark: - for sn in set_name: - out_path = os.path.join(arg.out_folder, b) - if not os.path.exists(out_path): - os.makedirs(out_path) - print(b, sn) - gendata(arg.data_path, out_path, arg.ignored_sample_path, benchmark=b, part=sn) \ No newline at end of file +import argparse +import pickle +from tqdm import tqdm +import sys + +sys.path.extend(['../']) +from data_gen.preprocess import pre_normalization + +training_subjects = [1, 2, 4, 5, 8, 9, 13, 14, 15, 16, 17, 18, 19, 25, 27, 28, 31, 34, 35, 38] +training_cameras = [2, 3] +max_body_true = 2 +max_body_kinect = 4 +num_joint = 25 +max_frame = 300 + +import numpy as np +import os + + +def read_skeleton_filter(file): + with open(file, 'r') as f: + skeleton_sequence = {} + skeleton_sequence['numFrame'] = int(f.readline()) + skeleton_sequence['frameInfo'] = [] + for t in range(skeleton_sequence['numFrame']): + frame_info = {} + frame_info['numBody'] = int(f.readline()) + frame_info['bodyInfo'] = [] + + for m in range(frame_info['numBody']): + body_info = {} + body_info_key = ['bodyID', 'clipedEdges', 'handLeftConfidence', + 'handLeftState', 'handRightConfidence', 'handRightState', + 'isResticted', 'leanX', 'leanY', 'trackingState'] + body_info = {k: float(v) for k, v in zip(body_info_key, f.readline().split())} + body_info['numJoint'] = int(f.readline()) + body_info['jointInfo'] = [] + for v in range(body_info['numJoint']): + joint_info_key = ['x', 'y', 'z', 'depthX', 'depthY', 'colorX', 'colorY', + 'orientationW', 'orientationX', 'orientationY', + 'orientationZ', 'trackingState'] + joint_info = {k: float(v) for k, v in zip(joint_info_key, f.readline().split())} + body_info['jointInfo'].append(joint_info) + frame_info['bodyInfo'].append(body_info) + skeleton_sequence['frameInfo'].append(frame_info) + + return skeleton_sequence + + +def get_nonzero_std(s): + index = s.sum(-1).sum(-1) != 0 + s = s[index] + if len(s) != 0: + s = s[:, :, 0].std() + s[:, :, 1].std() + s[:, :, 2].std() + else: + s = 0 + return s + + +def read_xyz(file, max_body=4, num_joint=25): + seq_info = read_skeleton_filter(file) + data = np.zeros((max_body, seq_info['numFrame'], num_joint, 3)) + for n, f in enumerate(seq_info['frameInfo']): + for m, b in enumerate(f['bodyInfo']): + for j, v in enumerate(b['jointInfo']): + if m < max_body and j < num_joint: + data[m, n, j, :] = [v['x'], v['y'], v['z']] + else: + pass + + energy = np.array([get_nonzero_std(x) for x in data]) + index = energy.argsort()[::-1][0:max_body_true] + data = data[index] + + data = data.transpose(3, 1, 2, 0) + return data + + +def gendata(data_path, out_path, ignored_sample_path=None, benchmark='xsub', set_name='val'): + if ignored_sample_path != None: + with open(ignored_sample_path, 'r') as f: + ignored_samples = [line.strip() + '.skeleton' for line in f.readlines()] + else: + ignored_samples = [] + sample_name = [] + sample_label = [] + for filename in os.listdir(data_path): + if filename in ignored_samples: + continue + action_class = int(filename[filename.find('A') + 1:filename.find('A') + 4]) + subject_id = int(filename[filename.find('P') + 1:filename.find('P') + 4]) + camera_id = int(filename[filename.find('C') + 1:filename.find('C') + 4]) + + if benchmark == 'xview': + istraining = (camera_id in training_cameras) + elif benchmark == 'xsub': + istraining = (subject_id in training_subjects) + else: + raise ValueError() + + if set_name == 'train': + issample = istraining + elif set_name == 'val': + issample = not (istraining) + else: + raise ValueError() + + if issample: + sample_name.append(filename) + sample_label.append(action_class - 1) + print(len(sample_label)) + + with open('{}/{}_label.pkl'.format(out_path, set_name), 'wb') as f: + pickle.dump((sample_name, list(sample_label)), f) + + fp = np.zeros((len(sample_label), 3, max_frame, num_joint, max_body_true), dtype=np.float32) + + for i, s in enumerate(tqdm(sample_name)): + print(s) + data = read_xyz(os.path.join(data_path, s), max_body=max_body_kinect, num_joint=num_joint) + fp[i, :, 0:data.shape[1], :, :] = data + + fp = pre_normalization(fp) + np.save('{}/{}_data_joint_pad.npy'.format(out_path, set_name), fp) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='NTU-RGB-D Data Converter.') + parser.add_argument('--data_path', default='../data/NTU-RGB+D/nturgb+d_skeletons/') + parser.add_argument('--ignored_sample_path', default='../data/NTU-RGB+D/samples_with_missing_skeletons.txt') + parser.add_argument('--out_folder', default='../data/nturgb_d/') + + benchmark = ['xsub', 'xview'] + set_name = ['train', 'val'] + arg = parser.parse_args() + + for b in benchmark: + for sn in set_name: + out_path = os.path.join(arg.out_folder, b) + if not os.path.exists(out_path): + os.makedirs(out_path) + print(b, sn) + #gendata(arg.data_path, out_path, arg.ignored_sample_path, benchmark=b, part=sn) + gendata(arg.data_path, out_path, arg.ignored_sample_path, benchmark=b, set_name=sn) diff --git a/data_gen/preprocess.py b/data_gen/preprocess.py index 86810aa..4e6ed9f 100644 --- a/data_gen/preprocess.py +++ b/data_gen/preprocess.py @@ -1,68 +1,68 @@ -import sys - -sys.path.extend(['../']) -from data_gen.rotation import * -from tqdm import tqdm - - -def pre_normalization(data, zaxis=[0, 1], xaxis=[8, 4]): - N, C, T, V, M = data.shape - s = np.transpose(data, [0, 4, 2, 3, 1]) - - print('sub the center joint') - for i_s, skeleton in enumerate(tqdm(s)): - if skeleton.sum() == 0: - continue - main_body_center = skeleton[0][:, 1:2, :].copy() - for i_p, person in enumerate(skeleton): - if person.sum() == 0: - continue - mask = (person.sum(-1) != 0).reshape(T, V, 1) - s[i_s, i_p] = (s[i_s, i_p] - main_body_center) * mask - - - print('parallel the torso bone') - for i_s, skeleton in enumerate(tqdm(s)): - if skeleton.sum() == 0: - continue - joint_bottom = skeleton[0, 0, zaxis[0]] - joint_top = skeleton[0, 0, zaxis[1]] - axis = np.cross(joint_top - joint_bottom, [0, 0, 1]) - angle = angle_between(joint_top - joint_bottom, [0, 0, 1]) - matrix_z = rotation_matrix(axis, angle) - for i_p, person in enumerate(skeleton): - if person.sum() == 0: - continue - for i_f, frame in enumerate(person): - if frame.sum() == 0: - continue - for i_j, joint in enumerate(frame): - s[i_s, i_p, i_f, i_j] = np.dot(matrix_z, joint) - - - print('parallel the shoulder bone') - for i_s, skeleton in enumerate(tqdm(s)): - if skeleton.sum() == 0: - continue - joint_rshoulder = skeleton[0, 0, xaxis[0]] - joint_lshoulder = skeleton[0, 0, xaxis[1]] - axis = np.cross(joint_rshoulder - joint_lshoulder, [1, 0, 0]) - angle = angle_between(joint_rshoulder - joint_lshoulder, [1, 0, 0]) - matrix_x = rotation_matrix(axis, angle) - for i_p, person in enumerate(skeleton): - if person.sum() == 0: - continue - for i_f, frame in enumerate(person): - if frame.sum() == 0: - continue - for i_j, joint in enumerate(frame): - s[i_s, i_p, i_f, i_j] = np.dot(matrix_x, joint) - - data = np.transpose(s, [0, 4, 2, 3, 1]) - return data - - -if __name__ == '__main__': - data = np.load('../data/NTU-RGB+D/xsub/train_data.npy') - pre_normalization(data) +import sys + +sys.path.extend(['../']) +from data_gen.rotation import * +from tqdm import tqdm + + +def pre_normalization(data, zaxis=[0, 1], xaxis=[8, 4]): + N, C, T, V, M = data.shape + s = np.transpose(data, [0, 4, 2, 3, 1]) + + print('sub the center joint') + for i_s, skeleton in enumerate(tqdm(s)): + if skeleton.sum() == 0: + continue + main_body_center = skeleton[0][:, 1:2, :].copy() + for i_p, person in enumerate(skeleton): + if person.sum() == 0: + continue + mask = (person.sum(-1) != 0).reshape(T, V, 1) + s[i_s, i_p] = (s[i_s, i_p] - main_body_center) * mask + + + print('parallel the torso bone') + for i_s, skeleton in enumerate(tqdm(s)): + if skeleton.sum() == 0: + continue + joint_bottom = skeleton[0, 0, zaxis[0]] + joint_top = skeleton[0, 0, zaxis[1]] + axis = np.cross(joint_top - joint_bottom, [0, 0, 1]) + angle = angle_between(joint_top - joint_bottom, [0, 0, 1]) + matrix_z = rotation_matrix(axis, angle) + for i_p, person in enumerate(skeleton): + if person.sum() == 0: + continue + for i_f, frame in enumerate(person): + if frame.sum() == 0: + continue + for i_j, joint in enumerate(frame): + s[i_s, i_p, i_f, i_j] = np.dot(matrix_z, joint) + + + print('parallel the shoulder bone') + for i_s, skeleton in enumerate(tqdm(s)): + if skeleton.sum() == 0: + continue + joint_rshoulder = skeleton[0, 0, xaxis[0]] + joint_lshoulder = skeleton[0, 0, xaxis[1]] + axis = np.cross(joint_rshoulder - joint_lshoulder, [1, 0, 0]) + angle = angle_between(joint_rshoulder - joint_lshoulder, [1, 0, 0]) + matrix_x = rotation_matrix(axis, angle) + for i_p, person in enumerate(skeleton): + if person.sum() == 0: + continue + for i_f, frame in enumerate(person): + if frame.sum() == 0: + continue + for i_j, joint in enumerate(frame): + s[i_s, i_p, i_f, i_j] = np.dot(matrix_x, joint) + + data = np.transpose(s, [0, 4, 2, 3, 1]) + return data + + +if __name__ == '__main__': + data = np.load('../data/NTU-RGB+D/xsub/train_data.npy') + pre_normalization(data) np.save('../data/nturgb_d/xsub/data_train_pre.npy', data) \ No newline at end of file diff --git a/data_gen/rotation.py b/data_gen/rotation.py index f82e6b8..9da8497 100644 --- a/data_gen/rotation.py +++ b/data_gen/rotation.py @@ -1,43 +1,43 @@ -import numpy as np -import math - - -def rotation_matrix(axis, theta): - if np.abs(axis).sum() < 1e-6 or np.abs(theta) < 1e-6: - return np.eye(3) - axis = np.asarray(axis) - axis = axis / math.sqrt(np.dot(axis, axis)) - a = math.cos(theta / 2.0) - b, c, d = -axis * math.sin(theta / 2.0) - aa, bb, cc, dd = a * a, b * b, c * c, d * d - bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d - return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], - [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], - [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) - - -def unit_vector(vector): - return vector / np.linalg.norm(vector) - - -def angle_between(v1, v2): - if np.abs(v1).sum() < 1e-6 or np.abs(v2).sum() < 1e-6: - return 0 - v1_u = unit_vector(v1) - v2_u = unit_vector(v2) - return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) - - -def x_rotation(vector, theta): - R = np.array([[1, 0, 0], [0, np.cos(theta), -np.sin(theta)], [0, np.sin(theta), np.cos(theta)]]) - return np.dot(R, vector) - - -def y_rotation(vector, theta): - R = np.array([[np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]]) - return np.dot(R, vector) - - -def z_rotation(vector, theta): - R = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) +import numpy as np +import math + + +def rotation_matrix(axis, theta): + if np.abs(axis).sum() < 1e-6 or np.abs(theta) < 1e-6: + return np.eye(3) + axis = np.asarray(axis) + axis = axis / math.sqrt(np.dot(axis, axis)) + a = math.cos(theta / 2.0) + b, c, d = -axis * math.sin(theta / 2.0) + aa, bb, cc, dd = a * a, b * b, c * c, d * d + bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d + return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) + + +def unit_vector(vector): + return vector / np.linalg.norm(vector) + + +def angle_between(v1, v2): + if np.abs(v1).sum() < 1e-6 or np.abs(v2).sum() < 1e-6: + return 0 + v1_u = unit_vector(v1) + v2_u = unit_vector(v2) + return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) + + +def x_rotation(vector, theta): + R = np.array([[1, 0, 0], [0, np.cos(theta), -np.sin(theta)], [0, np.sin(theta), np.cos(theta)]]) + return np.dot(R, vector) + + +def y_rotation(vector, theta): + R = np.array([[np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]]) + return np.dot(R, vector) + + +def z_rotation(vector, theta): + R = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) return np.dot(R, vector) \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..7e6873e --- /dev/null +++ b/environment.yml @@ -0,0 +1,88 @@ +name: asgcn +channels: + - pytorch + - https://mirrors.ustc.edu.cn/anaconda/pkgs/main + - https://mirrors.ustc.edu.cn/anaconda/pkgs/main/ + - https://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge/ + - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - blas=1.0=mkl + - ca-certificates=2021.4.13=h06a4308_1 + - certifi=2020.12.5=py36h06a4308_0 + - cffi=1.14.5=py36h261ae71_0 + - cuda90=1.0=h6433d27_0 + - cudatoolkit=10.0.130=0 + - cudnn=7.6.5=cuda10.0_0 + - cycler=0.10.0=py36_0 + - dbus=1.13.18=hb2f20db_0 + - expat=2.3.0=h2531618_2 + - fontconfig=2.13.1=h6c09931_0 + - freetype=2.10.4=h5ab3b9f_0 + - glib=2.68.1=h36276a3_0 + - gst-plugins-base=1.14.0=h8213a91_2 + - gstreamer=1.14.0=h28cd5cc_2 + - icu=58.2=he6710b0_3 + - intel-openmp=2019.4=243 + - jpeg=9b=h024ee3a_2 + - kiwisolver=1.3.1=py36h2531618_0 + - lcms2=2.11=h396b838_0 + - ld_impl_linux-64=2.33.1=h53a641e_7 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.1.0=hdf63c60_0 + - libgfortran-ng=7.3.0=hdf63c60_0 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - libtiff=4.2.0=h3942068_0 + - libuuid=1.0.3=h1bed415_2 + - libwebp-base=1.2.0=h27cfd23_0 + - libxcb=1.14=h7b6447c_0 + - libxml2=2.9.10=hb55368b_3 + - lz4-c=1.9.3=h2531618_0 + - matplotlib=3.3.2=h06a4308_0 + - matplotlib-base=3.3.2=py36h817c723_0 + - mkl=2018.0.3=1 + - mkl_fft=1.0.6=py36h7dd41cf_0 + - mkl_random=1.0.1=py36h4414c95_1 + - ncurses=6.2=he6710b0_1 + - ninja=1.10.2=py36hff7bd54_0 + - olefile=0.46=py36_0 + - openssl=1.1.1k=h27cfd23_0 + - pcre=8.44=he6710b0_0 + - pillow=8.1.2=py36he98fc37_0 + - pip=21.0.1=py36h06a4308_0 + - pycparser=2.20=py_2 + - pyparsing=2.4.7=pyhd3eb1b0_0 + - pyqt=5.9.2=py36h05f1152_2 + - python=3.6.13=hdb3f193_0 + - python-dateutil=2.8.1=pyhd3eb1b0_0 + - qt=5.9.7=h5867ecd_1 + - readline=8.1=h27cfd23_0 + - setuptools=52.0.0=py36h06a4308_0 + - sip=4.19.8=py36hf484d3e_0 + - six=1.15.0=py36h06a4308_0 + - sqlite=3.35.1=hdfb4753_0 + - tbb=2021.2.0=hff7bd54_0 + - tbb4py=2021.2.0=py36hff7bd54_0 + - tk=8.6.10=hbc83047_0 + - tornado=6.1=py36h27cfd23_0 + - wheel=0.36.2=pyhd3eb1b0_0 + - xz=5.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.5=h9ceee32_0 + - pip: + - argparse==1.4.0 + - cached-property==1.5.2 + - dataclasses==0.8 + - h5py==3.1.0 + - imageio==2.9.0 + - numpy==1.19.5 + - opencv-python==4.5.1.48 + - pyyaml==5.4.1 + - scikit-video==1.1.11 + - scipy==1.5.4 + - torch==1.7.1 + - torchvision==0.9.0 + - tqdm==4.60.0 + - typing-extensions==3.7.4.3 diff --git a/epoch_loss_class_eval.png b/epoch_loss_class_eval.png new file mode 100644 index 0000000..246c771 Binary files /dev/null and b/epoch_loss_class_eval.png differ diff --git a/epoch_loss_class_train.png b/epoch_loss_class_train.png new file mode 100644 index 0000000..51e019e Binary files /dev/null and b/epoch_loss_class_train.png differ diff --git a/epoch_loss_class_train.txt b/epoch_loss_class_train.txt new file mode 100644 index 0000000..515fdfc --- /dev/null +++ b/epoch_loss_class_train.txt @@ -0,0 +1,90 @@ +3.7712276314242685 +3.0636168965986927 +2.801343619823456 +2.6331870392384205 +2.5135307890929712 +2.4291697614202716 +2.3555361779060604 +2.2984420201291202 +2.2522135552288742 +2.197252894587989 +2.097349743214381 +1.9542330193516024 +1.7952129174223848 +1.6156651931379018 +1.4707137380617539 +1.3638808261168214 +1.2811493240666694 +1.2145853494477985 +1.1464360410734533 +1.0886454173869087 +1.0415393155084143 +1.0067537130545055 +0.9712598563652918 +0.9459512437851009 +0.9096426171339108 +0.883714666565097 +0.8630674364239435 +0.8406433062055244 +0.820840925855913 +0.8079694209025758 +0.7841788579339185 +0.776262023339371 +0.75239310020289 +0.7407916599995307 +0.7314519268123639 +0.7205715374548911 +0.7119348809259342 +0.6838741163310618 +0.675083714229081 +0.6665253263740542 +0.40875943099878626 +0.33014846760375033 +0.301763350591267 +0.27762719021237026 +0.26088881696264316 +0.2430464035457363 +0.2282850744710909 +0.2212755358291129 +0.203683311874431 +0.1944100352668947 +0.18915298141420142 +0.18149443655029596 +0.17873722235375974 +0.171926496180016 +0.164108280483452 +0.1580632201443505 +0.1529976990462558 +0.14594485707745689 +0.1456607071607638 +0.14255504573110114 +0.11923092203425435 +0.10815973842863215 +0.1090422415932918 +0.10648260271793217 +0.09845620263738979 +0.100937618059459 +0.10059720761977203 +0.09929326052758904 +0.09756028030852115 +0.09850751669961134 +0.09668084698679134 +0.0974692665948768 +0.09672294503566119 +0.09527862614642296 +0.09309141301707675 +0.09378876050791575 +0.094307657253778 +0.09193214094837932 +0.09031117155016503 +0.08870918162732286 +0.08976529900305628 +0.08807121031712889 +0.08787393360298594 +0.09003039282064657 +0.08867947479501018 +0.08774756830658956 +0.08815590344746861 +0.08792258952912853 +0.09093503984912055 +0.08763042784054735 diff --git a/feeder/__init__.py b/feeder/__init__.py index 8b13789..d3f5a12 100644 --- a/feeder/__init__.py +++ b/feeder/__init__.py @@ -1 +1 @@ - + diff --git a/feeder/feeder.py b/feeder/feeder.py index dba96f3..59a6333 100644 --- a/feeder/feeder.py +++ b/feeder/feeder.py @@ -1,90 +1,90 @@ -import os -import sys -import numpy as np -import random -import pickle -import time -import copy - -import torch -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -from torchvision import datasets, transforms - -from . import tools - -class Feeder(torch.utils.data.Dataset): - - def __init__(self, - data_path, label_path, - repeat_pad=False, - random_choose=False, - random_move=False, - window_size=-1, - debug=False, - down_sample = False, - mmap=True): - self.debug = debug - self.data_path = data_path - self.label_path = label_path - self.repeat_pad = repeat_pad - self.random_choose = random_choose - self.random_move = random_move - self.window_size = window_size - self.down_sample = down_sample - - self.load_data(mmap) - - def load_data(self, mmap): - - with open(self.label_path, 'rb') as f: - self.sample_name, self.label = pickle.load(f) - - if mmap: - self.data = np.load(self.data_path, mmap_mode='r') - else: - self.data = np.load(self.data_path) - - if self.debug: - self.label = self.label[0:100] - self.data = self.data[0:100] - self.sample_name = self.sample_name[0:100] - - self.N, self.C, self.T, self.V, self.M = self.data.shape - - def __len__(self): - return len(self.label) - - def __getitem__(self, index): - - data_numpy = np.array(self.data[index]).astype(np.float32) - label = self.label[index] - - valid_frame = (data_numpy!=0).sum(axis=3).sum(axis=2).sum(axis=0)>0 - begin, end = valid_frame.argmax(), len(valid_frame)-valid_frame[::-1].argmax() - length = end-begin - - if self.repeat_pad: - data_numpy = tools.repeat_pading(data_numpy) - if self.random_choose: - data_numpy = tools.random_choose(data_numpy, self.window_size) - elif self.window_size > 0: - data_numpy = tools.auto_pading(data_numpy, self.window_size) - if self.random_move: - data_numpy = tools.random_move(data_numpy) - - data_last = copy.copy(data_numpy[:,-11:-10,:,:]) - target_data = copy.copy(data_numpy[:,-10:,:,:]) - input_data = copy.copy(data_numpy[:,:-10,:,:]) - - if self.down_sample: - if length<=60: - input_data_dnsp = input_data[:,:50,:,:] - else: - rs = int(np.random.uniform(low=0, high=np.ceil((length-10)/50))) - input_data_dnsp = [input_data[:,int(i)+rs,:,:] for i in [np.floor(j*((length-10)/50)) for j in range(50)]] - input_data_dnsp = np.array(input_data_dnsp).astype(np.float32) - input_data_dnsp = np.transpose(input_data_dnsp, axes=(1,0,2,3)) - +import os +import sys +import numpy as np +import random +import pickle +import time +import copy + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torchvision import datasets, transforms + +from . import tools + +class Feeder(torch.utils.data.Dataset): + + def __init__(self, + data_path, label_path, + repeat_pad=False, + random_choose=False, + random_move=False, + window_size=-1, + debug=False, + down_sample = False, + mmap=True): + self.debug = debug + self.data_path = data_path + self.label_path = label_path + self.repeat_pad = repeat_pad + self.random_choose = random_choose + self.random_move = random_move + self.window_size = window_size + self.down_sample = down_sample + + self.load_data(mmap) + + def load_data(self, mmap): + + with open(self.label_path, 'rb') as f: + self.sample_name, self.label = pickle.load(f) + + if mmap: + self.data = np.load(self.data_path, mmap_mode='r') + else: + self.data = np.load(self.data_path) + + if self.debug: + self.label = self.label[0:100] + self.data = self.data[0:100] + self.sample_name = self.sample_name[0:100] + + self.N, self.C, self.T, self.V, self.M = self.data.shape # (40091, 3, 300, 25, 2) + + def __len__(self): + return len(self.label) + + def __getitem__(self, index): + + data_numpy = np.array(self.data[index]).astype(np.float32) + label = self.label[index] + + valid_frame = (data_numpy!=0).sum(axis=3).sum(axis=2).sum(axis=0)>0 + begin, end = valid_frame.argmax(), len(valid_frame)-valid_frame[::-1].argmax() + length = end-begin + + if self.repeat_pad: + data_numpy = tools.repeat_pading(data_numpy) + if self.random_choose: + data_numpy = tools.random_choose(data_numpy, self.window_size) + elif self.window_size > 0: + data_numpy = tools.auto_pading(data_numpy, self.window_size) + if self.random_move: + data_numpy = tools.random_move(data_numpy) + + data_last = copy.copy(data_numpy[:,-11:-10,:,:]) + target_data = copy.copy(data_numpy[:,-10:,:,:]) + input_data = copy.copy(data_numpy[:,:-10,:,:]) + + if self.down_sample: + if length<=60: + input_data_dnsp = input_data[:,:50,:,:] + else: + rs = int(np.random.uniform(low=0, high=np.ceil((length-10)/50))) + input_data_dnsp = [input_data[:,int(i)+rs,:,:] for i in [np.floor(j*((length-10)/50)) for j in range(50)]] + input_data_dnsp = np.array(input_data_dnsp).astype(np.float32) + input_data_dnsp = np.transpose(input_data_dnsp, axes=(1,0,2,3)) + return input_data, input_data_dnsp, target_data, data_last, label \ No newline at end of file diff --git a/feeder/tools.py b/feeder/tools.py index 0233fc7..942cfd2 100644 --- a/feeder/tools.py +++ b/feeder/tools.py @@ -1,244 +1,244 @@ -import numpy as np -import random - - -def downsample(data_numpy, step, random_sample=True): - # input: C,T,V,M - begin = np.random.randint(step) if random_sample else 0 - return data_numpy[:, begin::step, :, :] - - -def temporal_slice(data_numpy, step): - # input: C,T,V,M - C, T, V, M = data_numpy.shape - return data_numpy.reshape(C, T / step, step, V, M).transpose( - (0, 1, 3, 2, 4)).reshape(C, T / step, V, step * M) - - -def mean_subtractor(data_numpy, mean): - # input: C,T,V,M - # naive version - if mean == 0: - return - C, T, V, M = data_numpy.shape - valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0 - begin = valid_frame.argmax() - end = len(valid_frame) - valid_frame[::-1].argmax() - data_numpy[:, :end, :, :] = data_numpy[:, :end, :, :] - mean - return data_numpy - - -def auto_pading(data_numpy, size, random_pad=False): - C, T, V, M = data_numpy.shape - if T < size: - begin = random.randint(0, size - T) if random_pad else 0 - data_numpy_paded = np.zeros((C, size, V, M)) - data_numpy_paded[:, begin:begin + T, :, :] = data_numpy - return data_numpy_paded - else: - return data_numpy - - -def random_choose(data_numpy, size, auto_pad=True): - # input: C,T,V,M - C, T, V, M = data_numpy.shape - if T == size: - return data_numpy - elif T < size: - if auto_pad: - return auto_pading(data_numpy, size, random_pad=True) - else: - return data_numpy - else: - begin = random.randint(0, T - size) - return data_numpy[:, begin:begin + size, :, :] - - -def random_move(data_numpy, - angle_candidate=[-10., -5., 0., 5., 10.], - scale_candidate=[0.9, 1.0, 1.1], - transform_candidate=[-0.2, -0.1, 0.0, 0.1, 0.2], - move_time_candidate=[1]): - # input: C,T,V,M - C, T, V, M = data_numpy.shape - move_time = random.choice(move_time_candidate) - node = np.arange(0, T, T * 1.0 / move_time).round().astype(int) - node = np.append(node, T) - num_node = len(node) - - A = np.random.choice(angle_candidate, num_node) - S = np.random.choice(scale_candidate, num_node) - T_x = np.random.choice(transform_candidate, num_node) - T_y = np.random.choice(transform_candidate, num_node) - - a = np.zeros(T) - s = np.zeros(T) - t_x = np.zeros(T) - t_y = np.zeros(T) - - # linspace - for i in range(num_node - 1): - a[node[i]:node[i + 1]] = np.linspace( - A[i], A[i + 1], node[i + 1] - node[i]) * np.pi / 180 - s[node[i]:node[i + 1]] = np.linspace(S[i], S[i + 1], - node[i + 1] - node[i]) - t_x[node[i]:node[i + 1]] = np.linspace(T_x[i], T_x[i + 1], - node[i + 1] - node[i]) - t_y[node[i]:node[i + 1]] = np.linspace(T_y[i], T_y[i + 1], - node[i + 1] - node[i]) - - theta = np.array([[np.cos(a) * s, -np.sin(a) * s], - [np.sin(a) * s, np.cos(a) * s]]) - - # perform transformation - for i_frame in range(T): - xy = data_numpy[0:2, i_frame, :, :] - new_xy = np.dot(theta[:, :, i_frame], xy.reshape(2, -1)) - new_xy[0] += t_x[i_frame] - new_xy[1] += t_y[i_frame] - # print(new_xy.shape, data_numpy.shape) - # data_numpy[0:2, i_frame, :, :] = new_xy.reshape(2, V, M) - new_xy = new_xy.reshape(2, V, M) - data_numpy[0:2, i_frame, :, :] = new_xy - - return data_numpy - - -def rand_rotate(data_numpy,rand_rotate): - # input: C,T,V,M - C, T, V, M = data_numpy.shape - - R = np.eye(3) - for i in range(3): - theta = (np.random.rand()*2 -1)*rand_rotate * np.pi - Ri = np.eye(3) - Ri[C - 1, C - 1] = 1 - Ri[0, 0] = np.cos(theta) - Ri[0, 1] = np.sin(theta) - Ri[1, 0] = -np.sin(theta) - Ri[1, 1] = np.cos(theta) - R = R * Ri - - data_numpy = np.matmul(R,data_numpy.reshape(C,T*V*M)).reshape(C,T,V,M).astype('float32') - return data_numpy - - -def random_shift(data_numpy): - # input: C,T,V,M - C, T, V, M = data_numpy.shape - data_shift = np.zeros(data_numpy.shape) - valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0 - begin = valid_frame.argmax() - end = len(valid_frame) - valid_frame[::-1].argmax() - - size = end - begin - bias = random.randint(0, T - size) - data_shift[:, bias:bias + size, :, :] = data_numpy[:, begin:end, :, :] - - return data_shift - - -def openpose_match(data_numpy): - C, T, V, M = data_numpy.shape - assert (C == 3) - score = data_numpy[2, :, :, :].sum(axis=1) - # the rank of body confidence in each frame (shape: T-1, M) - rank = (-score[0:T - 1]).argsort(axis=1).reshape(T - 1, M) - - # data of frame 1 - xy1 = data_numpy[0:2, 0:T - 1, :, :].reshape(2, T - 1, V, M, 1) - # data of frame 2 - xy2 = data_numpy[0:2, 1:T, :, :].reshape(2, T - 1, V, 1, M) - # square of distance between frame 1&2 (shape: T-1, M, M) - distance = ((xy2 - xy1)**2).sum(axis=2).sum(axis=0) - - # match pose - forward_map = np.zeros((T, M), dtype=int) - 1 - forward_map[0] = range(M) - for m in range(M): - choose = (rank == m) - forward = distance[choose].argmin(axis=1) - for t in range(T - 1): - distance[t, :, forward[t]] = np.inf - forward_map[1:][choose] = forward - assert (np.all(forward_map >= 0)) - - # string data - for t in range(T - 1): - forward_map[t + 1] = forward_map[t + 1][forward_map[t]] - - # generate data - new_data_numpy = np.zeros(data_numpy.shape) - for t in range(T): - new_data_numpy[:, t, :, :] = data_numpy[:, t, :, forward_map[ - t]].transpose(1, 2, 0) - data_numpy = new_data_numpy - - # score sort - trace_score = data_numpy[2, :, :, :].sum(axis=1).sum(axis=0) - rank = (-trace_score).argsort() - data_numpy = data_numpy[:, :, :, rank] - - return data_numpy - - -def top_k_by_category(label, score, top_k): - instance_num, class_num = score.shape - rank = score.argsort() - hit_top_k = [[] for i in range(class_num)] - for i in range(instance_num): - l = label[i] - hit_top_k[l].append(l in rank[i, -top_k:]) - - accuracy_list = [] - for hit_per_category in hit_top_k: - if hit_per_category: - accuracy_list.append(sum(hit_per_category) * 1.0 / len(hit_per_category)) - else: - accuracy_list.append(0.0) - return accuracy_list - - -def calculate_recall_precision(label, score): - instance_num, class_num = score.shape - rank = score.argsort() - confusion_matrix = np.zeros([class_num, class_num]) - - for i in range(instance_num): - true_l = label[i] - pred_l = rank[i, -1] - confusion_matrix[true_l][pred_l] += 1 - - precision = [] - recall = [] - - for i in range(class_num): - true_p = confusion_matrix[i][i] - false_n = sum(confusion_matrix[i, :]) - true_p - false_p = sum(confusion_matrix[:, i]) - true_p - precision.append(true_p * 1.0 / (true_p + false_p)) - recall.append(true_p * 1.0 / (true_p + false_n)) - - return precision, recall - - -def repeat_pading(data_numpy): - data_tmp = np.transpose(data_numpy, [3,1,2,0]) # [2,300,25,3] - for i_p, person in enumerate(data_tmp): - if person.sum()==0: - continue - if person[0].sum()==0: - index = (person.sum(-1).sum(-1)!=0) - tmp = person[index].copy() - person*=0 - person[:len(tmp)] = tmp - for i_f, frame in enumerate(person): - if frame.sum()==0: - if person[i_f:].sum()==0: - rest = len(person)-i_f - num = int(np.ceil(rest/i_f)) - pad = np.concatenate([person[0:i_f] for _ in range(num)], 0)[:rest] - data_tmp[i_p,i_f:] = pad - break - data_numpy = np.transpose(data_tmp, [3,1,2,0]) +import numpy as np +import random + + +def downsample(data_numpy, step, random_sample=True): + # input: C,T,V,M + begin = np.random.randint(step) if random_sample else 0 + return data_numpy[:, begin::step, :, :] + + +def temporal_slice(data_numpy, step): + # input: C,T,V,M + C, T, V, M = data_numpy.shape + return data_numpy.reshape(C, T / step, step, V, M).transpose( + (0, 1, 3, 2, 4)).reshape(C, T / step, V, step * M) + + +def mean_subtractor(data_numpy, mean): + # input: C,T,V,M + # naive version + if mean == 0: + return + C, T, V, M = data_numpy.shape + valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0 + begin = valid_frame.argmax() + end = len(valid_frame) - valid_frame[::-1].argmax() + data_numpy[:, :end, :, :] = data_numpy[:, :end, :, :] - mean + return data_numpy + + +def auto_pading(data_numpy, size, random_pad=False): + C, T, V, M = data_numpy.shape + if T < size: + begin = random.randint(0, size - T) if random_pad else 0 + data_numpy_paded = np.zeros((C, size, V, M)) + data_numpy_paded[:, begin:begin + T, :, :] = data_numpy + return data_numpy_paded + else: + return data_numpy + + +def random_choose(data_numpy, size, auto_pad=True): + # input: C,T,V,M + C, T, V, M = data_numpy.shape + if T == size: + return data_numpy + elif T < size: + if auto_pad: + return auto_pading(data_numpy, size, random_pad=True) + else: + return data_numpy + else: + begin = random.randint(0, T - size) + return data_numpy[:, begin:begin + size, :, :] + + +def random_move(data_numpy, + angle_candidate=[-10., -5., 0., 5., 10.], + scale_candidate=[0.9, 1.0, 1.1], + transform_candidate=[-0.2, -0.1, 0.0, 0.1, 0.2], + move_time_candidate=[1]): + # input: C,T,V,M + C, T, V, M = data_numpy.shape + move_time = random.choice(move_time_candidate) + node = np.arange(0, T, T * 1.0 / move_time).round().astype(int) + node = np.append(node, T) + num_node = len(node) + + A = np.random.choice(angle_candidate, num_node) + S = np.random.choice(scale_candidate, num_node) + T_x = np.random.choice(transform_candidate, num_node) + T_y = np.random.choice(transform_candidate, num_node) + + a = np.zeros(T) + s = np.zeros(T) + t_x = np.zeros(T) + t_y = np.zeros(T) + + # linspace + for i in range(num_node - 1): + a[node[i]:node[i + 1]] = np.linspace( + A[i], A[i + 1], node[i + 1] - node[i]) * np.pi / 180 + s[node[i]:node[i + 1]] = np.linspace(S[i], S[i + 1], + node[i + 1] - node[i]) + t_x[node[i]:node[i + 1]] = np.linspace(T_x[i], T_x[i + 1], + node[i + 1] - node[i]) + t_y[node[i]:node[i + 1]] = np.linspace(T_y[i], T_y[i + 1], + node[i + 1] - node[i]) + + theta = np.array([[np.cos(a) * s, -np.sin(a) * s], + [np.sin(a) * s, np.cos(a) * s]]) + + # perform transformation + for i_frame in range(T): + xy = data_numpy[0:2, i_frame, :, :] + new_xy = np.dot(theta[:, :, i_frame], xy.reshape(2, -1)) + new_xy[0] += t_x[i_frame] + new_xy[1] += t_y[i_frame] + # print(new_xy.shape, data_numpy.shape) + # data_numpy[0:2, i_frame, :, :] = new_xy.reshape(2, V, M) + new_xy = new_xy.reshape(2, V, M) + data_numpy[0:2, i_frame, :, :] = new_xy + + return data_numpy + + +def rand_rotate(data_numpy,rand_rotate): + # input: C,T,V,M + C, T, V, M = data_numpy.shape + + R = np.eye(3) + for i in range(3): + theta = (np.random.rand()*2 -1)*rand_rotate * np.pi + Ri = np.eye(3) + Ri[C - 1, C - 1] = 1 + Ri[0, 0] = np.cos(theta) + Ri[0, 1] = np.sin(theta) + Ri[1, 0] = -np.sin(theta) + Ri[1, 1] = np.cos(theta) + R = R * Ri + + data_numpy = np.matmul(R,data_numpy.reshape(C,T*V*M)).reshape(C,T,V,M).astype('float32') + return data_numpy + + +def random_shift(data_numpy): + # input: C,T,V,M + C, T, V, M = data_numpy.shape + data_shift = np.zeros(data_numpy.shape) + valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0 + begin = valid_frame.argmax() + end = len(valid_frame) - valid_frame[::-1].argmax() + + size = end - begin + bias = random.randint(0, T - size) + data_shift[:, bias:bias + size, :, :] = data_numpy[:, begin:end, :, :] + + return data_shift + + +def openpose_match(data_numpy): + C, T, V, M = data_numpy.shape + assert (C == 3) + score = data_numpy[2, :, :, :].sum(axis=1) + # the rank of body confidence in each frame (shape: T-1, M) + rank = (-score[0:T - 1]).argsort(axis=1).reshape(T - 1, M) + + # data of frame 1 + xy1 = data_numpy[0:2, 0:T - 1, :, :].reshape(2, T - 1, V, M, 1) + # data of frame 2 + xy2 = data_numpy[0:2, 1:T, :, :].reshape(2, T - 1, V, 1, M) + # square of distance between frame 1&2 (shape: T-1, M, M) + distance = ((xy2 - xy1)**2).sum(axis=2).sum(axis=0) + + # match pose + forward_map = np.zeros((T, M), dtype=int) - 1 + forward_map[0] = range(M) + for m in range(M): + choose = (rank == m) + forward = distance[choose].argmin(axis=1) + for t in range(T - 1): + distance[t, :, forward[t]] = np.inf + forward_map[1:][choose] = forward + assert (np.all(forward_map >= 0)) + + # string data + for t in range(T - 1): + forward_map[t + 1] = forward_map[t + 1][forward_map[t]] + + # generate data + new_data_numpy = np.zeros(data_numpy.shape) + for t in range(T): + new_data_numpy[:, t, :, :] = data_numpy[:, t, :, forward_map[ + t]].transpose(1, 2, 0) + data_numpy = new_data_numpy + + # score sort + trace_score = data_numpy[2, :, :, :].sum(axis=1).sum(axis=0) + rank = (-trace_score).argsort() + data_numpy = data_numpy[:, :, :, rank] + + return data_numpy + + +def top_k_by_category(label, score, top_k): + instance_num, class_num = score.shape + rank = score.argsort() + hit_top_k = [[] for i in range(class_num)] + for i in range(instance_num): + l = label[i] + hit_top_k[l].append(l in rank[i, -top_k:]) + + accuracy_list = [] + for hit_per_category in hit_top_k: + if hit_per_category: + accuracy_list.append(sum(hit_per_category) * 1.0 / len(hit_per_category)) + else: + accuracy_list.append(0.0) + return accuracy_list + + +def calculate_recall_precision(label, score): + instance_num, class_num = score.shape + rank = score.argsort() + confusion_matrix = np.zeros([class_num, class_num]) + + for i in range(instance_num): + true_l = label[i] + pred_l = rank[i, -1] + confusion_matrix[true_l][pred_l] += 1 + + precision = [] + recall = [] + + for i in range(class_num): + true_p = confusion_matrix[i][i] + false_n = sum(confusion_matrix[i, :]) - true_p + false_p = sum(confusion_matrix[:, i]) - true_p + precision.append(true_p * 1.0 / (true_p + false_p)) + recall.append(true_p * 1.0 / (true_p + false_n)) + + return precision, recall + + +def repeat_pading(data_numpy): + data_tmp = np.transpose(data_numpy, [3,1,2,0]) # [2,300,25,3] + for i_p, person in enumerate(data_tmp): + if person.sum()==0: + continue + if person[0].sum()==0: + index = (person.sum(-1).sum(-1)!=0) + tmp = person[index].copy() + person*=0 + person[:len(tmp)] = tmp + for i_f, frame in enumerate(person): + if frame.sum()==0: + if person[i_f:].sum()==0: + rest = len(person)-i_f + num = int(np.ceil(rest/i_f)) + pad = np.concatenate([person[0:i_f] for _ in range(num)], 0)[:rest] + data_tmp[i_p,i_f:] = pad + break + data_numpy = np.transpose(data_tmp, [3,1,2,0]) return data_numpy \ No newline at end of file diff --git a/img/readme.md b/img/readme.md index 568dd30..d09b9d7 100644 --- a/img/readme.md +++ b/img/readme.md @@ -1 +1 @@ -Image used in the repository. +Image used in the repository. diff --git a/log/best_performance/epoch_loss_class_eval.png b/log/best_performance/epoch_loss_class_eval.png new file mode 100644 index 0000000..246c771 Binary files /dev/null and b/log/best_performance/epoch_loss_class_eval.png differ diff --git a/log/best_performance/epoch_loss_class_train.png b/log/best_performance/epoch_loss_class_train.png new file mode 100644 index 0000000..51e019e Binary files /dev/null and b/log/best_performance/epoch_loss_class_train.png differ diff --git a/log/best_performance/epoch_loss_class_train.txt b/log/best_performance/epoch_loss_class_train.txt new file mode 100644 index 0000000..515fdfc --- /dev/null +++ b/log/best_performance/epoch_loss_class_train.txt @@ -0,0 +1,90 @@ +3.7712276314242685 +3.0636168965986927 +2.801343619823456 +2.6331870392384205 +2.5135307890929712 +2.4291697614202716 +2.3555361779060604 +2.2984420201291202 +2.2522135552288742 +2.197252894587989 +2.097349743214381 +1.9542330193516024 +1.7952129174223848 +1.6156651931379018 +1.4707137380617539 +1.3638808261168214 +1.2811493240666694 +1.2145853494477985 +1.1464360410734533 +1.0886454173869087 +1.0415393155084143 +1.0067537130545055 +0.9712598563652918 +0.9459512437851009 +0.9096426171339108 +0.883714666565097 +0.8630674364239435 +0.8406433062055244 +0.820840925855913 +0.8079694209025758 +0.7841788579339185 +0.776262023339371 +0.75239310020289 +0.7407916599995307 +0.7314519268123639 +0.7205715374548911 +0.7119348809259342 +0.6838741163310618 +0.675083714229081 +0.6665253263740542 +0.40875943099878626 +0.33014846760375033 +0.301763350591267 +0.27762719021237026 +0.26088881696264316 +0.2430464035457363 +0.2282850744710909 +0.2212755358291129 +0.203683311874431 +0.1944100352668947 +0.18915298141420142 +0.18149443655029596 +0.17873722235375974 +0.171926496180016 +0.164108280483452 +0.1580632201443505 +0.1529976990462558 +0.14594485707745689 +0.1456607071607638 +0.14255504573110114 +0.11923092203425435 +0.10815973842863215 +0.1090422415932918 +0.10648260271793217 +0.09845620263738979 +0.100937618059459 +0.10059720761977203 +0.09929326052758904 +0.09756028030852115 +0.09850751669961134 +0.09668084698679134 +0.0974692665948768 +0.09672294503566119 +0.09527862614642296 +0.09309141301707675 +0.09378876050791575 +0.094307657253778 +0.09193214094837932 +0.09031117155016503 +0.08870918162732286 +0.08976529900305628 +0.08807121031712889 +0.08787393360298594 +0.09003039282064657 +0.08867947479501018 +0.08774756830658956 +0.08815590344746861 +0.08792258952912853 +0.09093503984912055 +0.08763042784054735 diff --git a/log/data_tree.log b/log/data_tree.log new file mode 100644 index 0000000..f98e62c --- /dev/null +++ b/log/data_tree.log @@ -0,0 +1,16 @@ +data +|-- NTU-RGB+D +| `-- samples_with_missing_skeletons.txt +`-- nturgb_d + |-- xsub + | |-- train_data_joint_pad.npy + | |-- train_label.pkl + | |-- val_data_joint_pad.npy + | `-- val_label.pkl + `-- xview + |-- train_data_joint_pad.npy + |-- train_label.pkl + |-- val_data_joint_pad.npy + `-- val_label.pkl + +4 directories, 9 files diff --git a/log/train_aim.log b/log/train_aim.log new file mode 100644 index 0000000..30a1033 --- /dev/null +++ b/log/train_aim.log @@ -0,0 +1,21 @@ +$python main.py recognition -c config/as_gcn/ntu-xsub/train_aim.yaml --device 0 1 2 + +/root/AS-GCN/processor/io.py:34: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details. + default_arg = yaml.load(f) +/root/AS-GCN/net/utils/adj_learn.py:18: UserWarning: This overload of nonzero is deprecated: + nonzero() +Consider using one of the following signatures instead: + nonzero(*, bool as_tuple) (Triggered internally at /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.) + offdiag_indices = (ones - eye).nonzero().t() +/root/anaconda3/envs/stgcn/lib/python3.6/site-packages/torch/nn/modules/container.py:435: UserWarning: Setting attributes on ParameterList is not supported. + warnings.warn("Setting attributes on ParameterList is not supported.") +/root/AS-GCN/net/utils/adj_learn.py:11: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument. + soft_max_1d = F.softmax(trans_input) + +[05.08.21|19:25:22] Parameters: +{'work_dir': './work_dir/recognition/ntu-xsub/AS_GCN', 'config': 'config/as_gcn/ntu-xsub/train_aim.yaml', 'phase': 'train', 'save_result': False, 'start_epoch': 0, 'num_epoch': 10, 'use_gpu': True, 'device': [0, 1, 2], 'log_interval': 100, 'save_interval': 1, 'eval_interval': 5, 'save_log': True, 'print_log': True, 'pavi_log': False, 'feeder': 'feeder.feeder.Feeder', 'num_worker': 4, 'train_feeder_args': {'data_path': './data/nturgb_d/xsub/train_data_joint_pad.npy', 'label_path': './data/nturgb_d/xsub/train_label.pkl', 'random_move': True, 'repeat_pad': True, 'down_sample': True, 'debug': False}, 'test_feeder_args': {'data_path': './data/nturgb_d/xsub/val_data_joint_pad.npy', 'label_path': './data/nturgb_d/xsub/val_label.pkl', 'random_move': False, 'repeat_pad': True, 'down_sample': True}, 'batch_size': 32, 'test_batch_size': 32, 'debug': False, 'model1': 'net.as_gcn.Model', 'model2': 'net.utils.adj_learn.AdjacencyLearn', 'model1_args': {'in_channels': 3, 'num_class': 60, 'dropout': 0.5, 'edge_importance_weighting': True, 'graph_args': {'layout': 'ntu-rgb+d', 'strategy': 'spatial', 'max_hop': 4}}, 'model2_args': {'n_in_enc': 150, 'n_hid_enc': 128, 'edge_types': 3, 'n_in_dec': 3, 'n_hid_dec': 128, 'node_num': 25}, 'weights1': None, 'weights2': None, 'ignore_weights': [], 'show_topk': [1, 5], 'base_lr1': 0.1, 'base_lr2': 0.0005, 'step': [50, 70, 90], 'optimizer': 'SGD', 'nesterov': True, 'weight_decay': 0.0001, 'max_hop_dir': 'max_hop_4', 'lamda_act': 0.5, 'lamda_act_dir': 'lamda_05'} + +[05.08.21|19:25:22] Training epoch: 0 +[05.08.21|19:25:29] Iter 0 Done. | loss2: 876.8732 | loss_nll: 832.9621 | loss_kl: 43.9111 | lr: 0.000500 +[05.08.21|19:26:56] Iter 100 Done. | loss2: 118.9051 | loss_nll: 110.3876 | loss_kl: 8.5176 | lr: 0.000500 +[05.08.21|19:28:14] Iter 200 Done. | loss2: 76.1775 | loss_nll: 71.7404 | loss_kl: 4.4371 | lr: 0.000500 diff --git a/main.py b/main.py index ee1a0a2..c81bb7b 100644 --- a/main.py +++ b/main.py @@ -1,24 +1,24 @@ -import argparse -import sys -import torchlight -from torchlight import import_class - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser(description='Processor collection') - - processors = dict() - processors['recognition'] = import_class('processor.recognition.REC_Processor') - processors['demo'] = import_class('processor.demo.Demo') - - subparsers = parser.add_subparsers(dest='processor') - for k, p in processors.items(): - subparsers.add_parser(k, parents=[p.get_parser()]) - - arg = parser.parse_args() - - # start - Processor = processors[arg.processor] - p = Processor(sys.argv[2:]) - p.start() +import argparse +import sys +import torchlight +from torchlight.io import import_class + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='Processor collection') + + processors = dict() + processors['recognition'] = import_class('processor.recognition.REC_Processor') + #processors['demo'] = import_class('processor.demo.Demo') + + subparsers = parser.add_subparsers(dest='processor') + for k, p in processors.items(): + subparsers.add_parser(k, parents=[p.get_parser()]) + + arg = parser.parse_args() + + # start + Processor = processors[arg.processor] + p = Processor(sys.argv[2:]) + p.start() diff --git a/net/__init__.py b/net/__init__.py index 8b13789..d3f5a12 100644 --- a/net/__init__.py +++ b/net/__init__.py @@ -1 +1 @@ - + diff --git a/net/as_gcn.py b/net/as_gcn.py index 7468be4..f49b08d 100644 --- a/net/as_gcn.py +++ b/net/as_gcn.py @@ -1,308 +1,307 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import Variable - -from net.utils.graph import Graph - - -class Model(nn.Module): - - def __init__(self, in_channels, num_class, graph_args, - edge_importance_weighting, **kwargs): - super().__init__() - - self.graph = Graph(**graph_args) - A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False) - self.register_buffer('A', A) - self.edge_type = 2 - - temporal_kernel_size = 9 - spatial_kernel_size = A.size(0) + self.edge_type - st_kernel_size = (temporal_kernel_size, spatial_kernel_size) - - self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) - - self.class_layer_0 = StgcnBlock(in_channels, 64, st_kernel_size, self.edge_type, stride=1, residual=False, **kwargs) - self.class_layer_1 = StgcnBlock(64, 64, st_kernel_size, self.edge_type, stride=1, **kwargs) - self.class_layer_2 = StgcnBlock(64, 64, st_kernel_size, self.edge_type, stride=1, **kwargs) - self.class_layer_3 = StgcnBlock(64, 128, st_kernel_size, self.edge_type, stride=2, **kwargs) - self.class_layer_4 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=1, **kwargs) - self.class_layer_5 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=1, **kwargs) - self.class_layer_6 = StgcnBlock(128, 256, st_kernel_size, self.edge_type, stride=2, **kwargs) - self.class_layer_7 = StgcnBlock(256, 256, st_kernel_size, self.edge_type, stride=1, **kwargs) - self.class_layer_8 = StgcnBlock(256, 256, st_kernel_size, self.edge_type, stride=1, **kwargs) - - self.recon_layer_0 = StgcnBlock(256, 128, st_kernel_size, self.edge_type, stride=1, **kwargs) - self.recon_layer_1 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=2, **kwargs) - self.recon_layer_2 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=2, **kwargs) - self.recon_layer_3 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=2, **kwargs) - self.recon_layer_4 = StgcnBlock(128, 128, (3, spatial_kernel_size), self.edge_type, stride=2, **kwargs) - self.recon_layer_5 = StgcnBlock(128, 128, (5, spatial_kernel_size), self.edge_type, stride=1, padding=False, residual=False, **kwargs) - self.recon_layer_6 = StgcnReconBlock(128+3, 30, (1, spatial_kernel_size), self.edge_type, stride=1, padding=False, residual=False, activation=None, **kwargs) - - - if edge_importance_weighting: - self.edge_importance = nn.ParameterList([nn.Parameter(torch.ones(self.A.size())) for i in range(9)]) - self.edge_importance_recon = nn.ParameterList([nn.Parameter(torch.ones(self.A.size())) for i in range(9)]) - else: - self.edge_importance = [1] * (len(self.st_gcn_networks)+len(self.st_gcn_recon)) - self.fcn = nn.Conv2d(256, num_class, kernel_size=1) - - def forward(self, x, x_target, x_last, A_act, lamda_act): - - N, C, T, V, M = x.size() - x_recon = x[:,:,:,:,0] # [2N, 3, 300, 25] - x = x.permute(0, 4, 3, 1, 2).contiguous() # [N, 2, 25, 3, 300] - x = x.view(N * M, V * C, T) # [2N, 75, 300] - - x_last = x_last.permute(0,4,1,2,3).contiguous().view(-1,3,1,25) - - x_bn = self.data_bn(x) - x_bn = x_bn.view(N, M, V, C, T) - x_bn = x_bn.permute(0, 1, 3, 4, 2).contiguous() - x_bn = x_bn.view(N * M, C, T, V) - - h0, _ = self.class_layer_0(x_bn, self.A * self.edge_importance[0], A_act, lamda_act) # [N, 64, 300, 25] - h1, _ = self.class_layer_1(h0, self.A * self.edge_importance[1], A_act, lamda_act) # [N, 64, 300, 25] - h1, _ = self.class_layer_1(h0, self.A * self.edge_importance[1], A_act, lamda_act) # [N, 64, 300, 25] - h2, _ = self.class_layer_2(h1, self.A * self.edge_importance[2], A_act, lamda_act) # [N, 64, 300, 25] - h3, _ = self.class_layer_3(h2, self.A * self.edge_importance[3], A_act, lamda_act) # [N, 128, 150, 25] - h4, _ = self.class_layer_4(h3, self.A * self.edge_importance[4], A_act, lamda_act) # [N, 128, 150, 25] - h5, _ = self.class_layer_5(h4, self.A * self.edge_importance[5], A_act, lamda_act) # [N, 128, 150, 25] - h6, _ = self.class_layer_6(h5, self.A * self.edge_importance[6], A_act, lamda_act) # [N, 256, 75, 25] - h7, _ = self.class_layer_7(h6, self.A * self.edge_importance[7], A_act, lamda_act) # [N, 256, 75, 25] - h8, _ = self.class_layer_8(h7, self.A * self.edge_importance[8], A_act, lamda_act) # [N, 256, 75, 25] - - x_class = F.avg_pool2d(h8, h8.size()[2:]) - x_class = x_class.view(N, M, -1, 1, 1).mean(dim=1) - x_class = self.fcn(x_class) - x_class = x_class.view(x_class.size(0), -1) - - r0, _ = self.recon_layer_0(h8, self.A*self.edge_importance_recon[0], A_act, lamda_act) # [N, 128, 75, 25] - r1, _ = self.recon_layer_1(r0, self.A*self.edge_importance_recon[1], A_act, lamda_act) # [N, 128, 38, 25] - r2, _ = self.recon_layer_2(r1, self.A*self.edge_importance_recon[2], A_act, lamda_act) # [N, 128, 19, 25] - r3, _ = self.recon_layer_3(r2, self.A*self.edge_importance_recon[3], A_act, lamda_act) # [N, 128, 10, 25] - r4, _ = self.recon_layer_4(r3, self.A*self.edge_importance_recon[4], A_act, lamda_act) # [N, 128, 5, 25] - r5, _ = self.recon_layer_5(r4, self.A*self.edge_importance_recon[5], A_act, lamda_act) # [N, 128, 1, 25] - r6, _ = self.recon_layer_6(torch.cat((r5, x_last),1), self.A*self.edge_importance_recon[6], A_act, lamda_act) # [N, 64, 1, 25] - pred = x_last.squeeze().repeat(1,10,1) + r6.squeeze() # [N, 3, 25] - - pred = pred.contiguous().view(-1, 3, 10, 25) - x_target = x_target.permute(0,4,1,2,3).contiguous().view(-1,3,10,25) - - return x_class, pred[::2], x_target[::2] - - def extract_feature(self, x): - - N, C, T, V, M = x.size() - x = x.permute(0, 4, 3, 1, 2).contiguous() - x = x.view(N * M, V * C, T) - x = self.data_bn(x) - x = x.view(N, M, V, C, T) - x = x.permute(0, 1, 3, 4, 2).contiguous() - x = x.view(N * M, C, T, V) - - for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): - x, _ = gcn(x, self.A * importance) - - _, c, t, v = x.size() - feature = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1) - - x = self.fcn(x) - output = x.view(N, M, -1, t, v).permute(0, 2, 3, 4, 1) - - return output, feature - - -class StgcnBlock(nn.Module): - - def __init__(self, - in_channels, - out_channels, - kernel_size, - edge_type=2, - t_kernel_size=1, - stride=1, - padding=True, - dropout=0, - residual=True): - super().__init__() - - assert len(kernel_size) == 2 - assert kernel_size[0] % 2 == 1 - if padding == True: - padding = ((kernel_size[0] - 1) // 2, 0) - else: - padding = (0,0) - - self.gcn = SpatialGcn(in_channels=in_channels, - out_channels=out_channels, - k_num=kernel_size[1], - edge_type=edge_type, - t_kernel_size=t_kernel_size) - self.tcn = nn.Sequential(nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - nn.Conv2d(out_channels, - out_channels, - (kernel_size[0], 1), - (stride, 1), - padding), - nn.BatchNorm2d(out_channels), - nn.Dropout(dropout, inplace=True)) - - if not residual: - self.residual = lambda x: 0 - elif (in_channels == out_channels) and (stride == 1): - self.residual = lambda x: x - else: - self.residual = nn.Sequential(nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=(stride, 1)), - nn.BatchNorm2d(out_channels)) - self.relu = nn.ReLU(inplace=True) - - def forward(self, x, A, B, lamda_act): - - res = self.residual(x) - x, A = self.gcn(x, A, B, lamda_act) - x = self.tcn(x) + res - - return self.relu(x), A - - -class StgcnReconBlock(nn.Module): - - def __init__(self, - in_channels, - out_channels, - kernel_size, - edge_type=2, - t_kernel_size=1, - stride=1, - padding=True, - dropout=0, - residual=True, - activation='relu'): - super().__init__() - - assert len(kernel_size) == 2 - assert kernel_size[0] % 2 == 1 - - if padding == True: - padding = ((kernel_size[0] - 1) // 2, 0) - else: - padding = (0,0) - - self.gcn_recon = SpatialGcnRecon(in_channels=in_channels, - out_channels=out_channels, - k_num=kernel_size[1], - edge_type=edge_type, - t_kernel_size=t_kernel_size) - self.tcn_recon = nn.Sequential(nn.BatchNorm2d(out_channels), - nn.ReLU(inplace=True), - nn.ConvTranspose2d(in_channels=out_channels, - out_channels=out_channels, - kernel_size=(kernel_size[0], 1), - stride=(stride, 1), - padding=padding, - output_padding=(stride-1,0)), - nn.BatchNorm2d(out_channels), - nn.Dropout(dropout, inplace=True)) - - if not residual: - self.residual = lambda x: 0 - elif (in_channels == out_channels) and (stride == 1): - self.residual = lambda x: x - else: - self.residual = nn.Sequential(nn.ConvTranspose2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - stride=(stride, 1), - output_padding=(stride-1,0)), - nn.BatchNorm2d(out_channels)) - self.relu = nn.ReLU(inplace=True) - self.activation = activation - - def forward(self, x, A, B, lamda_act): - - res = self.residual(x) - x, A = self.gcn_recon(x, A, B, lamda_act) - x = self.tcn_recon(x) + res - if self.activation == 'relu': - x = self.relu(x) - else: - x = x - - return x, A - - -class SpatialGcn(nn.Module): - - def __init__(self, - in_channels, - out_channels, - k_num, - edge_type=2, - t_kernel_size=1, - t_stride=1, - t_padding=0, - t_dilation=1, - bias=True): - super().__init__() - - self.k_num = k_num - self.edge_type = edge_type - self.conv = nn.Conv2d(in_channels=in_channels, - out_channels=out_channels*k_num, - kernel_size=(t_kernel_size, 1), - padding=(t_padding, 0), - stride=(t_stride, 1), - dilation=(t_dilation, 1), - bias=bias) - - def forward(self, x, A, B, lamda_act): - - x = self.conv(x) - n, kc, t, v = x.size() - x = x.view(n, self.k_num, kc//self.k_num, t, v) - x1 = x[:,:self.k_num-self.edge_type,:,:,:] - x2 = x[:,-self.edge_type:,:,:,:] - x1 = torch.einsum('nkctv,kvw->nctw', (x1, A)) - x2 = torch.einsum('nkctv,nkvw->nctw', (x2, B)) - x_sum = x1+x2*lamda_act - - return x_sum.contiguous(), A - - -class SpatialGcnRecon(nn.Module): - - def __init__(self, in_channels, out_channels, k_num, edge_type=3, - t_kernel_size=1, t_stride=1, t_padding=0, t_outpadding=0, t_dilation=1, - bias=True): - super().__init__() - - self.k_num = k_num - self.edge_type = edge_type - self.deconv = nn.ConvTranspose2d(in_channels=in_channels, - out_channels=out_channels*k_num, - kernel_size=(t_kernel_size, 1), - padding=(t_padding, 0), - output_padding=(t_outpadding, 0), - stride=(t_stride, 1), - dilation=(t_dilation, 1), - bias=bias) - - def forward(self, x, A, B, lamda_act): - - x = self.deconv(x) - n, kc, t, v = x.size() - x = x.view(n, self.k_num, kc//self.k_num, t, v) - x1 = x[:,:self.k_num-self.edge_type,:,:,:] - x2 = x[:,-self.edge_type:,:,:,:] - x1 = torch.einsum('nkctv,kvw->nctw', (x1, A)) - x2 = torch.einsum('nkctv,nkvw->nctw', (x2, B)) - x_sum = x1+x2*lamda_act - - return x_sum.contiguous(), A +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + +from net.utils.graph import Graph + + +class Model(nn.Module): + + def __init__(self, in_channels, num_class, graph_args, + edge_importance_weighting, **kwargs): + super().__init__() + + self.graph = Graph(**graph_args) + A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False) + self.register_buffer('A', A) + self.edge_type = 2 + + temporal_kernel_size = 9 + spatial_kernel_size = A.size(0) + self.edge_type + st_kernel_size = (temporal_kernel_size, spatial_kernel_size) + + self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) + + self.class_layer_0 = StgcnBlock(in_channels, 64, st_kernel_size, self.edge_type, stride=1, residual=False, **kwargs) + self.class_layer_1 = StgcnBlock(64, 64, st_kernel_size, self.edge_type, stride=1, **kwargs) + self.class_layer_2 = StgcnBlock(64, 64, st_kernel_size, self.edge_type, stride=1, **kwargs) + self.class_layer_3 = StgcnBlock(64, 128, st_kernel_size, self.edge_type, stride=2, **kwargs) + self.class_layer_4 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=1, **kwargs) + self.class_layer_5 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=1, **kwargs) + self.class_layer_6 = StgcnBlock(128, 256, st_kernel_size, self.edge_type, stride=2, **kwargs) + self.class_layer_7 = StgcnBlock(256, 256, st_kernel_size, self.edge_type, stride=1, **kwargs) + self.class_layer_8 = StgcnBlock(256, 256, st_kernel_size, self.edge_type, stride=1, **kwargs) + + self.recon_layer_0 = StgcnBlock(256, 128, st_kernel_size, self.edge_type, stride=1, **kwargs) + self.recon_layer_1 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=2, **kwargs) + self.recon_layer_2 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=2, **kwargs) + self.recon_layer_3 = StgcnBlock(128, 128, st_kernel_size, self.edge_type, stride=2, **kwargs) + self.recon_layer_4 = StgcnBlock(128, 128, (3, spatial_kernel_size), self.edge_type, stride=2, **kwargs) + self.recon_layer_5 = StgcnBlock(128, 128, (5, spatial_kernel_size), self.edge_type, stride=1, padding=False, residual=False, **kwargs) + self.recon_layer_6 = StgcnReconBlock(128+3, 30, (1, spatial_kernel_size), self.edge_type, stride=1, padding=False, residual=False, activation=None, **kwargs) + + + if edge_importance_weighting: + self.edge_importance = nn.ParameterList([nn.Parameter(torch.ones(self.A.size())) for i in range(9)]) + self.edge_importance_recon = nn.ParameterList([nn.Parameter(torch.ones(self.A.size())) for i in range(9)]) + else: + self.edge_importance = [1] * (len(self.st_gcn_networks)+len(self.st_gcn_recon)) + self.fcn = nn.Conv2d(256, num_class, kernel_size=1) + + def forward(self, x, x_target, x_last, A_act, lamda_act): + N, C, T, V, M = x.size() + x_recon = x[:,:,:,:,0] # [2N, 3, 300, 25] wsx: x_recon(4,3,290,25) select the first person data? + x = x.permute(0, 4, 3, 1, 2).contiguous() # [N, 2, 25, 3, 300] wsx: x(4,2,25,3,290) + x = x.view(N * M, V * C, T) # [2N, 75, 300]m wsx: x(8,75,290) + + x_last = x_last.permute(0,4,1,2,3).contiguous().view(-1,3,1,25) #(2N,3,1,25) + + x_bn = self.data_bn(x) + x_bn = x_bn.view(N, M, V, C, T) + x_bn = x_bn.permute(0, 1, 3, 4, 2).contiguous() + x_bn = x_bn.view(N * M, C, T, V) #2N,3,290,25 + + h0, _ = self.class_layer_0(x_bn, self.A * self.edge_importance[0], A_act, lamda_act) # [N, 64, 300, 25] + h1, _ = self.class_layer_1(h0, self.A * self.edge_importance[1], A_act, lamda_act) # [N, 64, 300, 25] + h1, _ = self.class_layer_1(h0, self.A * self.edge_importance[1], A_act, lamda_act) # [N, 64, 300, 25] + h2, _ = self.class_layer_2(h1, self.A * self.edge_importance[2], A_act, lamda_act) # [N, 64, 300, 25] + h3, _ = self.class_layer_3(h2, self.A * self.edge_importance[3], A_act, lamda_act) # [N, 128, 150, 25] + h4, _ = self.class_layer_4(h3, self.A * self.edge_importance[4], A_act, lamda_act) # [N, 128, 150, 25] + h5, _ = self.class_layer_5(h4, self.A * self.edge_importance[5], A_act, lamda_act) # [N, 128, 150, 25] + h6, _ = self.class_layer_6(h5, self.A * self.edge_importance[6], A_act, lamda_act) # [N, 256, 75, 25] + h7, _ = self.class_layer_7(h6, self.A * self.edge_importance[7], A_act, lamda_act) # [N, 256, 75, 25] + h8, _ = self.class_layer_8(h7, self.A * self.edge_importance[8], A_act, lamda_act) # [N, 256, 75, 25] + + x_class = F.avg_pool2d(h8, h8.size()[2:]) #(8,256,1,1) + x_class = x_class.view(N, M, -1, 1, 1).mean(dim=1) #(4,256,1,1) + x_class = self.fcn(x_class) #(4,60,1,1) Conv2d(256, 60, kernel_size=(1, 1), stride=(1, 1)) + x_class = x_class.view(x_class.size(0), -1) #(4,60) + + r0, _ = self.recon_layer_0(h8, self.A*self.edge_importance_recon[0], A_act, lamda_act) # [N, 128, 75, 25] + r1, _ = self.recon_layer_1(r0, self.A*self.edge_importance_recon[1], A_act, lamda_act) # [N, 128, 38, 25] + r2, _ = self.recon_layer_2(r1, self.A*self.edge_importance_recon[2], A_act, lamda_act) # [N, 128, 19, 25] + r3, _ = self.recon_layer_3(r2, self.A*self.edge_importance_recon[3], A_act, lamda_act) # [N, 128, 10, 25] + r4, _ = self.recon_layer_4(r3, self.A*self.edge_importance_recon[4], A_act, lamda_act) # [N, 128, 5, 25] + r5, _ = self.recon_layer_5(r4, self.A*self.edge_importance_recon[5], A_act, lamda_act) # [N, 128, 1, 25] + r6, _ = self.recon_layer_6(torch.cat((r5, x_last),1), self.A*self.edge_importance_recon[6], A_act, lamda_act) # [N, 64, 1, 25] wsx:(8,30,1,25) + pred = x_last.squeeze().repeat(1,10,1) + r6.squeeze() # [N, 3, 25] wsx:(8,30,25) + + pred = pred.contiguous().view(-1, 3, 10, 25) + x_target = x_target.permute(0,4,1,2,3).contiguous().view(-1,3,10,25) + + return x_class, pred[::2], x_target[::2] + + def extract_feature(self, x): + + N, C, T, V, M = x.size() + x = x.permute(0, 4, 3, 1, 2).contiguous() + x = x.view(N * M, V * C, T) + x = self.data_bn(x) + x = x.view(N, M, V, C, T) + x = x.permute(0, 1, 3, 4, 2).contiguous() + x = x.view(N * M, C, T, V) + + for gcn, importance in zip(self.st_gcn_networks, self.edge_importance): + x, _ = gcn(x, self.A * importance) + + _, c, t, v = x.size() + feature = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1) + + x = self.fcn(x) + output = x.view(N, M, -1, t, v).permute(0, 2, 3, 4, 1) + + return output, feature + + +class StgcnBlock(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + edge_type=2, + t_kernel_size=1, + stride=1, + padding=True, + dropout=0, + residual=True): + super().__init__() + + assert len(kernel_size) == 2 + assert kernel_size[0] % 2 == 1 + if padding == True: + padding = ((kernel_size[0] - 1) // 2, 0) + else: + padding = (0,0) + + self.gcn = SpatialGcn(in_channels=in_channels, + out_channels=out_channels, + k_num=kernel_size[1], + edge_type=edge_type, + t_kernel_size=t_kernel_size) + self.tcn = nn.Sequential(nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, + out_channels, + (kernel_size[0], 1), + (stride, 1), + padding), + nn.BatchNorm2d(out_channels), + nn.Dropout(dropout, inplace=True)) + + if not residual: + self.residual = lambda x: 0 + elif (in_channels == out_channels) and (stride == 1): + self.residual = lambda x: x + else: + self.residual = nn.Sequential(nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=(stride, 1)), + nn.BatchNorm2d(out_channels)) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x, A, B, lamda_act): + + res = self.residual(x) + x, A = self.gcn(x, A, B, lamda_act) + x = self.tcn(x) + res + + return self.relu(x), A + + +class StgcnReconBlock(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + edge_type=2, + t_kernel_size=1, + stride=1, + padding=True, + dropout=0, + residual=True, + activation='relu'): + super().__init__() + + assert len(kernel_size) == 2 + assert kernel_size[0] % 2 == 1 + + if padding == True: + padding = ((kernel_size[0] - 1) // 2, 0) + else: + padding = (0,0) + + self.gcn_recon = SpatialGcnRecon(in_channels=in_channels, + out_channels=out_channels, + k_num=kernel_size[1], + edge_type=edge_type, + t_kernel_size=t_kernel_size) + self.tcn_recon = nn.Sequential(nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(in_channels=out_channels, + out_channels=out_channels, + kernel_size=(kernel_size[0], 1), + stride=(stride, 1), + padding=padding, + output_padding=(stride-1,0)), + nn.BatchNorm2d(out_channels), + nn.Dropout(dropout, inplace=True)) + + if not residual: + self.residual = lambda x: 0 + elif (in_channels == out_channels) and (stride == 1): + self.residual = lambda x: x + else: + self.residual = nn.Sequential(nn.ConvTranspose2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=(stride, 1), + output_padding=(stride-1,0)), + nn.BatchNorm2d(out_channels)) + self.relu = nn.ReLU(inplace=True) + self.activation = activation + + def forward(self, x, A, B, lamda_act): + + res = self.residual(x) + x, A = self.gcn_recon(x, A, B, lamda_act) + x = self.tcn_recon(x) + res + if self.activation == 'relu': + x = self.relu(x) + else: + x = x + + return x, A + + +class SpatialGcn(nn.Module): + + def __init__(self, + in_channels, + out_channels, + k_num, + edge_type=2, + t_kernel_size=1, + t_stride=1, + t_padding=0, + t_dilation=1, + bias=True): + super().__init__() + + self.k_num = k_num + self.edge_type = edge_type + self.conv = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels*k_num, + kernel_size=(t_kernel_size, 1), + padding=(t_padding, 0), + stride=(t_stride, 1), + dilation=(t_dilation, 1), + bias=bias) + + def forward(self, x, A, B, lamda_act): + + x = self.conv(x) + n, kc, t, v = x.size() + x = x.view(n, self.k_num, kc//self.k_num, t, v) + x1 = x[:,:self.k_num-self.edge_type,:,:,:] + x2 = x[:,-self.edge_type:,:,:,:] + x1 = torch.einsum('nkctv,kvw->nctw', (x1, A)) + x2 = torch.einsum('nkctv,nkvw->nctw', (x2, B)) + x_sum = x1+x2*lamda_act + + return x_sum.contiguous(), A + + +class SpatialGcnRecon(nn.Module): + + def __init__(self, in_channels, out_channels, k_num, edge_type=3, + t_kernel_size=1, t_stride=1, t_padding=0, t_outpadding=0, t_dilation=1, + bias=True): + super().__init__() + + self.k_num = k_num + self.edge_type = edge_type + self.deconv = nn.ConvTranspose2d(in_channels=in_channels, + out_channels=out_channels*k_num, + kernel_size=(t_kernel_size, 1), + padding=(t_padding, 0), + output_padding=(t_outpadding, 0), + stride=(t_stride, 1), + dilation=(t_dilation, 1), + bias=bias) + + def forward(self, x, A, B, lamda_act): + + x = self.deconv(x) + n, kc, t, v = x.size() + x = x.view(n, self.k_num, kc//self.k_num, t, v) + x1 = x[:,:self.k_num-self.edge_type,:,:,:] + x2 = x[:,-self.edge_type:,:,:,:] + x1 = torch.einsum('nkctv,kvw->nctw', (x1, A)) + x2 = torch.einsum('nkctv,nkvw->nctw', (x2, B)) + x_sum = x1+x2*lamda_act + + return x_sum.contiguous(), A diff --git a/net/model_poseformer.py b/net/model_poseformer.py new file mode 100644 index 0000000..d702be4 --- /dev/null +++ b/net/model_poseformer.py @@ -0,0 +1,223 @@ +## Our PoseFormer model was revised from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + +import math +import logging +from functools import partial +from collections import OrderedDict +from einops import rearrange, repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import load_pretrained +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model + + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class PoseTransformer(nn.Module): + def __init__(self, num_frame=9, num_joints=25, in_chans=3, embed_dim_ratio: object = 32, depth=4, + num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=None, + num_class=60 + ): + """ ##########hybrid_backbone=None, representation_size=None, + Args: + num_frame (int, tuple): input frame number + num_joints (int, tuple): joints number + in_chans (int): number of input channels, 2D joints have 2 channels: (x,y) + embed_dim_ratio (int): embedding dimension ratio + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + num_class (int): the pose action class amount 30 + """ + super().__init__() + + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + embed_dim = embed_dim_ratio * num_joints #### temporal embed_dim is num_joints * spatial embedding dim ratio + out_dim = num_joints * 3 #### output dimension is num_joints * 3 + + ### spatial patch embedding + self.Spatial_patch_to_embedding = nn.Linear(3, 32) + self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio)) + + self.Temporal_pos_embed = nn.Parameter(torch.zeros(1, num_frame, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.Spatial_blocks = nn.ModuleList([ + Block( + dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.Spatial_norm = norm_layer(embed_dim_ratio) + self.Temporal_norm = norm_layer(embed_dim) + + ####### A easy way to implement weighted mean + self.weighted_mean = torch.nn.Conv1d(in_channels=num_frame, out_channels=1, kernel_size=1) + + self.head = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim , out_dim), + ) + + # wsx aciton_class_head + self.action_class_head = nn.Conv2d(290, num_class, kernel_size=1) + + # self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) + self.data_bn = nn.BatchNorm1d(3 * 25) + + + + + + def Spatial_forward_features(self, x): + b, _, f, p = x.shape ##### b is batch size, f is number of frames, p is number of joints + x = rearrange(x, 'b c f p -> (b f) p c', ) + + x = self.Spatial_patch_to_embedding(x) + x += self.Spatial_pos_embed + x = self.pos_drop(x) + + for blk in self.Spatial_blocks: + x = blk(x) + + x = self.Spatial_norm(x) + x = rearrange(x, '(b f) w c -> b f (w c)', f=f) + return x + + def forward_features(self, x): + b = x.shape[0] + x += self.Temporal_pos_embed + x = self.pos_drop(x) + for blk in self.blocks: + x = blk(x) + + x = self.Temporal_norm(x) + ##### x size [b, f, emb_dim], then take weighted mean on frame dimension, we only predict 3D pose of the center frame + # x = self.weighted_mean(x) #wsx don't change all frame to one + # x = x.view(b, 1, -1) + return x + + def forward(self, x, x_target): + ''' + # x input shape [170, 81, 17, 2] + x = x.permute(0, 3, 1, 2) #[170, 2, 81, 17] + b, _, _, p = x.shape #[170, 2, 81, 17] b:batch_size p:joint_num + ### now x is [batch_size, 2 channels, receptive frames, joint_num], following image data + ''' + + N, C, T, V, M = x.size() + x = x.permute(0, 4, 3, 1, 2).contiguous() + x = x.view(N * M, V * C, T) + x = self.data_bn(x) + x = x.view(N, M, V, C, T) + x = x.permute(0, 1, 3, 4, 2).contiguous() + x = x.view(N * M, C, T, V) + + x = self.Spatial_forward_features(x) + x = self.forward_features(x) # (2n, 290,800) + + # action_class_head + BatchN, FrameN, FutureN = x.size() + x = x.view(BatchN, FrameN, FutureN, 1) + x_class = F.avg_pool2d(x, x.size()[2:]) + x_class = x_class.view(N, M, -1, 1, 1).mean(dim=1) + x_class = self.action_class_head(x_class) + x_class = x_class.view(x_class.size(0), -1) + + + #action_class = x.permute(0,2,1) #[170, 544, 1] + #action_class = self.action_class_head(action_class) + #action_class = torch.squeeze(action_class) + #x = self.head(x) + #x = x.view(b, 1, p, -1) + + x_target = x_target.permute(0, 4, 1, 2, 3).contiguous().view(-1, 3, 10, 25) + + return x_class, x_target[::2] # [170,1,17,3] + diff --git a/net/utils/__init__.py b/net/utils/__init__.py index 8b13789..d3f5a12 100644 --- a/net/utils/__init__.py +++ b/net/utils/__init__.py @@ -1 +1 @@ - + diff --git a/net/utils/adj_learn.py b/net/utils/adj_learn.py index ea8e503..4580873 100644 --- a/net/utils/adj_learn.py +++ b/net/utils/adj_learn.py @@ -1,283 +1,284 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -import numpy as np -from torch.autograd import Variable - - -def my_softmax(input, axis=1): - trans_input = input.transpose(axis, 0).contiguous() - soft_max_1d = F.softmax(trans_input) - return soft_max_1d.transpose(axis, 0) - - -def get_offdiag_indices(num_nodes): - ones = torch.ones(num_nodes, num_nodes) - eye = torch.eye(num_nodes, num_nodes) - offdiag_indices = (ones - eye).nonzero().t() - offdiag_indices_ = offdiag_indices[0] * num_nodes + offdiag_indices[1] - return offdiag_indices, offdiag_indices_ - - -def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10): - y_soft = gumbel_softmax_sample(logits, tau=tau, eps=eps) - if hard: - shape = logits.size() - _, k = y_soft.data.max(-1) - y_hard = torch.zeros(*shape) - if y_soft.is_cuda: - y_hard = y_hard.cuda() - y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0) - y = Variable(y_hard - y_soft.data) + y_soft - else: - y = y_soft - return y - - -def gumbel_softmax_sample(logits, tau=1, eps=1e-10): - gumbel_noise = sample_gumbel(logits.size(), eps=eps) - if logits.is_cuda: - gumbel_noise = gumbel_noise.cuda() - y = logits + Variable(gumbel_noise) - return my_softmax(y / tau, axis=-1) - - -def sample_gumbel(shape, eps=1e-10): - uniform = torch.rand(shape).float() - return - torch.log(eps - torch.log(uniform + eps)) - - -def encode_onehot(labels): - classes = set(labels) - classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)} - labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32) - return labels_onehot - - -class MLP(nn.Module): - - def __init__(self, n_in, n_hid, n_out, do_prob=0.): - super().__init__() - - self.fc1 = nn.Linear(n_in, n_hid) - self.fc2 = nn.Linear(n_hid, n_out) - self.bn = nn.BatchNorm1d(n_out) - self.dropout = nn.Dropout(p=do_prob) - - self.init_weights() - - def init_weights(self): - for m in self.modules(): - if isinstance(m, nn.Linear): - nn.init.xavier_normal_(m.weight.data) - m.bias.data.fill_(0.1) - elif isinstance(m, nn.BatchNorm1d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def batch_norm(self, inputs): - x = inputs.view(inputs.size(0) * inputs.size(1), -1) - x = self.bn(x) - return x.view(inputs.size(0), inputs.size(1), -1) - - def forward(self, inputs): - x = F.elu(self.fc1(inputs)) - x = self.dropout(x) - x = F.elu(self.fc2(x)) - return self.batch_norm(x) - - -class InteractionNet(nn.Module): - - def __init__(self, n_in, n_hid, n_out, do_prob=0., factor=True): - super().__init__() - - self.factor = factor - self.mlp1 = MLP(n_in, n_hid, n_hid, do_prob) - self.mlp2 = MLP(n_hid*2, n_hid, n_hid, do_prob) - self.mlp3 = MLP(n_hid, n_hid, n_hid, do_prob) - self.mlp4 = MLP(n_hid*3, n_hid, n_hid, do_prob) if self.factor else MLP(n_hid*2, n_hid, n_hid, do_prob) - self.fc_out = nn.Linear(n_hid, n_out) - - self.init_weights() - - def init_weights(self): - for m in self.modules(): - if isinstance(m, nn.Linear): - nn.init.xavier_normal_(m.weight.data) - m.bias.data.fill_(0.1) - - def node2edge(self, x, rel_rec, rel_send): - receivers = torch.matmul(rel_rec, x) - senders = torch.matmul(rel_send, x) - edges = torch.cat([receivers, senders], dim=2) - return edges - - def edge2node(self, x, rel_rec, rel_send): - incoming = torch.matmul(rel_rec.t(), x) - nodes = incoming / incoming.size(1) - return nodes - - def forward(self, inputs, rel_rec, rel_send): # input: [N, v, t, c] = [N, 25, 50, 3] - x = inputs.contiguous() - x = x.view(inputs.size(0), inputs.size(1), -1) # [N, 25, 50, 3] -> [N, 25, 50*3=150] - x = self.mlp1(x) # [N, 25, 150] -> [N, 25, n_hid=256] -> [N, 25, n_out=256] - x = self.node2edge(x, rel_rec, rel_send) # [N, 25, 256] -> [N, 600, 256]|[N, 600, 256]=[N, 600, 512] - x = self.mlp2(x) # [N, 600, 512] -> [N, 600, n_hid=256] -> [N, 600, n_out=256] - x_skip = x - if self.factor: - x = self.edge2node(x, rel_rec, rel_send) # [N, 600, 256] -> [N, 25, 256] - x = self.mlp3(x) # [N, 25, 256] -> [N, 25, n_hid=256] -> [N, 25, n_out=256] - x = self.node2edge(x, rel_rec, rel_send) # [N, 25, 256] -> [N, 600, 256]|[N, 600, 256]=[N, 600, 512] - x = torch.cat((x, x_skip), dim=2) # [N, 600, 512] -> [N, 600, 512]|[N, 600, 256]=[N, 600, 768] - x = self.mlp4(x) # [N, 600, 768] -> [N, 600, n_hid=256] -> [N, 600, n_out=256] - else: - x = self.mlp3(x) - x = torch.cat((x, x_skip), dim=2) - x = self.mlp4(x) - return self.fc_out(x) # [N, 600, 256] -> [N, 600, 3] - - -class InteractionDecoderRecurrent(nn.Module): - - def __init__(self, n_in_node, edge_types, n_hid, do_prob=0., skip_first=True): - super().__init__() - - self.msg_fc1 = nn.ModuleList([nn.Linear(2 * n_hid, n_hid) for _ in range(edge_types)]) - self.msg_fc2 = nn.ModuleList([nn.Linear(n_hid, n_hid) for _ in range(edge_types)]) - self.msg_out_shape = n_hid - self.skip_first_edge_type = skip_first - - self.hidden_r = nn.Linear(n_hid, n_hid, bias=False) - self.hidden_i = nn.Linear(n_hid, n_hid, bias=False) - self.hidden_n = nn.Linear(n_hid, n_hid, bias=False) - - self.input_r = nn.Linear(n_in_node, n_hid, bias=True) # 3 x 256 - self.input_i = nn.Linear(n_in_node, n_hid, bias=True) - self.input_n = nn.Linear(n_in_node, n_hid, bias=True) - - self.out_fc1 = nn.Linear(n_hid, n_hid) - self.out_fc2 = nn.Linear(n_hid, n_hid) - self.out_fc3 = nn.Linear(n_hid, n_in_node) - - self.dropout1 = nn.Dropout(p=do_prob) - self.dropout2 = nn.Dropout(p=do_prob) - self.dropout3 = nn.Dropout(p=do_prob) - - def single_step_forward(self, inputs, rel_rec, rel_send, rel_type, hidden): - receivers = torch.matmul(rel_rec, hidden) - senders = torch.matmul(rel_send, hidden) - pre_msg = torch.cat([receivers, senders], dim=-1) - all_msgs = torch.zeros(pre_msg.size(0), pre_msg.size(1), self.msg_out_shape) - gpu_id = rel_rec.get_device() - all_msgs = all_msgs.cuda(gpu_id) - if self.skip_first_edge_type: - start_idx = 1 - norm = float(len(self.msg_fc2)) - 1. - else: - start_idx = 0 - norm = float(len(self.msg_fc2)) - for k in range(start_idx, len(self.msg_fc2)): - msg = torch.tanh(self.msg_fc1[k](pre_msg)) - msg = self.dropout1(msg) - msg = torch.tanh(self.msg_fc2[k](msg)) - msg = msg * rel_type[:, :, k:k + 1] - all_msgs += msg / norm - agg_msgs = all_msgs.transpose(-2, -1).matmul(rel_rec).transpose(-2, -1) - agg_msgs = agg_msgs.contiguous()/inputs.size(2) - - r = torch.sigmoid(self.input_r(inputs) + self.hidden_r(agg_msgs)) - i = torch.sigmoid(self.input_i(inputs) + self.hidden_i(agg_msgs)) - n = torch.tanh(self.input_n(inputs) + r * self.hidden_n(agg_msgs)) - hidden = (1-i)*n + i*hidden - - pred = self.dropout2(F.relu(self.out_fc1(hidden))) - pred = self.dropout2(F.relu(self.out_fc2(pred))) - pred = self.out_fc3(pred) - pred = inputs + pred - - return pred, hidden - - def forward(self, data, rel_type, rel_rec, rel_send, pred_steps=1, - burn_in=False, burn_in_steps=1, dynamic_graph=False, - encoder=None, temp=None): - inputs = data.transpose(1, 2).contiguous() - time_steps = inputs.size(1) - hidden = torch.zeros(inputs.size(0), inputs.size(2), self.msg_out_shape) - gpu_id = rel_rec.get_device() - hidden = hidden.cuda(gpu_id) - pred_all = [] - for step in range(0, inputs.size(1) - 1): - if not step % pred_steps: - ins = inputs[:, step, :, :] - else: - ins = pred_all[step - 1] - pred, hidden = self.single_step_forward(ins, rel_rec, rel_send, rel_type, hidden) - pred_all.append(pred) - preds = torch.stack(pred_all, dim=1) - return preds.transpose(1, 2).contiguous() - - -class AdjacencyLearn(nn.Module): - - def __init__(self, n_in_enc, n_hid_enc, edge_types, n_in_dec, n_hid_dec, node_num=25): - super().__init__() - - self.encoder = InteractionNet(n_in=n_in_enc, # 150 - n_hid=n_hid_enc, # 256 - n_out=edge_types, # 3 - do_prob=0.5, - factor=True) - self.decoder = InteractionDecoderRecurrent(n_in_node=n_in_dec, # 256 - edge_types=edge_types, # 3 - n_hid=n_hid_dec, # 256 - do_prob=0.5, - skip_first=True) - self.offdiag_indices, _ = get_offdiag_indices(node_num) - - off_diag = np.ones([node_num, node_num])-np.eye(node_num, node_num) - self.rel_rec = torch.FloatTensor(np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)) - self.rel_send = torch.FloatTensor(np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)) - self.dcy = 0.1 - - self.init_weights() - - def init_weights(self): - for m in self.modules(): - if isinstance(m, nn.BatchNorm1d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def forward(self, inputs): # [N, 3, 50, 25, 2] - - N, C, T, V, M = inputs.size() - x = inputs.permute(0, 4, 3, 1, 2).contiguous() # [N, 2, 25, 3, 50] - x = x.contiguous().view(N*M, V, C, T).permute(0,1,3,2) # [2N, 25, 50, 3] - - gpu_id = x.get_device() - rel_rec = self.rel_rec.cuda(gpu_id) - rel_send = self.rel_send.cuda(gpu_id) - - self.logits = self.encoder(x, rel_rec, rel_send) - self.N, self.v, self.c = self.logits.size() - self.edges = gumbel_softmax(self.logits, tau=0.5, hard=True) - self.prob = my_softmax(self.logits, -1) - self.outputs = self.decoder(x, self.edges, rel_rec, rel_send, burn_in=False, burn_in_steps=40) - self.offdiag_indices = self.offdiag_indices.cuda(gpu_id) - - A_batch = [] - for i in range(self.N): - A_types = [] - for j in range(1, self.c): - A = torch.sparse.FloatTensor(self.offdiag_indices, self.edges[i,:,j], torch.Size([25, 25])).to_dense().cuda(gpu_id) - A = A + torch.eye(25, 25).cuda(gpu_id) - D = torch.sum(A, dim=0).squeeze().pow(-1)+1e-10 - D = torch.diag(D) - A_ = torch.matmul(A, D)*self.dcy - A_types.append(A_) - A_types = torch.stack(A_types) - A_batch.append(A_types) - self.A_batch = torch.stack(A_batch).cuda(gpu_id) # [N, 2, 25, 25] - - return self.A_batch, self.prob, self.outputs, x +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import numpy as np +from torch.autograd import Variable + + +def my_softmax(input, axis=1): + trans_input = input.transpose(axis, 0).contiguous() + soft_max_1d = F.softmax(trans_input) + return soft_max_1d.transpose(axis, 0) + + +def get_offdiag_indices(num_nodes): + ones = torch.ones(num_nodes, num_nodes) + eye = torch.eye(num_nodes, num_nodes) + offdiag_indices = (ones - eye).nonzero().t() + offdiag_indices_ = offdiag_indices[0] * num_nodes + offdiag_indices[1] + return offdiag_indices, offdiag_indices_ + + +def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10): + y_soft = gumbel_softmax_sample(logits, tau=tau, eps=eps) + if hard: + shape = logits.size() + _, k = y_soft.data.max(-1) + y_hard = torch.zeros(*shape) + if y_soft.is_cuda: + y_hard = y_hard.cuda() + y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0) + y = Variable(y_hard - y_soft.data) + y_soft + else: + y = y_soft + return y + + +def gumbel_softmax_sample(logits, tau=1, eps=1e-10): + gumbel_noise = sample_gumbel(logits.size(), eps=eps) + if logits.is_cuda: + gumbel_noise = gumbel_noise.cuda() + y = logits + Variable(gumbel_noise) + return my_softmax(y / tau, axis=-1) + + +def sample_gumbel(shape, eps=1e-10): + uniform = torch.rand(shape).float() + return - torch.log(eps - torch.log(uniform + eps)) + + +def encode_onehot(labels): + classes = set(labels) + classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)} + labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32) + return labels_onehot + + +class MLP(nn.Module): + + def __init__(self, n_in, n_hid, n_out, do_prob=0.): + super().__init__() + + self.fc1 = nn.Linear(n_in, n_hid) + self.fc2 = nn.Linear(n_hid, n_out) + self.bn = nn.BatchNorm1d(n_out) + self.dropout = nn.Dropout(p=do_prob) + + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight.data) + m.bias.data.fill_(0.1) + elif isinstance(m, nn.BatchNorm1d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def batch_norm(self, inputs): + x = inputs.view(inputs.size(0) * inputs.size(1), -1) + x = self.bn(x) + return x.view(inputs.size(0), inputs.size(1), -1) + + def forward(self, inputs): + x = F.elu(self.fc1(inputs)) + x = self.dropout(x) + x = F.elu(self.fc2(x)) + return self.batch_norm(x) + + +class InteractionNet(nn.Module): + + def __init__(self, n_in, n_hid, n_out, do_prob=0., factor=True): + super().__init__() + + self.factor = factor + self.mlp1 = MLP(n_in, n_hid, n_hid, do_prob) + self.mlp2 = MLP(n_hid*2, n_hid, n_hid, do_prob) + self.mlp3 = MLP(n_hid, n_hid, n_hid, do_prob) + self.mlp4 = MLP(n_hid*3, n_hid, n_hid, do_prob) if self.factor else MLP(n_hid*2, n_hid, n_hid, do_prob) + self.fc_out = nn.Linear(n_hid, n_out) + + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight.data) + m.bias.data.fill_(0.1) + + def node2edge(self, x, rel_rec, rel_send): + receivers = torch.matmul(rel_rec, x) + senders = torch.matmul(rel_send, x) + edges = torch.cat([receivers, senders], dim=2) + return edges + + def edge2node(self, x, rel_rec, rel_send): + incoming = torch.matmul(rel_rec.t(), x) + nodes = incoming / incoming.size(1) + return nodes + + def forward(self, inputs, rel_rec, rel_send): # input: [N, v, t, c] = [N, 25, 50, 3] + x = inputs.contiguous() + x = x.view(inputs.size(0), inputs.size(1), -1) # [N, 25, 50, 3] -> [N, 25, 50*3=150] + x = self.mlp1(x) # [N, 25, 150] -> [N, 25, n_hid=256] -> [N, 25, n_out=256] + x = self.node2edge(x, rel_rec, rel_send) # [N, 25, 256] -> [N, 600, 256]|[N, 600, 256]=[N, 600, 512] + x = self.mlp2(x) # [N, 600, 512] -> [N, 600, n_hid=256] -> [N, 600, n_out=256] + x_skip = x + if self.factor: + x = self.edge2node(x, rel_rec, rel_send) # [N, 600, 256] -> [N, 25, 256] + x = self.mlp3(x) # [N, 25, 256] -> [N, 25, n_hid=256] -> [N, 25, n_out=256] + x = self.node2edge(x, rel_rec, rel_send) # [N, 25, 256] -> [N, 600, 256]|[N, 600, 256]=[N, 600, 512] + x = torch.cat((x, x_skip), dim=2) # [N, 600, 512] -> [N, 600, 512]|[N, 600, 256]=[N, 600, 768] + x = self.mlp4(x) # [N, 600, 768] -> [N, 600, n_hid=256] -> [N, 600, n_out=256] + else: + x = self.mlp3(x) + x = torch.cat((x, x_skip), dim=2) + x = self.mlp4(x) + return self.fc_out(x) # [N, 600, 256] -> [N, 600, 3] + + +class InteractionDecoderRecurrent(nn.Module): + + def __init__(self, n_in_node, edge_types, n_hid, do_prob=0., skip_first=True): + super().__init__() + + self.msg_fc1 = nn.ModuleList([nn.Linear(2 * n_hid, n_hid) for _ in range(edge_types)]) + self.msg_fc2 = nn.ModuleList([nn.Linear(n_hid, n_hid) for _ in range(edge_types)]) + self.msg_out_shape = n_hid + self.skip_first_edge_type = skip_first + + self.hidden_r = nn.Linear(n_hid, n_hid, bias=False) + self.hidden_i = nn.Linear(n_hid, n_hid, bias=False) + self.hidden_n = nn.Linear(n_hid, n_hid, bias=False) + + self.input_r = nn.Linear(n_in_node, n_hid, bias=True) # 3 x 256 + self.input_i = nn.Linear(n_in_node, n_hid, bias=True) + self.input_n = nn.Linear(n_in_node, n_hid, bias=True) + + self.out_fc1 = nn.Linear(n_hid, n_hid) + self.out_fc2 = nn.Linear(n_hid, n_hid) + self.out_fc3 = nn.Linear(n_hid, n_in_node) + + self.dropout1 = nn.Dropout(p=do_prob) + self.dropout2 = nn.Dropout(p=do_prob) + self.dropout3 = nn.Dropout(p=do_prob) + + def single_step_forward(self, inputs, rel_rec, rel_send, rel_type, hidden): + receivers = torch.matmul(rel_rec, hidden) + senders = torch.matmul(rel_send, hidden) + pre_msg = torch.cat([receivers, senders], dim=-1) + all_msgs = torch.zeros(pre_msg.size(0), pre_msg.size(1), self.msg_out_shape) + gpu_id = rel_rec.get_device() + all_msgs = all_msgs.cuda(gpu_id) + if self.skip_first_edge_type: + start_idx = 1 + norm = float(len(self.msg_fc2)) - 1. + else: + start_idx = 0 + norm = float(len(self.msg_fc2)) + for k in range(start_idx, len(self.msg_fc2)): + msg = torch.tanh(self.msg_fc1[k](pre_msg)) + msg = self.dropout1(msg) + msg = torch.tanh(self.msg_fc2[k](msg)) + msg = msg * rel_type[:, :, k:k + 1] + all_msgs += msg / norm + agg_msgs = all_msgs.transpose(-2, -1).matmul(rel_rec).transpose(-2, -1) + agg_msgs = agg_msgs.contiguous()/inputs.size(2) + + r = torch.sigmoid(self.input_r(inputs) + self.hidden_r(agg_msgs)) + i = torch.sigmoid(self.input_i(inputs) + self.hidden_i(agg_msgs)) + n = torch.tanh(self.input_n(inputs) + r * self.hidden_n(agg_msgs)) + hidden = (1-i)*n + i*hidden + + pred = self.dropout2(F.relu(self.out_fc1(hidden))) + pred = self.dropout2(F.relu(self.out_fc2(pred))) + pred = self.out_fc3(pred) + pred = inputs + pred + + return pred, hidden + + def forward(self, data, rel_type, rel_rec, rel_send, pred_steps=1, + burn_in=False, burn_in_steps=1, dynamic_graph=False, + encoder=None, temp=None): + inputs = data.transpose(1, 2).contiguous() + time_steps = inputs.size(1) + hidden = torch.zeros(inputs.size(0), inputs.size(2), self.msg_out_shape) + gpu_id = rel_rec.get_device() + hidden = hidden.cuda(gpu_id) + pred_all = [] + for step in range(0, inputs.size(1) - 1): + if not step % pred_steps: + ins = inputs[:, step, :, :] + else: + ins = pred_all[step - 1] + pred, hidden = self.single_step_forward(ins, rel_rec, rel_send, rel_type, hidden) + pred_all.append(pred) + preds = torch.stack(pred_all, dim=1) + return preds.transpose(1, 2).contiguous() + + +class AdjacencyLearn(nn.Module): + + def __init__(self, n_in_enc, n_hid_enc, edge_types, n_in_dec, n_hid_dec, node_num=25): + super().__init__() + + self.encoder = InteractionNet(n_in=n_in_enc, # 150 + n_hid=n_hid_enc, # 256 + n_out=edge_types, # 3 + do_prob=0.5, + factor=True) + self.decoder = InteractionDecoderRecurrent(n_in_node=n_in_dec, # 256 + edge_types=edge_types, # 3 + n_hid=n_hid_dec, # 256 + do_prob=0.5, + skip_first=True) + self.offdiag_indices, _ = get_offdiag_indices(node_num) + + off_diag = np.ones([node_num, node_num])-np.eye(node_num, node_num) + self.rel_rec = torch.FloatTensor(np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)) + self.rel_send = torch.FloatTensor(np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)) + self.dcy = 0.1 + + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm1d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, inputs): # [N, 3, 50, 25, 2] + print("enter AdjacencyLearn") + + N, C, T, V, M = inputs.size() + x = inputs.permute(0, 4, 3, 1, 2).contiguous() # [N, 2, 25, 3, 50] + x = x.contiguous().view(N*M, V, C, T).permute(0,1,3,2) # [2N, 25, 50, 3] + + gpu_id = x.get_device() + rel_rec = self.rel_rec.cuda(gpu_id) + rel_send = self.rel_send.cuda(gpu_id) + + self.logits = self.encoder(x, rel_rec, rel_send) + self.N, self.v, self.c = self.logits.size() + self.edges = gumbel_softmax(self.logits, tau=0.5, hard=True) + self.prob = my_softmax(self.logits, -1) + self.outputs = self.decoder(x, self.edges, rel_rec, rel_send, burn_in=False, burn_in_steps=40) + self.offdiag_indices = self.offdiag_indices.cuda(gpu_id) + + A_batch = [] + for i in range(self.N): + A_types = [] + for j in range(1, self.c): + A = torch.sparse.FloatTensor(self.offdiag_indices, self.edges[i,:,j], torch.Size([25, 25])).to_dense().cuda(gpu_id) + A = A + torch.eye(25, 25).cuda(gpu_id) + D = torch.sum(A, dim=0).squeeze().pow(-1)+1e-10 + D = torch.diag(D) + A_ = torch.matmul(A, D)*self.dcy + A_types.append(A_) + A_types = torch.stack(A_types) + A_batch.append(A_types) + self.A_batch = torch.stack(A_batch).cuda(gpu_id) # [N, 2, 25, 25] + + return self.A_batch, self.prob, self.outputs, x diff --git a/net/utils/graph.py b/net/utils/graph.py index 52bb13e..708f0b8 100644 --- a/net/utils/graph.py +++ b/net/utils/graph.py @@ -1,129 +1,129 @@ -import numpy as np - -class Graph(): - - def __init__(self, - layout='openpose', - strategy='uniform', - max_hop=2, - dilation=1): - self.max_hop = max_hop - self.dilation = dilation - - self.get_edge(layout) - self.hop_dis = get_hop_distance(self.num_node, self.edge, max_hop=max_hop) - self.get_adjacency(strategy) - - def __str__(self): - return self.A - - def get_edge(self, layout): - if layout == 'openpose': - self.num_node = 18 - self_link = [(i, i) for i in range(self.num_node)] - neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12, 11), - (10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1), - (0, 1), (15, 0), (14, 0), (17, 15), (16, 14)] - self.edge = self_link + neighbor_link - self.center = 1 - elif layout == 'ntu-rgb+d': - self.num_node = 25 - self_link = [(i, i) for i in range(self.num_node)] - neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), - (6, 5), (7, 6), (8, 7), (9, 21), (10, 9), - (11, 10), (12, 11), (13, 1), (14, 13), (15, 14), - (16, 15), (17, 1), (18, 17), (19, 18), (20, 19), - (22, 23), (23, 8), (24, 25), (25, 12)] - neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] - self.edge = self_link + neighbor_link - self.center = 21 - 1 - elif layout == 'ntu_edge': - self.num_node = 24 - self_link = [(i, i) for i in range(self.num_node)] - neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6), - (8, 7), (9, 2), (10, 9), (11, 10), (12, 11), - (13, 1), (14, 13), (15, 14), (16, 15), (17, 1), - (18, 17), (19, 18), (20, 19), (21, 22), (22, 8), - (23, 24), (24, 12)] - neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] - self.edge = self_link + neighbor_link - self.center = 2 - else: - raise ValueError("Do Not Exist This Layout.") - - def get_adjacency(self, strategy): - valid_hop = range(0, self.max_hop + 1, self.dilation) - adjacency = np.zeros((self.num_node, self.num_node)) - for hop in valid_hop: - adjacency[self.hop_dis == hop] = 1 - normalize_adjacency = normalize_digraph(adjacency) - - if strategy == 'uniform': - A = np.zeros((1, self.num_node, self.num_node)) - A[0] = normalize_adjacency - self.A = A - elif strategy == 'distance': - A = np.zeros((len(valid_hop), self.num_node, self.num_node)) - for i, hop in enumerate(valid_hop): - A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == hop] - self.A = A - elif strategy == 'spatial': - A = [] - for hop in valid_hop: - a_root = np.zeros((self.num_node, self.num_node)) - a_close = np.zeros((self.num_node, self.num_node)) - a_further = np.zeros((self.num_node, self.num_node)) - for i in range(self.num_node): - for j in range(self.num_node): - if self.hop_dis[j, i] == hop: - if self.hop_dis[j, self.center] == self.hop_dis[i, self.center]: - a_root[j, i] = normalize_adjacency[j, i] - elif self.hop_dis[j, self.center] > self.hop_dis[i, self.center]: - a_close[j, i] = normalize_adjacency[j, i] - else: - a_further[j, i] = normalize_adjacency[j, i] - if hop == 0: - A.append(a_root) - else: - A.append(a_root + a_close) - A.append(a_further) - A = np.stack(A) - self.A = A - else: - raise ValueError("Do Not Exist This Strategy") - - -def get_hop_distance(num_node, edge, max_hop=1): - A = np.zeros((num_node, num_node)) - for i, j in edge: - A[j, i] = 1 - A[i, j] = 1 - - hop_dis = np.zeros((num_node, num_node)) + np.inf - transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)] - arrive_mat = (np.stack(transfer_mat) > 0) - for d in range(max_hop, -1, -1): - hop_dis[arrive_mat[d]] = d - return hop_dis - - -def normalize_digraph(A): - Dl = np.sum(A, 0) - num_node = A.shape[0] - Dn = np.zeros((num_node, num_node)) - for i in range(num_node): - if Dl[i] > 0: - Dn[i, i] = Dl[i]**(-1) - AD = np.dot(A, Dn) - return AD - - -def normalize_undigraph(A): - Dl = np.sum(A, 0) - num_node = A.shape[0] - Dn = np.zeros((num_node, num_node)) - for i in range(num_node): - if Dl[i] > 0: - Dn[i, i] = Dl[i]**(-0.5) - DAD = np.dot(np.dot(Dn, A), Dn) +import numpy as np + +class Graph(): + + def __init__(self, + layout='openpose', + strategy='uniform', + max_hop=2, + dilation=1): + self.max_hop = max_hop + self.dilation = dilation + + self.get_edge(layout) + self.hop_dis = get_hop_distance(self.num_node, self.edge, max_hop=max_hop) + self.get_adjacency(strategy) + + def __str__(self): + return self.A + + def get_edge(self, layout): + if layout == 'openpose': + self.num_node = 18 + self_link = [(i, i) for i in range(self.num_node)] + neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12, 11), + (10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1), + (0, 1), (15, 0), (14, 0), (17, 15), (16, 14)] + self.edge = self_link + neighbor_link + self.center = 1 + elif layout == 'ntu-rgb+d': + self.num_node = 25 + self_link = [(i, i) for i in range(self.num_node)] + neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), + (6, 5), (7, 6), (8, 7), (9, 21), (10, 9), + (11, 10), (12, 11), (13, 1), (14, 13), (15, 14), + (16, 15), (17, 1), (18, 17), (19, 18), (20, 19), + (22, 23), (23, 8), (24, 25), (25, 12)] + neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] + self.edge = self_link + neighbor_link + self.center = 21 - 1 + elif layout == 'ntu_edge': + self.num_node = 24 + self_link = [(i, i) for i in range(self.num_node)] + neighbor_1base = [(1, 2), (3, 2), (4, 3), (5, 2), (6, 5), (7, 6), + (8, 7), (9, 2), (10, 9), (11, 10), (12, 11), + (13, 1), (14, 13), (15, 14), (16, 15), (17, 1), + (18, 17), (19, 18), (20, 19), (21, 22), (22, 8), + (23, 24), (24, 12)] + neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base] + self.edge = self_link + neighbor_link + self.center = 2 + else: + raise ValueError("Do Not Exist This Layout.") + + def get_adjacency(self, strategy): + valid_hop = range(0, self.max_hop + 1, self.dilation) + adjacency = np.zeros((self.num_node, self.num_node)) + for hop in valid_hop: + adjacency[self.hop_dis == hop] = 1 + normalize_adjacency = normalize_digraph(adjacency) + + if strategy == 'uniform': + A = np.zeros((1, self.num_node, self.num_node)) + A[0] = normalize_adjacency + self.A = A + elif strategy == 'distance': + A = np.zeros((len(valid_hop), self.num_node, self.num_node)) + for i, hop in enumerate(valid_hop): + A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis == hop] + self.A = A + elif strategy == 'spatial': + A = [] + for hop in valid_hop: + a_root = np.zeros((self.num_node, self.num_node)) + a_close = np.zeros((self.num_node, self.num_node)) + a_further = np.zeros((self.num_node, self.num_node)) + for i in range(self.num_node): + for j in range(self.num_node): + if self.hop_dis[j, i] == hop: + if self.hop_dis[j, self.center] == self.hop_dis[i, self.center]: + a_root[j, i] = normalize_adjacency[j, i] + elif self.hop_dis[j, self.center] > self.hop_dis[i, self.center]: + a_close[j, i] = normalize_adjacency[j, i] + else: + a_further[j, i] = normalize_adjacency[j, i] + if hop == 0: + A.append(a_root) + else: + A.append(a_root + a_close) + A.append(a_further) + A = np.stack(A) + self.A = A + else: + raise ValueError("Do Not Exist This Strategy") + + +def get_hop_distance(num_node, edge, max_hop=1): + A = np.zeros((num_node, num_node)) + for i, j in edge: + A[j, i] = 1 + A[i, j] = 1 + + hop_dis = np.zeros((num_node, num_node)) + np.inf + transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)] + arrive_mat = (np.stack(transfer_mat) > 0) + for d in range(max_hop, -1, -1): + hop_dis[arrive_mat[d]] = d + return hop_dis + + +def normalize_digraph(A): + Dl = np.sum(A, 0) + num_node = A.shape[0] + Dn = np.zeros((num_node, num_node)) + for i in range(num_node): + if Dl[i] > 0: + Dn[i, i] = Dl[i]**(-1) + AD = np.dot(A, Dn) + return AD + + +def normalize_undigraph(A): + Dl = np.sum(A, 0) + num_node = A.shape[0] + Dn = np.zeros((num_node, num_node)) + for i in range(num_node): + if Dl[i] > 0: + Dn[i, i] = Dl[i]**(-0.5) + DAD = np.dot(np.dot(Dn, A), Dn) return DAD \ No newline at end of file diff --git a/net/utils/utils_adj.py b/net/utils/utils_adj.py index c6a0218..034b64c 100644 --- a/net/utils/utils_adj.py +++ b/net/utils/utils_adj.py @@ -1,48 +1,48 @@ -import os -import numpy as np -import torch -import torch.utils.data -import torch.nn.functional as F -from torch.autograd import Variable - - -def my_softmax(input, axis=1): - trans_input = input.transpose(axis, 0).contiguous() - soft_max_1d = F.softmax(trans_input) - return soft_max_1d.transpose(axis, 0) - - -def get_offdiag_indices(num_nodes): - ones = torch.ones(num_nodes, num_nodes) - eye = torch.eye(num_nodes, num_nodes) - offdiag_indices = (ones - eye).nonzero().t() - offdiag_indices_ = offdiag_indices[0] * num_nodes + offdiag_indices[1] - return offdiag_indices, offdiag_indices_ - - -def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10): - y_soft = gumbel_softmax_sample(logits, tau=tau, eps=eps) - if hard: - shape = logits.size() - _, k = y_soft.data.max(-1) - y_hard = torch.zeros(*shape) - if y_soft.is_cuda: - y_hard = y_hard.cuda() - y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0) - y = Variable(y_hard - y_soft.data) + y_soft - else: - y = y_soft - return y - - -def gumbel_softmax_sample(logits, tau=1, eps=1e-10): - gumbel_noise = sample_gumbel(logits.size(), eps=eps) - if logits.is_cuda: - gumbel_noise = gumbel_noise.cuda() - y = logits + Variable(gumbel_noise) - return my_softmax(y / tau, axis=-1) - - -def sample_gumbel(shape, eps=1e-10): - uniform = torch.rand(shape).float() +import os +import numpy as np +import torch +import torch.utils.data +import torch.nn.functional as F +from torch.autograd import Variable + + +def my_softmax(input, axis=1): + trans_input = input.transpose(axis, 0).contiguous() + soft_max_1d = F.softmax(trans_input) + return soft_max_1d.transpose(axis, 0) + + +def get_offdiag_indices(num_nodes): + ones = torch.ones(num_nodes, num_nodes) + eye = torch.eye(num_nodes, num_nodes) + offdiag_indices = (ones - eye).nonzero().t() + offdiag_indices_ = offdiag_indices[0] * num_nodes + offdiag_indices[1] + return offdiag_indices, offdiag_indices_ + + +def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10): + y_soft = gumbel_softmax_sample(logits, tau=tau, eps=eps) + if hard: + shape = logits.size() + _, k = y_soft.data.max(-1) + y_hard = torch.zeros(*shape) + if y_soft.is_cuda: + y_hard = y_hard.cuda() + y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0) + y = Variable(y_hard - y_soft.data) + y_soft + else: + y = y_soft + return y + + +def gumbel_softmax_sample(logits, tau=1, eps=1e-10): + gumbel_noise = sample_gumbel(logits.size(), eps=eps) + if logits.is_cuda: + gumbel_noise = gumbel_noise.cuda() + y = logits + Variable(gumbel_noise) + return my_softmax(y / tau, axis=-1) + + +def sample_gumbel(shape, eps=1e-10): + uniform = torch.rand(shape).float() return - torch.log(eps - torch.log(U + eps)) \ No newline at end of file diff --git a/pip_req.txt b/pip_req.txt new file mode 100644 index 0000000..aae02c0 --- /dev/null +++ b/pip_req.txt @@ -0,0 +1,15 @@ +argparse==1.4.0 +cached-property==1.5.2 +dataclasses==0.8 +h5py==3.1.0 +imageio==2.9.0 +numpy==1.19.5 +opencv-python==4.5.1.48 +pyyaml==5.4.1 +scikit-video==1.1.11 +scipy==1.5.4 +torch==1.7.1 +torchvision==0.9.0 +tqdm==4.60.0 +typing-extensions==3.7.4.3 + diff --git a/processor/__init__.py b/processor/__init__.py index 8b13789..d3f5a12 100644 --- a/processor/__init__.py +++ b/processor/__init__.py @@ -1 +1 @@ - + diff --git a/processor/gpu.py b/processor/gpu.py new file mode 100644 index 0000000..e086d4c --- /dev/null +++ b/processor/gpu.py @@ -0,0 +1,35 @@ +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/processor/io.py b/processor/io.py index fb9e4f8..7e7e57c 100644 --- a/processor/io.py +++ b/processor/io.py @@ -1,116 +1,118 @@ -import sys -import os -import argparse -import yaml -import numpy as np - -import torch -import torch.nn as nn - -import torchlight -from torchlight import str2bool -from torchlight import DictAction -from torchlight import import_class - - -class IO(): - - def __init__(self, argv=None): - - self.load_arg(argv) - self.init_environment() - self.load_model() - self.load_weights() - self.gpu() - - def load_arg(self, argv=None): - parser = self.get_parser() - - # load arg form config file - p = parser.parse_args(argv) - if p.config is not None: - # load config file - with open(p.config, 'r') as f: - default_arg = yaml.load(f) - - # update parser from config file - key = vars(p).keys() - for k in default_arg.keys(): - if k not in key: - print('Unknown Arguments: {}'.format(k)) - assert k in key - - parser.set_defaults(**default_arg) - - self.arg = parser.parse_args(argv) - - def init_environment(self): - self.save_dir = os.path.join(self.arg.work_dir, - self.arg.max_hop_dir, - self.arg.lamda_act_dir) - self.io = torchlight.IO(self.save_dir, save_log=self.arg.save_log, print_log=self.arg.print_log) - self.io.save_arg(self.arg) - - # gpu - if self.arg.use_gpu: - gpus = torchlight.visible_gpu(self.arg.device) - torchlight.occupy_gpu(gpus) - self.gpus = gpus - self.dev = "cuda:0" - else: - self.dev = "cpu" - - def load_model(self): - self.model1 = self.io.load_model(self.arg.model1, **(self.arg.model1_args)) - self.model2 = self.io.load_model(self.arg.model2, **(self.arg.model2_args)) - - def load_weights(self): - if self.arg.weights1: - self.model1 = self.io.load_weights(self.model1, self.arg.weights1, self.arg.ignore_weights) - self.model2 = self.io.load_weights(self.model2, self.arg.weights2, self.arg.ignore_weights) - - def gpu(self): - # move modules to gpu - self.model1 = self.model1.to(self.dev) - self.model2 = self.model2.to(self.dev) - for name, value in vars(self).items(): - cls_name = str(value.__class__) - if cls_name.find('torch.nn.modules') != -1: - setattr(self, name, value.to(self.dev)) - - # model parallel - if self.arg.use_gpu and len(self.gpus) > 1: - self.model1 = nn.DataParallel(self.model1, device_ids=self.gpus) - self.model2 = nn.DataParallel(self.model2, device_ids=self.gpus) - - def start(self): - self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg)))) - - @staticmethod - def get_parser(add_help=False): - - #region arguments yapf: disable - # parameter priority: command line > config > default - parser = argparse.ArgumentParser( add_help=add_help, description='IO Processor') - - parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results') - parser.add_argument('-c', '--config', default=None, help='path to the configuration file') - - # processor - parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not') - parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing') - - # visulize and debug - parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not') - parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not') - - # model - parser.add_argument('--model1', default=None, help='the model will be used') - parser.add_argument('--model2', default=None, help='the model will be used') - parser.add_argument('--model1_args', action=DictAction, default=dict(), help='the arguments of model') - parser.add_argument('--model2_args', action=DictAction, default=dict(), help='the arguments of model') - parser.add_argument('--weights', default=None, help='the weights for network initialization') - parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization') - #endregion yapf: enable - - return parser +import sys +import os +import argparse +import yaml +import numpy as np + +import torch +import torch.nn as nn + +import torchlight +from torchlight.io import str2bool +from torchlight.io import DictAction +from torchlight.io import import_class + + +class IO(): + + def __init__(self, argv=None): + + self.load_arg(argv) + self.init_environment() + self.load_model() + self.load_weights() + self.gpu() + + def load_arg(self, argv=None): + parser = self.get_parser() + + # load arg form config file + p = parser.parse_args(argv) + if p.config is not None: + # load config file + with open(p.config, 'r') as f: + default_arg = yaml.load(f) + + # update parser from config file + key = vars(p).keys() + for k in default_arg.keys(): + if k not in key: + print('Unknown Arguments: {}'.format(k)) + assert k in key + + parser.set_defaults(**default_arg) + + self.arg = parser.parse_args(argv) + + def init_environment(self): + self.save_dir = os.path.join(self.arg.work_dir, + self.arg.max_hop_dir, + self.arg.lamda_act_dir) + self.io = torchlight.io.IO(self.save_dir, save_log=self.arg.save_log, print_log=self.arg.print_log) + self.io.save_arg(self.arg) + + # gpu + if self.arg.use_gpu: + gpus = torchlight.gpu.visible_gpu(self.arg.device) + #torchlight.occupy_gpu(gpus) + self.gpus = gpus + self.dev = "cuda:0" + else: + self.dev = "cpu" + + def load_model(self): + self.model1 = self.io.load_model(self.arg.model1, **(self.arg.model1_args)) + self.model2 = self.io.load_model(self.arg.model2, **(self.arg.model2_args)) + + def load_weights(self): + if self.arg.weights1: + self.model1 = self.io.load_weights(self.model1, self.arg.weights1, self.arg.ignore_weights) + self.model2 = self.io.load_weights(self.model2, self.arg.weights2, self.arg.ignore_weights) + #self.model3 = self.io.load_weights(self.model3, self.arg.weights3, self.arg.ignore_weights) + + def gpu(self): + # move modules to gpu + self.model1 = self.model1.to(self.dev) + self.model2 = self.model2.to(self.dev) + self.model3 = self.model3.to(self.dev) + for name, value in vars(self).items(): + cls_name = str(value.__class__) + if cls_name.find('torch.nn.modules') != -1: + setattr(self, name, value.to(self.dev)) + + # model parallel + if self.arg.use_gpu and len(self.gpus) > 1: + self.model1 = nn.DataParallel(self.model1, device_ids=self.gpus) + self.model2 = nn.DataParallel(self.model2, device_ids=self.gpus) + + def start(self): + self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg)))) + + @staticmethod + def get_parser(add_help=False): + + #region arguments yapf: disable + # parameter priority: command line > config > default + parser = argparse.ArgumentParser( add_help=add_help, description='IO Processor') + + parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results') + parser.add_argument('-c', '--config', default=None, help='path to the configuration file') + + # processor + parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not') + parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing') + + # visulize and debug + parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not') + parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not') + + # model + parser.add_argument('--model1', default=None, help='the model will be used') + parser.add_argument('--model2', default=None, help='the model will be used') + parser.add_argument('--model1_args', action=DictAction, default=dict(), help='the arguments of model') + parser.add_argument('--model2_args', action=DictAction, default=dict(), help='the arguments of model') + parser.add_argument('--weights', default=None, help='the weights for network initialization') + parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization') + #endregion yapf: enable + + return parser diff --git a/processor/processor.py b/processor/processor.py index 03fc2cf..045c764 100644 --- a/processor/processor.py +++ b/processor/processor.py @@ -1,186 +1,212 @@ -import sys -import argparse -import yaml -import numpy as np - -import torch -import torch.nn as nn -import torch.optim as optim - -import torchlight -from torchlight import str2bool -from torchlight import DictAction -from torchlight import import_class - -from .io import IO - - -class Processor(IO): - - def __init__(self, argv=None): - - self.load_arg(argv) - self.init_environment() - self.load_model() - self.load_weights() - self.gpu() - self.load_data() - - def init_environment(self): - - super().init_environment() - self.result = dict() - self.iter_info = dict() - self.epoch_info = dict() - self.meta_info = dict(epoch=0, iter=0) - - - def load_data(self): - Feeder = import_class(self.arg.feeder) - if 'debug' not in self.arg.train_feeder_args: - self.arg.train_feeder_args['debug'] = self.arg.debug - self.data_loader = dict() - if self.arg.phase == 'train': - self.data_loader['train'] = torch.utils.data.DataLoader(dataset=Feeder(**self.arg.train_feeder_args), - batch_size=self.arg.batch_size, - shuffle=True, - num_workers=self.arg.num_worker, - drop_last=True) - if self.arg.test_feeder_args: - self.data_loader['test'] = torch.utils.data.DataLoader(dataset=Feeder(**self.arg.test_feeder_args), - batch_size=self.arg.test_batch_size, - shuffle=False, - num_workers=self.arg.num_worker) - - def show_epoch_info(self): - for k, v in self.epoch_info.items(): - self.io.print_log('\t{}: {}'.format(k, v)) - if self.arg.pavi_log: - self.io.log('train', self.meta_info['iter'], self.epoch_info) - - def show_iter_info(self): - if self.meta_info['iter'] % self.arg.log_interval == 0: - info ='\tIter {} Done.'.format(self.meta_info['iter']) - for k, v in self.iter_info.items(): - if isinstance(v, float): - info = info + ' | {}: {:.4f}'.format(k, v) - else: - info = info + ' | {}: {}'.format(k, v) - - self.io.print_log(info) - - if self.arg.pavi_log: - self.io.log('train', self.meta_info['iter'], self.iter_info) - - def train(self): - for _ in range(100): - self.iter_info['loss'] = 0 - self.iter_info['loss_class'] = 0 - self.iter_info['loss_recon'] = 0 - self.show_iter_info() - self.meta_info['iter'] += 1 - self.epoch_info['mean_loss'] = 0 - self.epoch_info['mean_loss_class'] = 0 - self.epoch_info['mean_loss_recon'] = 0 - self.show_epoch_info() - - def test(self): - for _ in range(100): - self.iter_info['loss'] = 1 - self.iter_info['loss_class'] = 1 - self.iter_info['loss_recon'] = 1 - self.show_iter_info() - self.epoch_info['mean_loss'] = 1 - self.epoch_info['mean_loss_class'] = 1 - self.epoch_info['mean_loss_recon'] = 1 - self.show_epoch_info() - - def start(self): - self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg)))) - - if self.arg.phase == 'train': - for epoch in range(self.arg.start_epoch, self.arg.num_epoch): - self.meta_info['epoch'] = epoch - - if epoch < 10: - self.io.print_log('Training epoch: {}'.format(epoch)) - self.train(training_A=True) - self.io.print_log('Done.') - else: - self.io.print_log('Training epoch: {}'.format(epoch)) - self.train(training_A=False) - self.io.print_log('Done.') - - # save model - if ((epoch + 1) % self.arg.save_interval == 0) or (epoch + 1 == self.arg.num_epoch): - filename1 = 'epoch{}_model1.pt'.format(epoch) - self.io.save_model(self.model1, filename1) - filename2 = 'epoch{}_model2.pt'.format(epoch) - self.io.save_model(self.model2, filename2) - - # evaluation - if ((epoch + 1) % self.arg.eval_interval == 0) or (epoch + 1 == self.arg.num_epoch): - self.io.print_log('Eval epoch: {}'.format(epoch)) - if epoch <= 10: - self.test(testing_A=True) - else: - self.test(testing_A=False) - self.io.print_log('Done.') - - - elif self.arg.phase == 'test': - if self.arg.weights2 is None: - raise ValueError('Please appoint --weights.') - self.io.print_log('Model: {}.'.format(self.arg.model2)) - self.io.print_log('Weights: {}.'.format(self.arg.weights2)) - - self.io.print_log('Evaluation Start:') - self.test(testing_A=False, save_feature=True) - self.io.print_log('Done.\n') - - if self.arg.save_result: - result_dict = dict( - zip(self.data_loader['test'].dataset.sample_name, - self.result)) - self.io.save_pkl(result_dict, 'test_result.pkl') - - - @staticmethod - def get_parser(add_help=False): - - parser = argparse.ArgumentParser( add_help=add_help, description='Base Processor') - - parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results') - parser.add_argument('-c', '--config', default=None, help='path to the configuration file') - - parser.add_argument('--phase', default='train', help='must be train or test') - parser.add_argument('--save_result', type=str2bool, default=False, help='if ture, the output of the model will be stored') - parser.add_argument('--start_epoch', type=int, default=0, help='start training from which epoch') - parser.add_argument('--num_epoch', type=int, default=80, help='stop training in which epoch') - parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not') - parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing') - - parser.add_argument('--log_interval', type=int, default=100, help='the interval for printing messages (#iteration)') - parser.add_argument('--save_interval', type=int, default=1, help='the interval for storing models (#iteration)') - parser.add_argument('--eval_interval', type=int, default=5, help='the interval for evaluating models (#iteration)') - parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not') - parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not') - parser.add_argument('--pavi_log', type=str2bool, default=False, help='logging on pavi or not') - - parser.add_argument('--feeder', default='feeder.feeder', help='data loader will be used') - parser.add_argument('--num_worker', type=int, default=4, help='the number of worker per gpu for data loader') - parser.add_argument('--train_feeder_args', action=DictAction, default=dict(), help='the arguments of data loader for training') - parser.add_argument('--test_feeder_args', action=DictAction, default=dict(), help='the arguments of data loader for test') - parser.add_argument('--batch_size', type=int, default=256, help='training batch size') - parser.add_argument('--test_batch_size', type=int, default=256, help='test batch size') - parser.add_argument('--debug', action="store_true", help='less data, faster loading') - - parser.add_argument('--model1', default=None, help='the model will be used') - parser.add_argument('--model2', default=None, help='the model will be used') - parser.add_argument('--model1_args', action=DictAction, default=dict(), help='the arguments of model') - parser.add_argument('--model2_args', action=DictAction, default=dict(), help='the arguments of model') - parser.add_argument('--weights1', default=None, help='the weights for network initialization') - parser.add_argument('--weights2', default=None, help='the weights for network initialization') - parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization') - - return parser +import sys +import argparse +import yaml +import numpy as np + +import torch +import torch.nn as nn +import torch.optim as optim +import matplotlib.pyplot as plt +import pickle + +import torchlight +from torchlight.io import str2bool +from torchlight.io import DictAction +from torchlight.io import import_class + +from .io import IO + + +class Processor(IO): + + def __init__(self, argv=None): + + self.load_arg(argv) + self.init_environment() + self.load_model() + #self.load_weights() + self.gpu() + self.load_data() + + def init_environment(self): + + super().init_environment() + self.result = dict() + self.iter_info = dict() + self.epoch_info = dict() + self.meta_info = dict(epoch=0, iter=0) + self.epoch_loss_class_train = [] + self.epoch_loss_class_test = [] + + + def load_data(self): + Feeder = import_class(self.arg.feeder) + if 'debug' not in self.arg.train_feeder_args: + self.arg.train_feeder_args['debug'] = self.arg.debug + self.data_loader = dict() + if self.arg.phase == 'train': + self.data_loader['train'] = torch.utils.data.DataLoader(dataset=Feeder(**self.arg.train_feeder_args), + batch_size=self.arg.batch_size, + shuffle=True, + num_workers=self.arg.num_worker, + drop_last=True) + if self.arg.test_feeder_args: + self.data_loader['test'] = torch.utils.data.DataLoader(dataset=Feeder(**self.arg.test_feeder_args), + batch_size=self.arg.test_batch_size, + shuffle=False, + num_workers=self.arg.num_worker) + + def show_epoch_info(self): + for k, v in self.epoch_info.items(): + self.io.print_log('\t{}: {}'.format(k, v)) + if self.arg.pavi_log: + self.io.log('train', self.meta_info['iter'], self.epoch_info) + + + def show_epoch_curl(self, epoch): + plt.figure() + epoch_x = np.arange(10, len(loss_class_value)) + 1 + plt.plot(epoch_x, loss_class_value[3:], '--', color='C0') + plt.legend(['action_class_train']) + plt.ylabel('action CrossEntropyLoss') + plt.xlabel('Epoch') + plt.xlim((3, epoch)) + plt.savefig(os.path.join('loss_action_class_task.png')) + plt.close() + + def show_iter_info(self): + if self.meta_info['iter'] % self.arg.log_interval == 0: + info ='\tIter {} Done.'.format(self.meta_info['iter']) + for k, v in self.iter_info.items(): + if isinstance(v, float): + info = info + ' | {}: {:.4f}'.format(k, v) + else: + info = info + ' | {}: {}'.format(k, v) + + self.io.print_log(info) + + if self.arg.pavi_log: + self.io.log('train', self.meta_info['iter'], self.iter_info) + + def train(self): + for _ in range(300): + self.iter_info['loss'] = 0 + self.iter_info['loss_class'] = 0 + self.iter_info['loss_recon'] = 0 + self.show_iter_info() + self.meta_info['iter'] += 1 + self.epoch_info['mean_loss'] = 0 + self.epoch_info['mean_loss_class'] = 0 + self.epoch_info['mean_loss_recon'] = 0 + self.show_epoch_info() + + def test(self): + for _ in range(100): + self.iter_info['loss'] = 1 + self.iter_info['loss_class'] = 1 + self.iter_info['loss_recon'] = 1 + self.show_iter_info() + self.epoch_info['mean_loss'] = 1 + self.epoch_info['mean_loss_class'] = 1 + self.epoch_info['mean_loss_recon'] = 1 + self.show_epoch_info() + + def start(self): + self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg)))) + + if self.arg.phase == 'train': + for epoch in range(self.arg.start_epoch, self.arg.num_epoch): + self.meta_info['epoch'] = epoch + + if epoch < 10: + self.io.print_log('Training epoch: {}'.format(epoch)) + self.train(training_A=True) + self.io.print_log('Done.') + else: + self.io.print_log('Training epoch: {}'.format(epoch)) + self.train(training_A=False) + self.io.print_log('Done.') + + # save model + if ((epoch + 1) % self.arg.save_interval == 0) or (epoch + 1 == self.arg.num_epoch): + """ + filename1 = 'epoch{}_model1.pt'.format(epoch) + self.io.save_model(self.model1, filename1) + filename2 = 'epoch{}_model2.pt'.format(epoch) + self.io.save_model(self.model2, filename2) + """ + filename3 = 'epoch{}_model3.pt'.format(epoch) + self.io.save_model(self.model3, filename3) + + with open("epoch_loss_class_train.txt", "w") as outfile: + for item in self.epoch_loss_class_train: + outfile.write("{}: {}\n".format(self.epoch_info, item)) + + # evaluation + if ((epoch + 1) % self.arg.eval_interval == 0) or (epoch + 1 == self.arg.num_epoch): + self.io.print_log('Eval epoch: {}'.format(epoch)) + self.test(testing_A=False) + self.io.print_log('Done.') + + with open("epoch_loss_class_test.txt", "w") as outfile: + for item in self.epoch_loss_class_test: + outfile.write("{}: {}\n".format(self.epoch_info, item)) + + + + elif self.arg.phase == 'test': + if self.arg.weights2 is None: + raise ValueError('Please appoint --weights.') + self.io.print_log('Model: {}.'.format(self.arg.model2)) + self.io.print_log('Weights: {}.'.format(self.arg.weights2)) + + self.io.print_log('Evaluation Start:') + self.test(testing_A=False, save_feature=True) + self.io.print_log('Done.\n') + + if self.arg.save_result: + result_dict = dict( + zip(self.data_loader['test'].dataset.sample_name, + self.result)) + self.io.save_pkl(result_dict, 'test_result.pkl') + + + @staticmethod + def get_parser(add_help=False): + + parser = argparse.ArgumentParser( add_help=add_help, description='Base Processor') + + parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results') + parser.add_argument('-c', '--config', default=None, help='path to the configuration file') + + parser.add_argument('--phase', default='train', help='must be train or test') + parser.add_argument('--save_result', type=str2bool, default=False, help='if ture, the output of the model will be stored') + parser.add_argument('--start_epoch', type=int, default=0, help='start training from which epoch') + parser.add_argument('--num_epoch', type=int, default=80, help='stop training in which epoch') + parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not') + parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing') + + parser.add_argument('--log_interval', type=int, default=100, help='the interval for printing messages (#iteration)') + parser.add_argument('--save_interval', type=int, default=1, help='the interval for storing models (#iteration)') + parser.add_argument('--eval_interval', type=int, default=5, help='the interval for evaluating models (#iteration)') + parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not') + parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not') + parser.add_argument('--pavi_log', type=str2bool, default=False, help='logging on pavi or not') + + parser.add_argument('--feeder', default='feeder.feeder', help='data loader will be used') + parser.add_argument('--num_worker', type=int, default=4, help='the number of worker per gpu for data loader') + parser.add_argument('--train_feeder_args', action=DictAction, default=dict(), help='the arguments of data loader for training') + parser.add_argument('--test_feeder_args', action=DictAction, default=dict(), help='the arguments of data loader for test') + parser.add_argument('--batch_size', type=int, default=256, help='training batch size') + parser.add_argument('--test_batch_size', type=int, default=256, help='test batch size') + parser.add_argument('--debug', action="store_true", help='less data, faster loading') + + parser.add_argument('--model1', default=None, help='the model will be used') + parser.add_argument('--model2', default=None, help='the model will be used') + parser.add_argument('--model1_args', action=DictAction, default=dict(), help='the arguments of model') + parser.add_argument('--model2_args', action=DictAction, default=dict(), help='the arguments of model') + parser.add_argument('--weights1', default=None, help='the weights for network initialization') + parser.add_argument('--weights2', default=None, help='the weights for network initialization') + parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization') + + return parser diff --git a/processor/recognition.py b/processor/recognition.py index d3905af..a17fe90 100644 --- a/processor/recognition.py +++ b/processor/recognition.py @@ -1,315 +1,376 @@ -import sys -import os -import argparse -import yaml -import numpy as np - -import matplotlib -matplotlib.use('Agg') -import matplotlib.pyplot as plt - -import torch -import torch.nn as nn -import torch.optim as optim - -import torchlight -from torchlight import str2bool -from torchlight import DictAction -from torchlight import import_class - -from .processor import Processor - - -def weights_init(m): - classname = m.__class__.__name__ - if classname.find('Conv1d') != -1: - m.weight.data.normal_(0.0, 0.02) - if m.bias is not None: - m.bias.data.fill_(0) - elif classname.find('Conv2d') != -1: - m.weight.data.normal_(0.0, 0.02) - if m.bias is not None: - m.bias.data.fill_(0) - elif classname.find('BatchNorm') != -1: - m.weight.data.normal_(1.0, 0.02) - m.bias.data.fill_(0) - - -class REC_Processor(Processor): - - def load_model(self): - self.model1 = self.io.load_model(self.arg.model1, **(self.arg.model1_args)) - self.model1.apply(weights_init) - self.model2 = self.io.load_model(self.arg.model2, **(self.arg.model2_args)) - - self.loss_class = nn.CrossEntropyLoss() - self.loss_pred = nn.MSELoss() - self.w_pred = 0.01 - - prior = np.array([0.95, 0.05/2, 0.05/2]) - self.log_prior = torch.FloatTensor(np.log(prior)) - self.log_prior = torch.unsqueeze(torch.unsqueeze(self.log_prior, 0), 0) - - self.load_optimizer() - - def load_optimizer(self): - if self.arg.optimizer == 'SGD': - self.optimizer1 = optim.SGD(params=self.model1.parameters(), - lr=self.arg.base_lr1, - momentum=0.9, - nesterov=self.arg.nesterov, - weight_decay=self.arg.weight_decay) - elif self.arg.optimizer == 'Adam': - self.optimizer1 = optim.Adam(params=self.model1.parameters(), - lr=self.arg.base_lr1, - weight_decay=self.arg.weight_decay) - else: - raise ValueError() - self.optimizer2 = optim.Adam(params=self.model2.parameters(), - lr=self.arg.base_lr2) - - def adjust_lr(self): - if self.arg.optimizer == 'SGD' and self.arg.step: - lr = self.arg.base_lr1 * (0.1**np.sum(self.meta_info['epoch']>= np.array(self.arg.step))) - for param_group in self.optimizer1.param_groups: - param_group['lr'] = lr - self.lr = lr - else: - self.lr = self.arg.base_lr1 - self.lr2 = self.arg.base_lr2 - - def nll_gaussian(self, preds, target, variance, add_const=False): - neg_log_p = ((preds-target)**2/(2*variance)) - if add_const: - const = 0.5*np.log(2*np.pi*variance) - neg_log_p += const - return neg_log_p.sum() / (target.size(0) * target.size(1)) - - def kl_categorical(self, preds, log_prior, num_node, eps=1e-16): - kl_div = preds*(torch.log(preds+eps)-log_prior) - return kl_div.sum()/(num_node*preds.size(0)) - - - def train(self, training_A=False): - - self.model1.train() - self.model2.train() - self.adjust_lr() - loader = self.data_loader['train'] - loss1_value = [] - loss_class_value = [] - loss_recon_value = [] - loss2_value = [] - loss_nll_value = [] - loss_kl_value = [] - - if training_A: - for param1 in self.model1.parameters(): - param1.requires_grad = False - for param2 in self.model2.parameters(): - param2.requires_grad = True - self.iter_info.clear() - self.epoch_info.clear() - - for data, data_downsample, target_data, data_last, label in loader: - data = data.float().to(self.dev) - data_downsample = data_downsample.float().to(self.dev) - label = label.long().to(self.dev) - - gpu_id = data.get_device() - self.log_prior = self.log_prior.cuda(gpu_id) - A_batch, prob, outputs, data_target = self.model2(data_downsample) - loss_nll = self.nll_gaussian(outputs, data_target[:,:,1:,:], variance=5e-4) - loss_kl = self.kl_categorical(prob, self.log_prior, num_node=25) - loss2 = loss_nll + loss_kl - - self.optimizer2.zero_grad() - loss2.backward() - self.optimizer2.step() - - self.iter_info['loss2'] = loss2.data.item() - self.iter_info['loss_nll'] = loss_nll.data.item() - self.iter_info['loss_kl'] = loss_kl.data.item() - self.iter_info['lr'] = '{:.6f}'.format(self.lr2) - - loss2_value.append(self.iter_info['loss2']) - loss_nll_value.append(self.iter_info['loss_nll']) - loss_kl_value.append(self.iter_info['loss_kl']) - self.show_iter_info() - self.meta_info['iter'] += 1 - self.epoch_info['mean_loss2'] = np.mean(loss2_value) - self.epoch_info['mean_loss_nll'] = np.mean(loss_nll_value) - self.epoch_info['mean_loss_kl'] = np.mean(loss_kl_value) - self.show_epoch_info() - self.io.print_timer() - - else: - for param1 in self.model1.parameters(): - param1.requires_grad = True - for param2 in self.model2.parameters(): - param2.requires_grad = True - self.iter_info.clear() - self.epoch_info.clear() - for data, data_downsample, target_data, data_last, label in loader: - data = data.float().to(self.dev) - data_downsample = data_downsample.float().to(self.dev) - target_data = target_data.float().to(self.dev) - data_last = data_last.float().to(self.dev) - label = label.long().to(self.dev) - - A_batch, prob, outputs, _ = self.model2(data_downsample) - x_class, pred, target = self.model1(data, target_data, data_last, A_batch, self.arg.lamda_act) - loss_class = self.loss_class(x_class, label) - loss_recon = self.loss_pred(pred, target) - loss1 = loss_class + self.w_pred*loss_recon - - self.optimizer1.zero_grad() - loss1.backward() - self.optimizer1.step() - - self.iter_info['loss1'] = loss1.data.item() - self.iter_info['loss_class'] = loss_class.data.item() - self.iter_info['loss_recon'] = loss_recon.data.item()*self.w_pred - self.iter_info['lr'] = '{:.6f}'.format(self.lr) - - loss1_value.append(self.iter_info['loss1']) - loss_class_value.append(self.iter_info['loss_class']) - loss_recon_value.append(self.iter_info['loss_recon']) - self.show_iter_info() - self.meta_info['iter'] += 1 - - self.epoch_info['mean_loss1']= np.mean(loss1_value) - self.epoch_info['mean_loss_class'] = np.mean(loss_class_value) - self.epoch_info['mean_loss_recon'] = np.mean(loss_recon_value) - self.show_epoch_info() - self.io.print_timer() - - - def test(self, evaluation=True, testing_A=False, save=False, save_feature=False): - - self.model1.eval() - self.model2.eval() - loader = self.data_loader['test'] - loss1_value = [] - loss_class_value = [] - loss_recon_value = [] - loss2_value = [] - loss_nll_value = [] - loss_kl_value = [] - result_frag = [] - label_frag = [] - - if testing_A: - A_all = [] - self.epoch_info.clear() - for data, data_downsample, target_data, data_last, label in loader: - data = data.float().to(self.dev) - data_downsample = data_downsample.float().to(self.dev) - label = label.long().to(self.dev) - - with torch.no_grad(): - A_batch, prob, outputs, data_bn = self.model2(data_downsample) - - if save: - n = A_batch.size(0) - a = A_batch[:int(n/2),:,:,:].cpu().numpy() - A_all.extend(a) - - if evaluation: - gpu_id = data.get_device() - self.log_prior = self.log_prior.cuda(gpu_id) - loss_nll = self.nll_gaussian(outputs, data_bn[:,:,1:,:], variance=5e-4) - loss_kl = self.kl_categorical(prob, self.log_prior, num_node=25) - loss2 = loss_nll + loss_kl - - loss2_value.append(loss2.item()) - loss_nll_value.append(loss_nll.item()) - loss_kl_value.append(loss_kl.item()) - - if save: - A_all = np.array(A_all) - np.save(os.path.join(self.arg.work_dir, 'test_adj.npy'), A_all) - - if evaluation: - self.epoch_info['mean_loss2'] = np.mean(loss2_value) - self.epoch_info['mean_loss_nll'] = np.mean(loss_nll_value) - self.epoch_info['mean_loss_kl'] = np.mean(loss_kl_value) - self.show_epoch_info() - - else: - recon_data = [] - feature_map = [] - self.epoch_info.clear() - for data, data_downsample, target_data, data_last, label in loader: - data = data.float().to(self.dev) - data_downsample = data_downsample.float().to(self.dev) - target_data = target_data.float().to(self.dev) - data_last = data_last.float().to(self.dev) - label = label.long().to(self.dev) - - with torch.no_grad(): - A_batch, prob, outputs, _ = self.model2(data_downsample) - x_class, pred, target = self.model1(data, target_data, data_last, A_batch, self.arg.lamda_act) - result_frag.append(x_class.data.cpu().numpy()) - - if save: - n = pred.size(0) - p = pred[::2,:,:,:].cpu().numpy() - recon_data.extend(p) - - if evaluation: - loss_class = self.loss_class(x_class, label) - loss_recon = self.loss_pred(pred, target) - loss1 = loss_class + self.w_pred*loss_recon - - loss1_value.append(loss1.item()) - loss_class_value.append(loss_class.item()) - loss_recon_value.append(loss_recon.item()) - label_frag.append(label.data.cpu().numpy()) - - if save: - recon_data = np.array(recon_data) - np.save(os.path.join(self.arg.work_dir, 'recon_data.npy'), recon_data) - - - self.result = np.concatenate(result_frag) - if evaluation: - self.label = np.concatenate(label_frag) - self.epoch_info['mean_loss1'] = np.mean(loss1_value) - self.epoch_info['mean_loss_class'] = np.mean(loss_class_value) - self.epoch_info['mean_loss_recon'] = np.mean(loss_recon_value) - self.show_epoch_info() - - for k in self.arg.show_topk: - hit_top_k = [] - rank = self.result.argsort() - for i,l in enumerate(self.label): - hit_top_k.append(l in rank[i, -k:]) - self.io.print_log('\n') - accuracy = sum(hit_top_k)*1.0/len(hit_top_k) - self.io.print_log('\tTop{}: {:.2f}%'.format(k, 100 * accuracy)) - - - - @staticmethod - def get_parser(add_help=False): - - parent_parser = Processor.get_parser(add_help=False) - parser = argparse.ArgumentParser( - add_help=add_help, - parents=[parent_parser], - description='Spatial Temporal Graph Convolution Network') - - parser.add_argument('--show_topk', type=int, default=[1, 5], nargs='+', help='which Top K accuracy will be shown') - parser.add_argument('--base_lr1', type=float, default=0.1, help='initial learning rate') - parser.add_argument('--base_lr2', type=float, default=0.0005, help='initial learning rate') - parser.add_argument('--step', type=int, default=[], nargs='+', help='the epoch where optimizer reduce the learning rate') - parser.add_argument('--optimizer', default='SGD', help='type of optimizer') - parser.add_argument('--nesterov', type=str2bool, default=True, help='use nesterov or not') - parser.add_argument('--weight_decay', type=float, default=0.0001, help='weight decay for optimizer') - - parser.add_argument('--max_hop_dir', type=str, default='max_hop_4') - parser.add_argument('--lamda_act', type=float, default=0.5) - parser.add_argument('--lamda_act_dir', type=str, default='lamda_05') - - return parser +import sys +import os +import argparse +import yaml +import numpy as np + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.optim as optim + +import torchlight +from torchlight.io import str2bool +from torchlight.io import DictAction +from torchlight.io import import_class + +from .processor import Processor + +from net.model_poseformer import * + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv1d') != -1: + m.weight.data.normal_(0.0, 0.02) + if m.bias is not None: + m.bias.data.fill_(0) + elif classname.find('Conv2d') != -1: + m.weight.data.normal_(0.0, 0.02) + if m.bias is not None: + m.bias.data.fill_(0) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +class REC_Processor(Processor): + + def load_model(self): + self.model1 = self.io.load_model(self.arg.model1, **(self.arg.model1_args)) + self.model1.apply(weights_init) + self.model2 = self.io.load_model(self.arg.model2, **(self.arg.model2_args)) + self.model3 = PoseTransformer(num_frame=290, num_joints=25, in_chans=2, embed_dim_ratio=32, depth=4, + num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0) + + self.loss_class = nn.CrossEntropyLoss() + self.loss_pred = nn.MSELoss() + self.w_pred = 0.01 + + prior = np.array([0.95, 0.05/2, 0.05/2]) + self.log_prior = torch.FloatTensor(np.log(prior)) + self.log_prior = torch.unsqueeze(torch.unsqueeze(self.log_prior, 0), 0) + + self.load_optimizer() + + def load_optimizer(self): + if self.arg.optimizer == 'SGD': + self.optimizer1 = optim.SGD(params=self.model3.parameters(), + lr=self.arg.base_lr1, + momentum=0.9, + nesterov=self.arg.nesterov, + weight_decay=self.arg.weight_decay) + elif self.arg.optimizer == 'Adam': + self.optimizer1 = optim.Adam(params=self.model3.parameters(), + lr=self.arg.base_lr1, + weight_decay=self.arg.weight_decay) + + def adjust_lr(self): + if self.arg.optimizer == 'SGD' and self.arg.step: + lr = self.arg.base_lr1 * (0.1**np.sum(self.meta_info['epoch']>= np.array(self.arg.step))) + for param_group in self.optimizer1.param_groups: + param_group['lr'] = lr + self.lr = lr + else: + self.lr = self.arg.base_lr1 + self.lr2 = self.arg.base_lr2 + + def nll_gaussian(self, preds, target, variance, add_const=False): + neg_log_p = ((preds-target)**2/(2*variance)) + if add_const: + const = 0.5*np.log(2*np.pi*variance) + neg_log_p += const + return neg_log_p.sum() / (target.size(0) * target.size(1)) + + def kl_categorical(self, preds, log_prior, num_node, eps=1e-16): + kl_div = preds*(torch.log(preds+eps)-log_prior) + return kl_div.sum()/(num_node*preds.size(0)) + + + def train(self, training_A=False): + self.model3.train() + self.adjust_lr() + loader = self.data_loader['train'] + loss_class_value = [] + loss2_value = [] + loss_nll_value = [] + loss_kl_value = [] + + if training_A: + for param1 in self.model1.parameters(): + param1.requires_grad = False + for param2 in self.model2.parameters(): + param2.requires_grad = True + self.iter_info.clear() + self.epoch_info.clear() + + for data, data_downsample, target_data, data_last, label in loader: + # data: (32,3,290,25,2) data_downsample:(32,3,50,25,2) target_data:(32,3,10,25,2) data_last:(32,3,1,25,2) label:(32) + data = data.float().to(self.dev) + data_downsample = data_downsample.float().to(self.dev) + label = label.long().to(self.dev) + + gpu_id = data.get_device() + self.log_prior = self.log_prior.cuda(gpu_id) + A_batch, prob, outputs, data_target = self.model2(data_downsample) + loss_nll = self.nll_gaussian(outputs, data_target[:,:,1:,:], variance=5e-4) + loss_kl = self.kl_categorical(prob, self.log_prior, num_node=25) + loss2 = loss_nll + loss_kl + + self.optimizer2.zero_grad() + loss2.backward() + self.optimizer2.step() + + self.iter_info['loss2'] = loss2.data.item() + self.iter_info['loss_nll'] = loss_nll.data.item() + self.iter_info['loss_kl'] = loss_kl.data.item() + self.iter_info['lr'] = '{:.6f}'.format(self.lr2) + + loss2_value.append(self.iter_info['loss2']) + loss_nll_value.append(self.iter_info['loss_nll']) + loss_kl_value.append(self.iter_info['loss_kl']) + self.show_iter_info() + self.meta_info['iter'] += 1 + break + self.epoch_info['mean_loss2'] = np.mean(loss2_value) + self.epoch_info['mean_loss_nll'] = np.mean(loss_nll_value) + self.epoch_info['mean_loss_kl'] = np.mean(loss_kl_value) + self.show_epoch_info() + self.io.print_timer() + + else: + ''' + for param1 in self.model1.parameters(): + param1.requires_grad = True + for param2 in self.model2.parameters(): + param2.requires_grad = True + ''' + for param3 in self.model3.parameters(): + param3.requires_grad = True + + self.iter_info.clear() + self.epoch_info.clear() + for data, data_downsample, target_data, data_last, label in loader: + data = data.float().to(self.dev) + #data_downsample = data_downsample.float().to(self.dev) + target_data = target_data.float().to(self.dev) + #data_last = data_last.float().to(self.dev) + label = label.long().to(self.dev) + + #A_batch, prob, outputs, _ = self.model2(data_downsample) + + # wsx model2 viz + ''' + import tensorwatch as tw + import torchvision.models + alexnet_model = torchvision.models.alexnet() + img = tw.draw_model(alexnet_model, [1, 3, 1024, 1024]) + img.save(r'img.jpg') + + from torchviz import make_dot + make_dot((A_batch), params=dict(list(self.model2.named_parameters()))).render("modle2", format="png") + + import hiddenlayer as h + vis_graph = h.build_graph(self.model2, torch.zeros([4,3,50,25,2])) + vis_graph.theme = h.graph.THEMES["blue"].copy() + vis_graph.save("./hl_model2.png") + + viz = self.model2.to(self.dev) + import torch + torch.save(viz, './model2.pth') + ''' + + # workingwsx + #x_class, pred, target = self.model1(data, target_data, data_last, A_batch, self.arg.lamda_act) + x_class, target = self.model3(data, target_data) + loss_class = self.loss_class(x_class, label) + #loss_recon = self.loss_pred(pred, target) + #loss1 = loss_class + + self.optimizer1.zero_grad() + loss_class.backward() + self.optimizer1.step() + + #self.iter_info['loss1'] = loss1.data.item() + self.iter_info['loss_class'] = loss_class.data.item() + #self.iter_info['loss_recon'] = loss_recon.data.item()*self.w_pred + self.iter_info['lr'] = '{:.6f}'.format(self.lr) + + #loss1_value.append(self.iter_info['loss1']) + loss_class_value.append(self.iter_info['loss_class']) + #loss_recon_value.append(self.iter_info['loss_recon']) + self.show_iter_info() + self.meta_info['iter'] += 1 + #break # breakwsx + + #self.epoch_info['mean_loss1']= np.mean(loss1_value) + self.epoch_info['mean_loss_class'] = np.mean(loss_class_value) + self.epoch_loss_class_train.append(np.mean(loss_class_value)) + #self.epoch_info['mean_loss_recon'] = np.mean(loss_recon_value) + self.show_epoch_info() + self.io.print_timer() + + # show train curve + plt.figure() + epoch_x = np.arange(0, len(self.epoch_loss_class_train)) + 1 + plt.plot(epoch_x, self.epoch_loss_class_train[0:], '--', color='C0') + plt.legend(['epoch_loss_class_train']) + plt.ylabel('action CrossEntropyLoss') + plt.xlabel('Epoch') + plt.xlim(0, self.meta_info['epoch']) + plt.savefig('epoch_loss_class_train.png') + plt.close() + + + + + + def test(self, evaluation=True, testing_A=False, save=False, save_feature=False): + + #self.model1.eval() + #self.model2.eval() + self.model3.eval() + loader = self.data_loader['test'] + loss1_value = [] + loss_class_value = [] + loss_recon_value = [] + loss2_value = [] + loss_nll_value = [] + loss_kl_value = [] + result_frag = [] + label_frag = [] + + if testing_A: + A_all = [] + self.epoch_info.clear() + for data, data_downsample, target_data, data_last, label in loader: + data = data.float().to(self.dev) + data_downsample = data_downsample.float().to(self.dev) + label = label.long().to(self.dev) + + with torch.no_grad(): + A_batch, prob, outputs, data_bn = self.model2(data_downsample) + + if save: + n = A_batch.size(0) + a = A_batch[:int(n/2),:,:,:].cpu().numpy() + A_all.extend(a) + + if evaluation: + gpu_id = data.get_device() + self.log_prior = self.log_prior.cuda(gpu_id) + loss_nll = self.nll_gaussian(outputs, data_bn[:,:,1:,:], variance=5e-4) + loss_kl = self.kl_categorical(prob, self.log_prior, num_node=25) + loss2 = loss_nll + loss_kl + + loss2_value.append(loss2.item()) + loss_nll_value.append(loss_nll.item()) + loss_kl_value.append(loss_kl.item()) + + break + + if save: + A_all = np.array(A_all) + np.save(os.path.join(self.arg.work_dir, 'test_adj.npy'), A_all) + + if evaluation: + self.epoch_info['mean_loss2'] = np.mean(loss2_value) + self.epoch_info['mean_loss_nll'] = np.mean(loss_nll_value) + self.epoch_info['mean_loss_kl'] = np.mean(loss_kl_value) + self.show_epoch_info() + + else: + recon_data = [] + feature_map = [] + self.epoch_info.clear() + for data, data_downsample, target_data, data_last, label in loader: + data = data.float().to(self.dev) + #data_downsample = data_downsample.float().to(self.dev) + target_data = target_data.float().to(self.dev) + #data_last = data_last.float().to(self.dev) + label = label.long().to(self.dev) + + with torch.no_grad(): + #A_batch, prob, outputs, _ = self.model2(data_downsample) + #x_class, pred, target = self.model1(data, target_data, data_last, A_batch, self.arg.lamda_act) + x_class, target = self.model3(data, target_data) + result_frag.append(x_class.data.cpu().numpy()) + + """ + if save: + n = pred.size(0) + p = pred[::2,:,:,:].cpu().numpy() + recon_data.extend(p) + """ + + if evaluation: + loss_class = self.loss_class(x_class, label) + #loss_recon = self.loss_pred(pred, target) + #loss1 = loss_class + self.w_pred*loss_recon + + #loss1_value.append(loss1.item()) + loss_class_value.append(loss_class.item()) + #loss_recon_value.append(loss_recon.item()) + label_frag.append(label.data.cpu().numpy()) + #break #breakwsx + """ + if save: + recon_data = np.array(recon_data) + np.save(os.path.join(self.arg.work_dir, 'recon_data.npy'), recon_data) + """ + + self.result = np.concatenate(result_frag) + + if evaluation: + self.label = np.concatenate(label_frag) + #self.epoch_info['mean_loss1'] = np.mean(loss1_value) + self.epoch_info['mean_loss_class'] = np.mean(loss_class_value) + self.epoch_loss_class_test.append(np.mean(loss_class_value)) + #self.epoch_info['mean_loss_recon'] = np.mean(loss_recon_value) + self.show_epoch_info() + + for k in self.arg.show_topk: + hit_top_k = [] + rank = self.result.argsort() + for i,l in enumerate(self.label): + hit_top_k.append(l in rank[i, -k:]) + self.io.print_log('\n') + accuracy = sum(hit_top_k)*1.0/len(hit_top_k) + self.io.print_log('\tTop{}: {:.2f}%'.format(k, 100 * accuracy)) + + # wsx test curve + plt.figure() + epoch_x = np.arange(0, len(self.epoch_loss_class_test)) + 1 + plt.plot(epoch_x, self.epoch_loss_class_test[0:], '--', color='C1') + plt.legend(['epoch_loss_class_eval']) + plt.ylabel('action CrossEntropyLoss') + plt.xlabel('Epoch') + plt.xlim(0, len(self.epoch_loss_class_test)) + plt.savefig('epoch_loss_class_eval.png') + plt.close() + + + @staticmethod + def get_parser(add_help=False): + + parent_parser = Processor.get_parser(add_help=False) + parser = argparse.ArgumentParser( + add_help=add_help, + parents=[parent_parser], + description='Spatial Temporal Graph Convolution Network') + + parser.add_argument('--show_topk', type=int, default=[1, 5], nargs='+', help='which Top K accuracy will be shown') + parser.add_argument('--base_lr1', type=float, default=0.1, help='initial learning rate') + parser.add_argument('--base_lr2', type=float, default=0.0005, help='initial learning rate') + parser.add_argument('--step', type=int, default=[], nargs='+', help='the epoch where optimizer reduce the learning rate') + parser.add_argument('--optimizer', default='SGD', help='type of optimizer') + parser.add_argument('--nesterov', type=str2bool, default=True, help='use nesterov or not') + parser.add_argument('--weight_decay', type=float, default=0.0001, help='weight decay for optimizer') + + parser.add_argument('--max_hop_dir', type=str, default='max_hop_4') + parser.add_argument('--lamda_act', type=float, default=0.5) + parser.add_argument('--lamda_act_dir', type=str, default='lamda_05') + + return parser diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000..4dd2d1d --- /dev/null +++ b/requirement.txt @@ -0,0 +1,327 @@ +# This file may be used to create an environment using: +# $ conda create --name --file +# platform: linux-64 +_ipyw_jlab_nb_ext_conf=0.1.0=py38_0 +_libgcc_mutex=0.1=main +alabaster=0.7.12=py_0 +anaconda=2020.11=py38_0 +anaconda-client=1.7.2=py38_0 +anaconda-navigator=1.10.0=py38_0 +anaconda-project=0.8.4=py_0 +argh=0.26.2=py38_0 +argon2-cffi=20.1.0=py38h7b6447c_1 +asn1crypto=1.4.0=py_0 +astroid=2.4.2=py38_0 +astropy=4.0.2=py38h7b6447c_0 +async_generator=1.10=py_0 +atomicwrites=1.4.0=py_0 +attrs=20.3.0=pyhd3eb1b0_0 +autopep8=1.5.4=py_0 +babel=2.8.1=pyhd3eb1b0_0 +backcall=0.2.0=py_0 +backports=1.0=py_2 +backports.functools_lru_cache=1.6.4=pyhd3eb1b0_0 +backports.shutil_get_terminal_size=1.0.0=py38_2 +backports.tempfile=1.0=pyhd3eb1b0_1 +backports.weakref=1.0.post1=py_1 +beautifulsoup4=4.9.3=pyhb0f4dca_0 +bitarray=1.6.1=py38h27cfd23_0 +bkcharts=0.2=py38_0 +blas=1.0=mkl +bleach=3.2.1=py_0 +blosc=1.20.1=hd408876_0 +bokeh=2.2.3=py38_0 +boto=2.49.0=py38_0 +bottleneck=1.3.2=py38heb32a55_1 +brotlipy=0.7.0=py38h7b6447c_1000 +bzip2=1.0.8=h7b6447c_0 +ca-certificates=2020.10.14=0 +cairo=1.14.12=h8948797_3 +certifi=2020.6.20=pyhd3eb1b0_3 +cffi=1.14.3=py38he30daa8_0 +chardet=3.0.4=py38_1003 +click=7.1.2=py_0 +cloudpickle=1.6.0=py_0 +clyent=1.2.2=py38_1 +colorama=0.4.4=py_0 +conda=4.10.1=py38h06a4308_1 +conda-build=3.20.5=py38_1 +conda-env=2.6.0=1 +conda-package-handling=1.7.3=py38h27cfd23_1 +conda-verify=3.4.2=py_1 +contextlib2=0.6.0.post1=py_0 +cryptography=3.1.1=py38h1ba5d50_0 +curl=7.71.1=hbc83047_1 +cycler=0.10.0=py38_0 +cython=0.29.21=py38he6710b0_0 +cytoolz=0.11.0=py38h7b6447c_0 +dask=2.30.0=py_0 +dask-core=2.30.0=py_0 +dbus=1.13.18=hb2f20db_0 +decorator=4.4.2=py_0 +defusedxml=0.6.0=py_0 +diff-match-patch=20200713=py_0 +distributed=2.30.1=py38h06a4308_0 +docutils=0.16=py38_1 +entrypoints=0.3=py38_0 +et_xmlfile=1.0.1=py_1001 +expat=2.2.10=he6710b0_2 +fastcache=1.1.0=py38h7b6447c_0 +filelock=3.0.12=py_0 +flake8=3.8.4=py_0 +flask=1.1.2=py_0 +fontconfig=2.13.0=h9420a91_0 +freetype=2.10.4=h5ab3b9f_0 +fribidi=1.0.10=h7b6447c_0 +fsspec=0.8.3=py_0 +future=0.18.2=py38_1 +get_terminal_size=1.0.0=haa9412d_0 +gevent=20.9.0=py38h7b6447c_0 +glib=2.66.1=h92f7085_0 +glob2=0.7=py_0 +gmp=6.1.2=h6c8ec71_1 +gmpy2=2.0.8=py38hd5f6e3b_3 +graphite2=1.3.14=h23475e2_0 +greenlet=0.4.17=py38h7b6447c_0 +gst-plugins-base=1.14.0=hbbd80ab_1 +gstreamer=1.14.0=hb31296c_0 +h5py=2.10.0=py38h7918eee_0 +harfbuzz=2.4.0=hca77d97_1 +hdf5=1.10.4=hb1b8bf9_0 +heapdict=1.0.1=py_0 +html5lib=1.1=py_0 +icu=58.2=he6710b0_3 +idna=2.10=py_0 +imageio=2.9.0=py_0 +imagesize=1.2.0=py_0 +importlib-metadata=2.0.0=py_1 +importlib_metadata=2.0.0=1 +iniconfig=1.1.1=py_0 +intel-openmp=2020.2=254 +intervaltree=3.1.0=py_0 +ipykernel=5.3.4=py38h5ca1d4c_0 +ipython=7.19.0=py38hb070fc8_0 +ipython_genutils=0.2.0=py38_0 +ipywidgets=7.5.1=py_1 +isort=5.6.4=py_0 +itsdangerous=1.1.0=py_0 +jbig=2.1=hdba287a_0 +jdcal=1.4.1=py_0 +jedi=0.17.1=py38_0 +jeepney=0.5.0=pyhd3eb1b0_0 +jinja2=2.11.2=py_0 +joblib=0.17.0=py_0 +jpeg=9b=h024ee3a_2 +json5=0.9.5=py_0 +jsonschema=3.2.0=py_2 +jupyter=1.0.0=py38_7 +jupyter_client=6.1.7=py_0 +jupyter_console=6.2.0=py_0 +jupyter_core=4.6.3=py38_0 +jupyterlab=2.2.6=py_0 +jupyterlab_pygments=0.1.2=py_0 +jupyterlab_server=1.2.0=py_0 +keyring=21.4.0=py38_1 +kiwisolver=1.3.0=py38h2531618_0 +krb5=1.18.2=h173b8e3_0 +lazy-object-proxy=1.4.3=py38h7b6447c_0 +lcms2=2.11=h396b838_0 +ld_impl_linux-64=2.33.1=h53a641e_7 +libarchive=3.4.2=h62408e4_0 +libcurl=7.71.1=h20c2e04_1 +libedit=3.1.20191231=h14c3975_1 +libffi=3.3=he6710b0_2 +libgcc-ng=9.1.0=hdf63c60_0 +libgfortran-ng=7.3.0=hdf63c60_0 +liblief=0.10.1=he6710b0_0 +libllvm10=10.0.1=hbcb73fb_5 +libpng=1.6.37=hbc83047_0 +libsodium=1.0.18=h7b6447c_0 +libspatialindex=1.9.3=he6710b0_0 +libssh2=1.9.0=h1ba5d50_1 +libstdcxx-ng=9.1.0=hdf63c60_0 +libtiff=4.1.0=h2733197_1 +libtool=2.4.6=h7b6447c_1005 +libuuid=1.0.3=h1bed415_2 +libxcb=1.14=h7b6447c_0 +libxml2=2.9.10=hb55368b_3 +libxslt=1.1.34=hc22bd24_0 +llvmlite=0.34.0=py38h269e1b5_4 +locket=0.2.0=py38_1 +lxml=4.6.1=py38hefd8a0e_0 +lz4-c=1.9.2=heb0550a_3 +lzo=2.10=h7b6447c_2 +markupsafe=1.1.1=py38h7b6447c_0 +matplotlib=3.3.2=0 +matplotlib-base=3.3.2=py38h817c723_0 +mccabe=0.6.1=py38_1 +mistune=0.8.4=py38h7b6447c_1000 +mkl=2020.2=256 +mkl-service=2.3.0=py38he904b0f_0 +mkl_fft=1.2.0=py38h23d657b_0 +mkl_random=1.1.1=py38h0573a6f_0 +mock=4.0.2=py_0 +more-itertools=8.6.0=pyhd3eb1b0_0 +mpc=1.1.0=h10f8cd9_1 +mpfr=4.0.2=hb69a4c5_1 +mpmath=1.1.0=py38_0 +msgpack-python=1.0.0=py38hfd86e86_1 +multipledispatch=0.6.0=py38_0 +navigator-updater=0.2.1=py38_0 +nbclient=0.5.1=py_0 +nbconvert=6.0.7=py38_0 +nbformat=5.0.8=py_0 +ncurses=6.2=he6710b0_1 +nest-asyncio=1.4.2=pyhd3eb1b0_0 +networkx=2.5=py_0 +nltk=3.5=py_0 +nose=1.3.7=py38_2 +notebook=6.1.4=py38_0 +numba=0.51.2=py38h0573a6f_1 +numexpr=2.7.1=py38h423224d_0 +numpy=1.19.2=py38h54aff64_0 +numpy-base=1.19.2=py38hfa32c7d_0 +numpydoc=1.1.0=pyhd3eb1b0_1 +olefile=0.46=py_0 +openpyxl=3.0.5=py_0 +openssl=1.1.1h=h7b6447c_0 +packaging=20.4=py_0 +pandas=1.1.3=py38he6710b0_0 +pandoc=2.11=hb0f4dca_0 +pandocfilters=1.4.3=py38h06a4308_1 +pango=1.45.3=hd140c19_0 +parso=0.7.0=py_0 +partd=1.1.0=py_0 +patchelf=0.12=he6710b0_0 +path=15.0.0=py38_0 +path.py=12.5.0=0 +pathlib2=2.3.5=py38_0 +pathtools=0.1.2=py_1 +patsy=0.5.1=py38_0 +pcre=8.44=he6710b0_0 +pep8=1.7.1=py38_0 +pexpect=4.8.0=py38_0 +pickleshare=0.7.5=py38_1000 +pillow=8.0.1=py38he98fc37_0 +pip=21.1=pypi_0 +pixman=0.40.0=h7b6447c_0 +pkginfo=1.6.1=py38h06a4308_0 +pluggy=0.13.1=py38_0 +ply=3.11=py38_0 +prometheus_client=0.8.0=py_0 +prompt-toolkit=3.0.8=py_0 +prompt_toolkit=3.0.8=0 +psutil=5.7.2=py38h7b6447c_0 +ptyprocess=0.6.0=py38_0 +py=1.9.0=py_0 +py-lief=0.10.1=py38h403a769_0 +pycodestyle=2.6.0=py_0 +pycosat=0.6.3=py38h7b6447c_1 +pycparser=2.20=py_2 +pycurl=7.43.0.6=py38h1ba5d50_0 +pydocstyle=5.1.1=py_0 +pyflakes=2.2.0=py_0 +pygments=2.7.2=pyhd3eb1b0_0 +pylint=2.6.0=py38_0 +pyodbc=4.0.30=py38he6710b0_0 +pyopenssl=19.1.0=py_1 +pyparsing=2.4.7=py_0 +pyqt=5.9.2=py38h05f1152_4 +pyrsistent=0.17.3=py38h7b6447c_0 +pysocks=1.7.1=py38_0 +pytables=3.6.1=py38h9fd0a39_0 +pytest=6.1.1=py38_0 +python=3.8.5=h7579374_1 +python-dateutil=2.8.1=py_0 +python-jsonrpc-server=0.4.0=py_0 +python-language-server=0.35.1=py_0 +python-libarchive-c=2.9=py_0 +pytz=2020.1=py_0 +pywavelets=1.1.1=py38h7b6447c_2 +pyxdg=0.27=pyhd3eb1b0_0 +pyyaml=5.3.1=py38h7b6447c_1 +pyzmq=19.0.2=py38he6710b0_1 +qdarkstyle=2.8.1=py_0 +qt=5.9.7=h5867ecd_1 +qtawesome=1.0.1=py_0 +qtconsole=4.7.7=py_0 +qtpy=1.9.0=py_0 +readline=8.0=h7b6447c_0 +regex=2020.10.15=py38h7b6447c_0 +requests=2.24.0=py_0 +ripgrep=12.1.1=0 +rope=0.18.0=py_0 +rtree=0.9.4=py38_1 +ruamel_yaml=0.15.87=py38h7b6447c_1 +scikit-image=0.17.2=py38hdf5156a_0 +scikit-learn=0.23.2=py38h0573a6f_0 +scipy=1.5.2=py38h0b6359f_0 +seaborn=0.11.0=py_0 +secretstorage=3.1.2=py38_0 +send2trash=1.5.0=py38_0 +setuptools=50.3.1=py38h06a4308_1 +simplegeneric=0.8.1=py38_2 +singledispatch=3.4.0.3=py_1001 +sip=4.19.13=py38he6710b0_0 +six=1.15.0=py38h06a4308_0 +snowballstemmer=2.0.0=py_0 +sortedcollections=1.2.1=py_0 +sortedcontainers=2.2.2=py_0 +soupsieve=2.0.1=py_0 +sphinx=3.2.1=py_0 +sphinxcontrib=1.0=py38_1 +sphinxcontrib-applehelp=1.0.2=py_0 +sphinxcontrib-devhelp=1.0.2=py_0 +sphinxcontrib-htmlhelp=1.0.3=py_0 +sphinxcontrib-jsmath=1.0.1=py_0 +sphinxcontrib-qthelp=1.0.3=py_0 +sphinxcontrib-serializinghtml=1.1.4=py_0 +sphinxcontrib-websupport=1.2.4=py_0 +spyder=4.1.5=py38_0 +spyder-kernels=1.9.4=py38_0 +sqlalchemy=1.3.20=py38h7b6447c_0 +sqlite=3.33.0=h62c20be_0 +statsmodels=0.12.0=py38h7b6447c_0 +sympy=1.6.2=py38h06a4308_1 +tbb=2020.3=hfd86e86_0 +tblib=1.7.0=py_0 +terminado=0.9.1=py38_0 +testpath=0.4.4=py_0 +threadpoolctl=2.1.0=pyh5ca1d4c_0 +tifffile=2020.10.1=py38hdd07704_2 +tk=8.6.10=hbc83047_0 +toml=0.10.1=py_0 +toolz=0.11.1=py_0 +torch=1.8.1=pypi_0 +torchvision=0.9.1=pypi_0 +tornado=6.0.4=py38h7b6447c_1 +tqdm=4.50.2=py_0 +traitlets=5.0.5=py_0 +typing_extensions=3.7.4.3=py_0 +ujson=4.0.1=py38he6710b0_0 +unicodecsv=0.14.1=py38_0 +unixodbc=2.3.9=h7b6447c_0 +urllib3=1.25.11=py_0 +watchdog=0.10.3=py38_0 +wcwidth=0.2.5=py_0 +webencodings=0.5.1=py38_1 +werkzeug=1.0.1=py_0 +wheel=0.35.1=py_0 +widgetsnbextension=3.5.1=py38_0 +wrapt=1.11.2=py38h7b6447c_0 +wurlitzer=2.0.1=py38_0 +xlrd=1.2.0=py_0 +xlsxwriter=1.3.7=py_0 +xlwt=1.3.0=py38_0 +xmltodict=0.12.0=py_0 +xz=5.2.5=h7b6447c_0 +yaml=0.2.5=h7b6447c_0 +yapf=0.30.0=py_0 +zeromq=4.3.3=he6710b0_3 +zict=2.0.0=py_0 +zipp=3.4.0=pyhd3eb1b0_0 +zlib=1.2.11=h7b6447c_3 +zope=1.0=py38_1 +zope.event=4.5.0=py38_0 +zope.interface=5.1.2=py38h7b6447c_0 +zstd=1.4.5=h9ceee32_0 diff --git a/torchlight/__init__.py b/torchlight/__init__.py new file mode 100644 index 0000000..5e0e7b9 --- /dev/null +++ b/torchlight/__init__.py @@ -0,0 +1,8 @@ +from .io import IO +from .io import str2bool +from .io import str2dict +from .io import DictAction +from .io import import_class +from .gpu import visible_gpu +from .gpu import occupy_gpu +from .gpu import ngpu diff --git a/torchlight/__pycache__/__init__.cpython-36.pyc b/torchlight/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..2765d62 Binary files /dev/null and b/torchlight/__pycache__/__init__.cpython-36.pyc differ diff --git a/torchlight/__pycache__/io.cpython-36.pyc b/torchlight/__pycache__/io.cpython-36.pyc new file mode 100644 index 0000000..70b62d7 Binary files /dev/null and b/torchlight/__pycache__/io.cpython-36.pyc differ diff --git a/torchlight/build/lib/torchlight/__init__.py b/torchlight/build/lib/torchlight/__init__.py new file mode 100644 index 0000000..d3f5a12 --- /dev/null +++ b/torchlight/build/lib/torchlight/__init__.py @@ -0,0 +1 @@ + diff --git a/torchlight/build/lib/torchlight/gpu.py b/torchlight/build/lib/torchlight/gpu.py new file mode 100644 index 0000000..e086d4c --- /dev/null +++ b/torchlight/build/lib/torchlight/gpu.py @@ -0,0 +1,35 @@ +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/torchlight/build/lib/torchlight/io.py b/torchlight/build/lib/torchlight/io.py new file mode 100644 index 0000000..5b43720 --- /dev/null +++ b/torchlight/build/lib/torchlight/io.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/torchlight/dist/torchlight-1.0-py3.6.egg b/torchlight/dist/torchlight-1.0-py3.6.egg new file mode 100644 index 0000000..46e5716 Binary files /dev/null and b/torchlight/dist/torchlight-1.0-py3.6.egg differ diff --git a/torchlight/gpu.py b/torchlight/gpu.py new file mode 100644 index 0000000..76566a9 --- /dev/null +++ b/torchlight/gpu.py @@ -0,0 +1,36 @@ +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + #os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2" + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/torchlight/io.py b/torchlight/io.py new file mode 100644 index 0000000..5b43720 --- /dev/null +++ b/torchlight/io.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict) diff --git a/torchlight/setup.py b/torchlight/setup.py index 95bbce2..a8a1647 100644 --- a/torchlight/setup.py +++ b/torchlight/setup.py @@ -1,8 +1,8 @@ -from setuptools import find_packages, setup - -setup( - name='torchlight', - version='1.0', - description='A mini framework for pytorch', - packages=find_packages(), - install_requires=[]) +from setuptools import find_packages, setup + +setup( + name='torchlight', + version='1.0', + description='A mini framework for pytorch', + packages=find_packages(), + install_requires=[]) diff --git a/torchlight/torchlight.egg-info/PKG-INFO b/torchlight/torchlight.egg-info/PKG-INFO new file mode 100644 index 0000000..4020517 --- /dev/null +++ b/torchlight/torchlight.egg-info/PKG-INFO @@ -0,0 +1,10 @@ +Metadata-Version: 1.0 +Name: torchlight +Version: 1.0 +Summary: A mini framework for pytorch +Home-page: UNKNOWN +Author: UNKNOWN +Author-email: UNKNOWN +License: UNKNOWN +Description: UNKNOWN +Platform: UNKNOWN diff --git a/torchlight/torchlight.egg-info/SOURCES.txt b/torchlight/torchlight.egg-info/SOURCES.txt new file mode 100644 index 0000000..4c2ca9d --- /dev/null +++ b/torchlight/torchlight.egg-info/SOURCES.txt @@ -0,0 +1,8 @@ +setup.py +torchlight/__init__.py +torchlight/gpu.py +torchlight/io.py +torchlight.egg-info/PKG-INFO +torchlight.egg-info/SOURCES.txt +torchlight.egg-info/dependency_links.txt +torchlight.egg-info/top_level.txt \ No newline at end of file diff --git a/torchlight/torchlight.egg-info/dependency_links.txt b/torchlight/torchlight.egg-info/dependency_links.txt new file mode 100644 index 0000000..d3f5a12 --- /dev/null +++ b/torchlight/torchlight.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/torchlight/torchlight.egg-info/top_level.txt b/torchlight/torchlight.egg-info/top_level.txt new file mode 100644 index 0000000..09e8d4c --- /dev/null +++ b/torchlight/torchlight.egg-info/top_level.txt @@ -0,0 +1 @@ +torchlight diff --git a/torchlight/torchlight/__init__.py b/torchlight/torchlight/__init__.py index 8b13789..d3f5a12 100644 --- a/torchlight/torchlight/__init__.py +++ b/torchlight/torchlight/__init__.py @@ -1 +1 @@ - + diff --git a/torchlight/torchlight/gpu.py b/torchlight/torchlight/gpu.py index 306c391..e086d4c 100644 --- a/torchlight/torchlight/gpu.py +++ b/torchlight/torchlight/gpu.py @@ -1,35 +1,35 @@ -import os -import torch - - -def visible_gpu(gpus): - """ - set visible gpu. - - can be a single id, or a list - - return a list of new gpus ids - """ - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) - return list(range(len(gpus))) - - -def ngpu(gpus): - """ - count how many gpus used. - """ - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - return len(gpus) - - -def occupy_gpu(gpus=None): - """ - make program appear on nvidia-smi. - """ - if gpus is None: - torch.zeros(1).cuda() - else: - gpus = [gpus] if isinstance(gpus, int) else list(gpus) - for g in gpus: - torch.zeros(1).cuda(g) +import os +import torch + + +def visible_gpu(gpus): + """ + set visible gpu. + + can be a single id, or a list + + return a list of new gpus ids + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus))) + return list(range(len(gpus))) + + +def ngpu(gpus): + """ + count how many gpus used. + """ + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + return len(gpus) + + +def occupy_gpu(gpus=None): + """ + make program appear on nvidia-smi. + """ + if gpus is None: + torch.zeros(1).cuda() + else: + gpus = [gpus] if isinstance(gpus, int) else list(gpus) + for g in gpus: + torch.zeros(1).cuda(g) diff --git a/torchlight/torchlight/io.py b/torchlight/torchlight/io.py index c753ca1..5b43720 100644 --- a/torchlight/torchlight/io.py +++ b/torchlight/torchlight/io.py @@ -1,203 +1,203 @@ -#!/usr/bin/env python -import argparse -import os -import sys -import traceback -import time -import warnings -import pickle -from collections import OrderedDict -import yaml -import numpy as np -# torch -import torch -import torch.nn as nn -import torch.optim as optim -from torch.autograd import Variable - -with warnings.catch_warnings(): - warnings.filterwarnings("ignore",category=FutureWarning) - import h5py - -class IO(): - def __init__(self, work_dir, save_log=True, print_log=True): - self.work_dir = work_dir - self.save_log = save_log - self.print_to_screen = print_log - self.cur_time = time.time() - self.split_timer = {} - self.pavi_logger = None - self.session_file = None - self.model_text = '' - - # PaviLogger is removed in this version - def log(self, *args, **kwargs): - pass - # try: - # if self.pavi_logger is None: - # from torchpack.runner.hooks import PaviLogger - # url = 'http://pavi.parrotsdnn.org/log' - # with open(self.session_file, 'r') as f: - # info = dict( - # session_file=self.session_file, - # session_text=f.read(), - # model_text=self.model_text) - # self.pavi_logger = PaviLogger(url) - # self.pavi_logger.connect(self.work_dir, info=info) - # self.pavi_logger.log(*args, **kwargs) - # except: #pylint: disable=W0702 - # pass - - def load_model(self, model, **model_args): - Model = import_class(model) - model = Model(**model_args) - self.model_text += '\n\n' + str(model) - return model - - def load_weights(self, model, weights_path, ignore_weights=None): - if ignore_weights is None: - ignore_weights = [] - if isinstance(ignore_weights, str): - ignore_weights = [ignore_weights] - - self.print_log('Load weights from {}.'.format(weights_path)) - weights = torch.load(weights_path) - weights = OrderedDict([[k.split('module.')[-1], - v.cpu()] for k, v in weights.items()]) - - # filter weights - for i in ignore_weights: - ignore_name = list() - for w in weights: - if w.find(i) == 0: - ignore_name.append(w) - for n in ignore_name: - weights.pop(n) - self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) - - for w in weights: - self.print_log('Load weights [{}].'.format(w)) - - try: - model.load_state_dict(weights) - except (KeyError, RuntimeError): - state = model.state_dict() - diff = list(set(state.keys()).difference(set(weights.keys()))) - for d in diff: - self.print_log('Can not find weights [{}].'.format(d)) - state.update(weights) - model.load_state_dict(state) - return model - - def save_pkl(self, result, filename): - with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: - pickle.dump(result, f) - - def save_h5(self, result, filename): - with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: - for k in result.keys(): - f[k] = result[k] - - def save_model(self, model, name): - model_path = '{}/{}'.format(self.work_dir, name) - state_dict = model.state_dict() - weights = OrderedDict([[''.join(k.split('module.')), - v.cpu()] for k, v in state_dict.items()]) - torch.save(weights, model_path) - self.print_log('The model has been saved as {}.'.format(model_path)) - - def save_arg(self, arg): - - self.session_file = '{}/config.yaml'.format(self.work_dir) - - # save arg - arg_dict = vars(arg) - if not os.path.exists(self.work_dir): - os.makedirs(self.work_dir) - with open(self.session_file, 'w') as f: - f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) - yaml.dump(arg_dict, f, default_flow_style=False, indent=4) - - def print_log(self, str, print_time=True): - if print_time: - # localtime = time.asctime(time.localtime(time.time())) - str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str - - if self.print_to_screen: - print(str) - if self.save_log: - with open('{}/log.txt'.format(self.work_dir), 'a') as f: - print(str, file=f) - - def init_timer(self, *name): - self.record_time() - self.split_timer = {k: 0.0000001 for k in name} - - def check_time(self, name): - self.split_timer[name] += self.split_time() - - def record_time(self): - self.cur_time = time.time() - return self.cur_time - - def split_time(self): - split_time = time.time() - self.cur_time - self.record_time() - return split_time - - def print_timer(self): - proportion = { - k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) - for k, v in self.split_timer.items() - } - self.print_log('Time consumption:') - for k in proportion: - self.print_log( - '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) - ) - - -def str2bool(v): - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - - -def str2dict(v): - return eval('dict({})'.format(v)) #pylint: disable=W0123 - - -def _import_class_0(name): - components = name.split('.') - mod = __import__(components[0]) - for comp in components[1:]: - mod = getattr(mod, comp) - return mod - - -def import_class(import_str): - mod_str, _sep, class_str = import_str.rpartition('.') - __import__(mod_str) - try: - return getattr(sys.modules[mod_str], class_str) - except AttributeError: - raise ImportError('Class %s cannot be found (%s)' % - (class_str, - traceback.format_exception(*sys.exc_info()))) - - -class DictAction(argparse.Action): - def __init__(self, option_strings, dest, nargs=None, **kwargs): - if nargs is not None: - raise ValueError("nargs not allowed") - super(DictAction, self).__init__(option_strings, dest, **kwargs) - - def __call__(self, parser, namespace, values, option_string=None): - input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 - output_dict = getattr(namespace, self.dest) - for k in input_dict: - output_dict[k] = input_dict[k] - setattr(namespace, self.dest, output_dict) +#!/usr/bin/env python +import argparse +import os +import sys +import traceback +import time +import warnings +import pickle +from collections import OrderedDict +import yaml +import numpy as np +# torch +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +with warnings.catch_warnings(): + warnings.filterwarnings("ignore",category=FutureWarning) + import h5py + +class IO(): + def __init__(self, work_dir, save_log=True, print_log=True): + self.work_dir = work_dir + self.save_log = save_log + self.print_to_screen = print_log + self.cur_time = time.time() + self.split_timer = {} + self.pavi_logger = None + self.session_file = None + self.model_text = '' + + # PaviLogger is removed in this version + def log(self, *args, **kwargs): + pass + # try: + # if self.pavi_logger is None: + # from torchpack.runner.hooks import PaviLogger + # url = 'http://pavi.parrotsdnn.org/log' + # with open(self.session_file, 'r') as f: + # info = dict( + # session_file=self.session_file, + # session_text=f.read(), + # model_text=self.model_text) + # self.pavi_logger = PaviLogger(url) + # self.pavi_logger.connect(self.work_dir, info=info) + # self.pavi_logger.log(*args, **kwargs) + # except: #pylint: disable=W0702 + # pass + + def load_model(self, model, **model_args): + Model = import_class(model) + model = Model(**model_args) + self.model_text += '\n\n' + str(model) + return model + + def load_weights(self, model, weights_path, ignore_weights=None): + if ignore_weights is None: + ignore_weights = [] + if isinstance(ignore_weights, str): + ignore_weights = [ignore_weights] + + self.print_log('Load weights from {}.'.format(weights_path)) + weights = torch.load(weights_path) + weights = OrderedDict([[k.split('module.')[-1], + v.cpu()] for k, v in weights.items()]) + + # filter weights + for i in ignore_weights: + ignore_name = list() + for w in weights: + if w.find(i) == 0: + ignore_name.append(w) + for n in ignore_name: + weights.pop(n) + self.print_log('Filter [{}] remove weights [{}].'.format(i,n)) + + for w in weights: + self.print_log('Load weights [{}].'.format(w)) + + try: + model.load_state_dict(weights) + except (KeyError, RuntimeError): + state = model.state_dict() + diff = list(set(state.keys()).difference(set(weights.keys()))) + for d in diff: + self.print_log('Can not find weights [{}].'.format(d)) + state.update(weights) + model.load_state_dict(state) + return model + + def save_pkl(self, result, filename): + with open('{}/{}'.format(self.work_dir, filename), 'wb') as f: + pickle.dump(result, f) + + def save_h5(self, result, filename): + with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f: + for k in result.keys(): + f[k] = result[k] + + def save_model(self, model, name): + model_path = '{}/{}'.format(self.work_dir, name) + state_dict = model.state_dict() + weights = OrderedDict([[''.join(k.split('module.')), + v.cpu()] for k, v in state_dict.items()]) + torch.save(weights, model_path) + self.print_log('The model has been saved as {}.'.format(model_path)) + + def save_arg(self, arg): + + self.session_file = '{}/config.yaml'.format(self.work_dir) + + # save arg + arg_dict = vars(arg) + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + with open(self.session_file, 'w') as f: + f.write('# command line: {}\n\n'.format(' '.join(sys.argv))) + yaml.dump(arg_dict, f, default_flow_style=False, indent=4) + + def print_log(self, str, print_time=True): + if print_time: + # localtime = time.asctime(time.localtime(time.time())) + str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str + + if self.print_to_screen: + print(str) + if self.save_log: + with open('{}/log.txt'.format(self.work_dir), 'a') as f: + print(str, file=f) + + def init_timer(self, *name): + self.record_time() + self.split_timer = {k: 0.0000001 for k in name} + + def check_time(self, name): + self.split_timer[name] += self.split_time() + + def record_time(self): + self.cur_time = time.time() + return self.cur_time + + def split_time(self): + split_time = time.time() - self.cur_time + self.record_time() + return split_time + + def print_timer(self): + proportion = { + k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values())))) + for k, v in self.split_timer.items() + } + self.print_log('Time consumption:') + for k in proportion: + self.print_log( + '\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k]) + ) + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2dict(v): + return eval('dict({})'.format(v)) #pylint: disable=W0123 + + +def _import_class_0(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def import_class(import_str): + mod_str, _sep, class_str = import_str.rpartition('.') + __import__(mod_str) + try: + return getattr(sys.modules[mod_str], class_str) + except AttributeError: + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +class DictAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super(DictAction, self).__init__(option_strings, dest, **kwargs) + + def __call__(self, parser, namespace, values, option_string=None): + input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123 + output_dict = getattr(namespace, self.dest) + for k in input_dict: + output_dict[k] = input_dict[k] + setattr(namespace, self.dest, output_dict)