diff --git a/PyTorch/build-in/Classification/Sequencer2D/model/__init__.py b/PyTorch/build-in/Classification/Sequencer2D/model/__init__.py new file mode 100644 index 000000000..251f6696f --- /dev/null +++ b/PyTorch/build-in/Classification/Sequencer2D/model/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2022. Yuki Tatsunami +# Licensed under the Apache License, Version 2.0 (the "License"); + +from .vanilla_sequencer import * +from .two_dim_sequencer import * diff --git a/PyTorch/build-in/Classification/Sequencer2D/model/layers.py b/PyTorch/build-in/Classification/Sequencer2D/model/layers.py new file mode 100644 index 000000000..6abb1abbf --- /dev/null +++ b/PyTorch/build-in/Classification/Sequencer2D/model/layers.py @@ -0,0 +1,253 @@ +# Copyright (c) 2022. Yuki Tatsunami +# Licensed under the Apache License, Version 2.0 (the "License"); + +from functools import partial +from typing import Tuple + +import torch +from timm.models.layers import DropPath, Mlp, PatchEmbed as TimmPatchEmbed + +from torch import nn, _assert, Tensor + +from utils.helpers import to_2tuple + + +class RNNIdentity(nn.Module): + def __init__(self, *args, **kwargs): + super(RNNIdentity, self).__init__() + + def forward(self, x: Tensor) -> Tuple[Tensor, None]: + return x, None + + +class RNNBase(nn.Module): + + def __init__(self, input_size, hidden_size=None, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True): + super().__init__() + self.rnn = RNNIdentity() + + def forward(self, x): + B, H, W, C = x.shape + x, _ = self.rnn(x.view(B, -1, C)) + return x.view(B, H, W, -1) + + +class RNN(RNNBase): + + def __init__(self, input_size, hidden_size=None, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + nonlinearity="tanh"): + super().__init__(input_size, hidden_size, num_layers, bias, bidirectional) + self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True, + bias=bias, bidirectional=bidirectional, nonlinearity=nonlinearity) + + +class GRU(RNNBase): + + def __init__(self, input_size, hidden_size=None, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True): + super().__init__(input_size, hidden_size, num_layers, bias, bidirectional) + self.rnn = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, + bias=bias, bidirectional=bidirectional) + + +class LSTM(RNNBase): + + def __init__(self, input_size, hidden_size=None, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True): + super().__init__(input_size, hidden_size, num_layers, bias, bidirectional) + self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, + bias=bias, bidirectional=bidirectional) + + +class RNN2DBase(nn.Module): + + def __init__(self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + union="cat", with_fc=True): + super().__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = 2 * hidden_size if bidirectional else hidden_size + self.union = union + + self.with_vertical = True + self.with_horizontal = True + self.with_fc = with_fc + + if with_fc: + if union == "cat": + self.fc = nn.Linear(2 * self.output_size, input_size) + elif union == "add": + self.fc = nn.Linear(self.output_size, input_size) + elif union == "vertical": + self.fc = nn.Linear(self.output_size, input_size) + self.with_horizontal = False + elif union == "horizontal": + self.fc = nn.Linear(self.output_size, input_size) + self.with_vertical = False + else: + raise ValueError("Unrecognized union: " + union) + elif union == "cat": + pass + if 2 * self.output_size != input_size: + raise ValueError(f"The output channel {2 * self.output_size} is different from the input channel {input_size}.") + elif union == "add": + pass + if self.output_size != input_size: + raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.") + elif union == "vertical": + if self.output_size != input_size: + raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.") + self.with_horizontal = False + elif union == "horizontal": + if self.output_size != input_size: + raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.") + self.with_vertical = False + else: + raise ValueError("Unrecognized union: " + union) + + self.rnn_v = RNNIdentity() + self.rnn_h = RNNIdentity() + + def forward(self, x): + B, H, W, C = x.shape + + if self.with_vertical: + v = x.permute(0, 2, 1, 3) + v = v.reshape(-1, H, C) + v, _ = self.rnn_v(v) + v = v.reshape(B, W, H, -1) + v = v.permute(0, 2, 1, 3) + + if self.with_horizontal: + h = x.reshape(-1, W, C) + h, _ = self.rnn_h(h) + h = h.reshape(B, H, W, -1) + + if self.with_vertical and self.with_horizontal: + if self.union == "cat": + x = torch.cat([v, h], dim=-1) + else: + x = v + h + elif self.with_vertical: + x = v + elif self.with_horizontal: + x = h + + if self.with_fc: + x = self.fc(x) + + return x + + +class RNN2D(RNN2DBase): + + def __init__(self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + union="cat", with_fc=True, nonlinearity="tanh"): + super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc) + if self.with_vertical: + self.rnn_v = nn.RNN(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional, nonlinearity=nonlinearity) + if self.with_horizontal: + self.rnn_h = nn.RNN(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional, nonlinearity=nonlinearity) + + +class LSTM2D(RNN2DBase): + + def __init__(self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + union="cat", with_fc=True): + super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc) + if self.with_vertical: + self.rnn_v = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional) + if self.with_horizontal: + self.rnn_h = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional) + + +class GRU2D(RNN2DBase): + + def __init__(self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + union="cat", with_fc=True): + super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc) + if self.with_vertical: + self.rnn_v = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional) + if self.with_horizontal: + self.rnn_h = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional) + + +class VanillaSequencerBlock(nn.Module): + def __init__(self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM, mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, + num_layers=1, bidirectional=True, drop=0., drop_path=0.): + super().__init__() + channels_dim = int(mlp_ratio * dim) + self.norm1 = norm_layer(dim) + self.rnn_tokens = rnn_layer(dim, hidden_size, num_layers=num_layers, bidirectional=bidirectional) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.rnn_tokens(self.norm1(x))) + x = x + self.drop_path(self.mlp_channels(self.norm2(x))) + return x + + +class Sequencer2DBlock(nn.Module): + def __init__(self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, + num_layers=1, bidirectional=True, union="cat", with_fc=True, + drop=0., drop_path=0.): + super().__init__() + channels_dim = int(mlp_ratio * dim) + self.norm1 = norm_layer(dim) + self.rnn_tokens = rnn_layer(dim, hidden_size, num_layers=num_layers, bidirectional=bidirectional, + union=union, with_fc=with_fc) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.rnn_tokens(self.norm1(x))) + x = x + self.drop_path(self.mlp_channels(self.norm2(x))) + return x + + +class PatchEmbed(TimmPatchEmbed): + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + else: + x = x.permute(0, 2, 3, 1) # BCHW -> BHWC + x = self.norm(x) + return x + + +class Shuffle(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + if self.training: + B, H, W, C = x.shape + r = torch.randperm(H * W) + x = x.reshape(B, -1, C) + x = x[:, r, :].reshape(B, H, W, -1) + return x + + +class Downsample2D(nn.Module): + def __init__(self, input_dim, output_dim, patch_size): + super().__init__() + self.down = nn.Conv2d(input_dim, output_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = x.permute(0, 3, 1, 2) + x = self.down(x) + x = x.permute(0, 2, 3, 1) + return x diff --git a/PyTorch/build-in/Classification/Sequencer2D/model/two_dim_sequencer.py b/PyTorch/build-in/Classification/Sequencer2D/model/two_dim_sequencer.py new file mode 100644 index 000000000..70411fd84 --- /dev/null +++ b/PyTorch/build-in/Classification/Sequencer2D/model/two_dim_sequencer.py @@ -0,0 +1,447 @@ +# Copyright (c) 2022. Yuki Tatsunami +# Licensed under the Apache License, Version 2.0 (the "License"); + +import math +from functools import partial +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT +from timm.models.layers import lecun_normal_, Mlp +from timm.models.helpers import build_model_with_cfg, named_apply +from timm.models.registry import register_model +from torch import nn + +from model.layers import Sequencer2DBlock, PatchEmbed, LSTM2D, GRU2D, RNN2D, Downsample2D + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': DEFAULT_CROP_PCT, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.proj', 'classifier': 'head', + **kwargs + } + + +def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False): + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + if flax: + # Flax defaults + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.RNN, nn.GRU, nn.LSTM)): + stdv = 1.0 / math.sqrt(module.hidden_size) + for weight in module.parameters(): + nn.init.uniform_(weight, -stdv, stdv) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def get_stage(index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer, + norm_layer, act_layer, num_layers, bidirectional, union, + with_fc, drop=0., drop_path_rate=0., **kwargs): + assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios) + blocks = [] + for block_idx in range(layers[index]): + drop_path = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) + blocks.append(block_layer(embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index], + rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, + act_layer=act_layer, num_layers=num_layers, + bidirectional=bidirectional, union=union, with_fc=with_fc, + drop=drop, drop_path=drop_path)) + + if index < len(embed_dims) - 1: + blocks.append(Downsample2D(embed_dims[index], embed_dims[index + 1], patch_sizes[index + 1])) + + blocks = nn.Sequential(*blocks) + return blocks + + +class Sequencer2D(nn.Module): + def __init__( + self, + num_classes=1000, + img_size=224, + in_chans=3, + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + block_layer=Sequencer2DBlock, + rnn_layer=LSTM2D, + mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + num_rnn_layers=1, + bidirectional=True, + union="cat", + with_fc=True, + drop_rate=0., + drop_path_rate=0., + nlhb=False, + stem_norm=False, + **kwargs + ): + super().__init__() + self.num_classes = num_classes + self.num_features = embed_dims[0] # num_features for consistency with other models + self.embed_dims = embed_dims + self.stem = PatchEmbed( + img_size=img_size, patch_size=patch_sizes[0], in_chans=in_chans, + embed_dim=embed_dims[0], norm_layer=norm_layer if stem_norm else None, + flatten=False) + + self.blocks = nn.Sequential(*[ + get_stage( + i, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer=block_layer, + rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, + num_layers=num_rnn_layers, bidirectional=bidirectional, + union=union, with_fc=with_fc, drop=drop_rate, drop_path_rate=drop_path_rate, + ) + for i, _ in enumerate(embed_dims)]) + + self.norm = norm_layer(embed_dims[-1]) + self.head = nn.Linear(embed_dims[-1], self.num_classes) if num_classes > 0 else nn.Identity() + + self.init_weights(nlhb=nlhb) + + def init_weights(self, nlhb=False): + head_bias = -math.log(self.num_classes) if nlhb else 0. + named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + x = self.blocks(x) + x = self.norm(x) + x = x.mean(dim=(1, 2)) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + return state_dict + + +default_cfgs = dict( + sequencer2d_s=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_s.pth"), + sequencer2d_m=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_m.pth"), + sequencer2d_l=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_l.pth"), + sequencer2d_l_d4_3x=_cfg(), + sequencer2d_s_unidirectional=_cfg(), + sequencer2d_s_add=_cfg(), + sequencer2d_s_h2x=_cfg(), + sequencer2d_s_without_fc=_cfg(), + sequencer2d_vertical=_cfg(), + sequencer2d_s_horizontal=_cfg(), + gru_sequencer2d_s=_cfg(), + rnn_sequencer2d_s=_cfg(), +) + + +def _create_sequencer2d(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Sequencer2D models.') + + model = build_model_with_cfg( + Sequencer2D, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +# main + +@register_model +def sequencer2d_s(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_s', pretrained=pretrained, **model_args) + return model + + +@register_model +def sequencer2d_m(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 14, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_m', pretrained=pretrained, **model_args) + return model + + +@register_model +def sequencer2d_l(pretrained=False, **kwargs): + model_args = dict( + layers=[8, 8, 16, 4], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_l', pretrained=pretrained, **model_args) + return model + + +# high resolution + +@register_model +def sequencer2d_s_392(pretrained=False, **kwargs): + model_args = dict( + img_size=392, + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_s', pretrained=pretrained, **model_args) + return model + + +@register_model +def sequencer2d_m_392(pretrained=False, **kwargs): + model_args = dict( + img_size=392, + layers=[4, 3, 14, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_m', pretrained=pretrained, **model_args) + return model + + +@register_model +def sequencer2d_l_392(pretrained=False, **kwargs): + model_args = dict( + img_size=392, + layers=[8, 8, 16, 4], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_l', pretrained=pretrained, **model_args) + return model + + +# ablation + +@register_model +def sequencer2d_s_unidirectional(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=False, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_s_unidirectional', pretrained=pretrained, **model_args) + return model + + +@register_model +def sequencer2d_s_add(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="add", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_s_add', pretrained=pretrained, **model_args) + return model + + +@register_model +def sequencer2d_s_h2x(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[96, 192, 192, 192], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_s_h2x', pretrained=pretrained, **model_args) + return model + + +@register_model +def sequencer2d_s_without_fc(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=False, + **kwargs) + model = _create_sequencer2d('sequencer2d_s_without_fc', pretrained=pretrained, **model_args) + return model + + +@register_model +def sequencer2d_vertical(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[96, 192, 192, 192], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="vertical", + with_fc=False, + **kwargs) + model = _create_sequencer2d('sequencer2d_vertical', pretrained=pretrained, **model_args) + return model + + +@register_model +def sequencer2d_s_horizontal(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[96, 192, 192, 192], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="horizontal", + with_fc=False, + **kwargs) + model = _create_sequencer2d('sequencer2d_s_horizontal', pretrained=pretrained, **model_args) + return model + + +# option + +@register_model +def gru_sequencer2d_s(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=GRU2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('gru_sequencer2d_s', pretrained=pretrained, **model_args) + return model + + +@register_model +def rnn_sequencer2d_s(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=RNN2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('rnn_sequencer2d_s', pretrained=pretrained, **model_args) + return model + + +@register_model +def sequencer2d_l_d4_3x(pretrained=False, **kwargs): + model_args = dict( + layers=[8, 8, 16, 4], + patch_sizes=[7, 2, 1, 1], + embed_dims=[256, 512, 512, 512], + hidden_sizes=[64, 128, 128, 128], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_l_d4_3x', pretrained=pretrained, **model_args) + return model diff --git a/PyTorch/build-in/Classification/Sequencer2D/model/vanilla_sequencer.py b/PyTorch/build-in/Classification/Sequencer2D/model/vanilla_sequencer.py new file mode 100644 index 000000000..a9de5ea1f --- /dev/null +++ b/PyTorch/build-in/Classification/Sequencer2D/model/vanilla_sequencer.py @@ -0,0 +1,236 @@ +# Copyright (c) 2022. Yuki Tatsunami +# Licensed under the Apache License, Version 2.0 (the "License"); + +import math +from functools import partial + +import torch +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT +from timm.models.layers import Mlp, lecun_normal_, trunc_normal_ +from timm.models.helpers import build_model_with_cfg, named_apply +from timm.models.registry import register_model +from torch import nn + +from model.layers import LSTM, VanillaSequencerBlock, PatchEmbed, Downsample2D, Shuffle + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': DEFAULT_CROP_PCT, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.proj', 'classifier': 'head', + **kwargs + } + + +def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False): + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + if flax: + # Flax defaults + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.RNN, nn.GRU, nn.LSTM)): + stdv = 1.0 / math.sqrt(module.hidden_size) + for weight in module.parameters(): + nn.init.uniform_(weight, -stdv, stdv) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def get_stage(index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer, + norm_layer, act_layer, num_layers, bidirectional, drop=0., drop_path_rate=0., **kwargs): + assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios) + blocks = [] + for block_idx in range(layers[index]): + drop_path = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) + blocks.append(block_layer(embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index], + rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, + act_layer=act_layer, num_layers=num_layers, + bidirectional=bidirectional, drop=drop, drop_path=drop_path)) + + if index < len(embed_dims) - 1: + blocks.append(Downsample2D(embed_dims[index], embed_dims[index + 1], patch_sizes[index + 1])) + + blocks = nn.Sequential(*blocks) + return blocks + + +class VanillaSequencer(nn.Module): + def __init__( + self, + num_classes=1000, + img_size=224, + in_chans=3, + layers=[4, 3, 8, 3], + patch_sizes=[14, 1, 1, 1], + embed_dims=[384, 384, 384, 384], + hidden_sizes=[192, 192, 192, 192], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + block_layer=VanillaSequencerBlock, + rnn_layer=LSTM, + mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + num_rnn_layers=1, + bidirectional=True, + shuffle=False, + ape=False, + drop_rate=0., + drop_path_rate=0., + nlhb=False, + stem_norm=False, + **kwargs + ): + super().__init__() + self.num_classes = num_classes + self.num_features = embed_dims[0] # num_features for consistency with other models + self.embed_dims = embed_dims + self.stem = PatchEmbed( + img_size=img_size, patch_size=patch_sizes[0], in_chans=in_chans, + embed_dim=embed_dims[0], norm_layer=norm_layer if stem_norm else None, + flatten=False) + self.shuffle = shuffle + + if self.shuffle: + self.shuffle_patches = Shuffle() + + # absolute position embedding + self.ape = ape + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, self.stem.grid_size[0], self.stem.grid_size[1], embed_dims[0])) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.blocks = nn.Sequential(*[ + get_stage( + i, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer=block_layer, + rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, + num_layers=num_rnn_layers, bidirectional=bidirectional, drop=drop_rate, drop_path_rate=drop_path_rate, + ) + for i, _ in enumerate(embed_dims)]) + + self.norm = norm_layer(embed_dims[-1]) + self.head = nn.Linear(embed_dims[-1], self.num_classes) if num_classes > 0 else nn.Identity() + + self.init_weights(nlhb=nlhb) + + def init_weights(self, nlhb=False): + head_bias = -math.log(self.num_classes) if nlhb else 0. + named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + if self.ape: + x = x + self.absolute_pos_embed + if self.shuffle: + x = self.shuffle_patches(x) + x = self.blocks(x) + x = self.norm(x) + x = x.mean(dim=(1, 2)) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + return state_dict + + +default_cfgs = dict( + v_sequencer_s=_cfg(), + v_sequencer_s_h=_cfg(), + v_sequencer_s_pe=_cfg(), +) + + +def _create_vanilla_sequencer(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for VanillaSequencer models.') + + model = build_model_with_cfg( + VanillaSequencer, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +@register_model +def v_sequencer_s(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 8, 3], + patch_sizes=[14, 1, 1, 1], + embed_dims=[384, 384, 384, 384], + hidden_sizes=[192, 192, 192, 192], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM, + bidirectional=True, + shuffle=False, + ape=False, + **kwargs) + model = _create_vanilla_sequencer('v_sequencer_s', pretrained=pretrained, **model_args) + return model + + +@register_model +def v_sequencer_s_h(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[96, 192, 192, 192], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM, + bidirectional=True, + shuffle=False, + ape=False, + **kwargs) + model = _create_vanilla_sequencer('v_sequencer_s_h', pretrained=pretrained, **model_args) + return model + + +@register_model +def v_sequencer_s_pe(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 8, 3], + patch_sizes=[14, 1, 1, 1], + embed_dims=[384, 384, 384, 384], + hidden_sizes=[192, 192, 192, 192], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM, + bidirectional=True, + shuffle=False, + ape=True, + **kwargs) + model = _create_vanilla_sequencer('v_sequencer_s_pe', pretrained=pretrained, **model_args) + return model diff --git a/PyTorch/build-in/Classification/Sequencer2D/readme.md b/PyTorch/build-in/Classification/Sequencer2D/readme.md new file mode 100644 index 000000000..eb7772f82 --- /dev/null +++ b/PyTorch/build-in/Classification/Sequencer2D/readme.md @@ -0,0 +1,65 @@ +```markdown +## 1. 模型链接 +- 原始仓库链接: +https://github.com/huggingface/pytorch-image-models?tab=readme-ov-file#models + +## 2. 快速开始 + +使用本模型执行训练的主要流程如下: + +1. **基础环境安装**:介绍训练前需要完成的基础环境检查和安装。 +2. **获取数据集**:介绍如何获取训练所需的数据集。 +3. **构建环境**:介绍如何构建模型运行所需要的环境。 +4. **启动训练**:介绍如何运行训练。 + +### 2.1 基础环境安装 + +请参考主仓库的基础环境安装章节,完成训练前的基础环境检查和安装(如驱动、固件等)。 + +### 2.2 准备数据集 + +#### 2.2.1 获取数据集 + +训练使用 **CIFAR-100** 数据集。该数据集为开源数据集,包含 100 个类别的 60000 张彩色图像。 + +#### 2.2.2 处理数据集 + +请确保数据集已下载并解压。根据训练脚本的默认配置,建议将数据集存放在模型目录的上级 `data` 目录中(即 `../data`),或者根据实际路径修改训练命令中的 `--datapath` 参数。 + +### 2.3 构建环境 + +所使用的环境下需包含 PyTorch 框架虚拟环境。 + +1. 执行以下命令,启动虚拟环境(根据实际环境名称修改): + + ```bash + conda activate torch_env_py310 + +``` + +2. 安装 Python 依赖。确保已安装项目所需的依赖包: +```bash +pip install -r requirements_exact.txt + +``` + + + +### 2.4 启动训练 + +1. 在构建好的环境中,进入模型训练脚本所在目录。 + +2. 运行训练。该模型支持单机单卡训练。 +执行以下命令启动训练(使用 CIFAR-100 数据集,Batch Size 为 128): +```bash +python weloTrainStep.py \ + --name train \ + --arch sequencer2D \ + --print_freq 1 \ + --steps 100 \ + --dataset cifar100 \ + --datapath ../data \ + --batch_size 32 \ + --epochs 100 + +``` diff --git a/PyTorch/build-in/Classification/Sequencer2D/requirements.txt b/PyTorch/build-in/Classification/Sequencer2D/requirements.txt new file mode 100644 index 000000000..7394b3319 --- /dev/null +++ b/PyTorch/build-in/Classification/Sequencer2D/requirements.txt @@ -0,0 +1,89 @@ +addict==2.4.0 +aliyun-python-sdk-core==2.16.0 +aliyun-python-sdk-kms==2.16.5 +anyio==4.11.0 +astunparse==1.6.3 +certifi==2024.12.14 +cffi==2.0.0 +charset-normalizer==3.4.1 +click==8.3.1 +colorama==0.4.6 +contourpy==1.3.2 +crcmod==1.7 +cryptography==46.0.3 +cycler==0.12.1 +einops==0.8.1 +exceptiongroup==1.3.1 +filelock==3.14.0 +fonttools==4.60.1 +fsspec==2024.12.0 +future @ file:///croot/future_1730902796226/work +git-filter-repo==2.47.0 +h11==0.16.0 +hf-xet==1.2.0 +httpcore==1.0.9 +httpx==0.28.1 +huggingface_hub==1.1.5 +idna==3.10 +inplace-abn @ git+https://github.com/mapillary/inplace_abn.git@b50bfe9c7cd7116a3ab091a352b48d6ba5ee701c +Jinja2==3.1.5 +jmespath==0.10.0 +joblib==1.5.2 +kiwisolver==1.4.9 +Markdown==3.10 +markdown-it-py==4.0.0 +MarkupSafe==3.0.2 +matplotlib==3.10.7 +mdurl==0.1.2 +mmdet==3.3.0 +mmengine==0.10.7 +model-index==0.1.11 +mpmath==1.3.0 +networkx==3.4.2 +numpy==1.23.5 +opencv-python==4.12.0.88 +opendatalab==0.0.10 +openmim==0.3.9 +openxlab==0.1.3 +ordered-set==4.1.0 +oss2==2.17.0 +packaging @ file:///croot/packaging_1734472117206/work +pandas==2.3.3 +pillow==11.1.0 +platformdirs==4.5.1 +pycocotools==2.0.11 +pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work +pycryptodome==3.23.0 +Pygments==2.19.2 +pyparsing==3.2.5 +python-dateutil==2.9.0.post0 +pytz==2023.4 +PyYAML @ file:///croot/pyyaml_1728657952215/work +requests==2.28.2 +rich==13.4.2 +safetensors==0.7.0 +scikit-learn==1.7.2 +scipy==1.15.3 +shapely==2.1.2 +shellingham==1.5.4 +six @ file:///tmp/build/80754af9/six_1644875935023/work +sniffio==1.3.1 +sympy==1.13.3 +tabulate==0.9.0 +termcolor==3.2.0 +terminaltables==3.1.10 +threadpoolctl==3.6.0 +timm==1.0.22 +tomli==2.3.0 +torch @ file:///apps/torch-2.4.0a0%2Bgit4451b0e-cp310-cp310-linux_x86_64.whl#sha256=2e472c916044cac5a1a0e0d8b0e12bb943d8522b24ff826c8014dd444dccd378 +torch_sdaa @ file:///apps/torch_sdaa-2.0.0-cp310-cp310-linux_x86_64.whl#sha256=5aa57889b002e1231fbf806642e1353bfa016297bc25178396e89adc2b1f92e7 +torchaudio @ file:///apps/torchaudio-2.0.2%2Bda3eb8d-cp310-cp310-linux_x86_64.whl#sha256=46525c02fb7eaa8dafea860428de3d01e437ba8d6ff2cc228d7c71975ac4054b +torchdata @ file:///apps/torchdata-0.6.1%2Be1feeb2-py3-none-any.whl#sha256=aa2dc1a7732ea68adfad186978049bf68cc1afdbbdd1e17a8024227ab770e433 +torchtext @ file:///apps/torchtext-0.15.2a0%2B4571036-cp310-cp310-linux_x86_64.whl#sha256=7e42c684ba366f97b59ec37488bf95e416cce3892b6589200d2b3ad159ee5788 +torchvision @ file:///apps/torchvision-0.15.1a0%2B42759b1-cp310-cp310-linux_x86_64.whl#sha256=4b904db2d50102415536bc764bbc31c669b90b1b014f90964e9eccaadb2fd9eb +tqdm==4.65.2 +typer-slim==0.20.0 +typing_extensions==4.15.0 +tzdata==2025.2 +urllib3==1.26.20 +yapf==0.43.0 diff --git a/PyTorch/build-in/Classification/Sequencer2D/sequencer2D.py b/PyTorch/build-in/Classification/Sequencer2D/sequencer2D.py new file mode 100644 index 000000000..ed2ce6a22 --- /dev/null +++ b/PyTorch/build-in/Classification/Sequencer2D/sequencer2D.py @@ -0,0 +1,78 @@ +# model_factory.py + +# sequencer +from model.vanilla_sequencer import ( + v_sequencer_s, + v_sequencer_s_h, + v_sequencer_s_pe, +) + +from model.two_dim_sequencer import ( + sequencer2d_m, + sequencer2d_l, + sequencer2d_s_392, + sequencer2d_m_392, + sequencer2d_l_392, + sequencer2d_s_unidirectional, + sequencer2d_s_add, + sequencer2d_s_h2x, + sequencer2d_s_without_fc, + sequencer2d_vertical, + sequencer2d_s_horizontal, + gru_sequencer2d_s, + rnn_sequencer2d_s, + sequencer2d_l_d4_3x, +) + +_MODEL_TABLE = { + # vanilla sequencer + "v_sequencer_s": v_sequencer_s, + "v_sequencer_s_h": v_sequencer_s_h, + "v_sequencer_s_pe": v_sequencer_s_pe, + + # 2d sequencer + "sequencer2d_m": sequencer2d_m, + "sequencer2d_l": sequencer2d_l, + "sequencer2d_s_392": sequencer2d_s_392, + "sequencer2d_m_392": sequencer2d_m_392, + "sequencer2d_l_392": sequencer2d_l_392, + "sequencer2d_s_unidirectional": sequencer2d_s_unidirectional, + "sequencer2d_s_add": sequencer2d_s_add, + "sequencer2d_s_h2x": sequencer2d_s_h2x, + "sequencer2d_s_without_fc": sequencer2d_s_without_fc, + "sequencer2d_vertical": sequencer2d_vertical, + "sequencer2d_s_horizontal": sequencer2d_s_horizontal, + "gru_sequencer2d_s": gru_sequencer2d_s, + "rnn_sequencer2d_s": rnn_sequencer2d_s, + "sequencer2d_l_d4_3x": sequencer2d_l_d4_3x, + +} + + +def Model(num_classes=100, model_name=None, **kwargs): + """ + Unified model entry (NO timm). + + Args: + num_classes (int): number of classes (可直接用位置参数传) + model_name (str, optional): key in _MODEL_TABLE, 默认使用 'sequencer2d_s_392' + **kwargs: 传给模型构造函数 + + Returns: + torch.nn.Module + """ + if model_name is None: + model_name = "sequencer2d_s_392" + + if model_name not in _MODEL_TABLE: + raise ValueError( + f"Unknown model '{model_name}'. " + f"Available models: {list(_MODEL_TABLE.keys())}" + ) + + return _MODEL_TABLE[model_name]( + pretrained=False, + num_classes=num_classes, + **kwargs + ) + diff --git a/PyTorch/build-in/Classification/Sequencer2D/utils/__init__.py b/PyTorch/build-in/Classification/Sequencer2D/utils/__init__.py new file mode 100644 index 000000000..f04000a2f --- /dev/null +++ b/PyTorch/build-in/Classification/Sequencer2D/utils/__init__.py @@ -0,0 +1 @@ +from .timm import * \ No newline at end of file diff --git a/PyTorch/build-in/Classification/Sequencer2D/utils/helpers.py b/PyTorch/build-in/Classification/Sequencer2D/utils/helpers.py new file mode 100644 index 000000000..88f5dadbe --- /dev/null +++ b/PyTorch/build-in/Classification/Sequencer2D/utils/helpers.py @@ -0,0 +1,76 @@ +# Copyright (c) 2022. Yuki Tatsunami +# Licensed under the Apache License, Version 2.0 (the "License"); + +from itertools import repeat +import functools +import collections + +import torch +from torch import nn + + +def rsetattr(obj, attr, val): + pre, _, post = attr.rpartition('.') + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + + +def rgetattr(obj, attr, *args): + def _getattr(obj, attr): + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split('.')) + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def train_rnn(model): + for m in model.children(): + if isinstance(m, (nn.LSTM, nn.GRU, nn.RNN)): + m.train() + else: + train_rnn(m) + + +def normalize_fn(tensor, mean, std): + mean = mean[None, :, None, None] + std = std[None, :, None, None] + return tensor.sub(mean).div(std) + + +class NormalizeByChannelMeanStd(nn.Module): + def __init__(self, mean, std): + super(NormalizeByChannelMeanStd, self).__init__() + if not isinstance(mean, torch.Tensor): + mean = torch.tensor(mean) + if not isinstance(std, torch.Tensor): + std = torch.tensor(std) + self.register_buffer("mean", mean) + self.register_buffer("std", std) + + def forward(self, tensor): + return normalize_fn(tensor, self.mean, self.std) + + def extra_repr(self): + return 'mean={}, std={}'.format(self.mean, self.std) + +class WithNone: + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback): + pass diff --git a/PyTorch/build-in/Classification/Sequencer2D/utils/timm/__init__.py b/PyTorch/build-in/Classification/Sequencer2D/utils/timm/__init__.py new file mode 100644 index 000000000..269d3443c --- /dev/null +++ b/PyTorch/build-in/Classification/Sequencer2D/utils/timm/__init__.py @@ -0,0 +1,2 @@ +from .checkpoint_saver import * +from .summary import * \ No newline at end of file diff --git a/PyTorch/build-in/Classification/Sequencer2D/utils/timm/checkpoint_saver.py b/PyTorch/build-in/Classification/Sequencer2D/utils/timm/checkpoint_saver.py new file mode 100644 index 000000000..ea657b0a0 --- /dev/null +++ b/PyTorch/build-in/Classification/Sequencer2D/utils/timm/checkpoint_saver.py @@ -0,0 +1,163 @@ + +""" Checkpoint Saver + +Track top-n training checkpoints and maintain recovery checkpoints on specified intervals. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import glob +import operator +import os +import logging + +import torch + +from timm.utils.model import unwrap_model, get_state_dict + + +_logger = logging.getLogger(__name__) + + +class CheckpointSaver: + def __init__( + self, + model, + optimizer, + args=None, + model_ema=None, + amp_scaler=None, + checkpoint_prefix='checkpoint', + recovery_prefix='recovery', + checkpoint_dir='', + recovery_dir='', + decreasing=False, + max_history=10, + unwrap_fn=unwrap_model, + log_clearml=False, + log_s3=False, + ): + + # objects to save state_dicts of + self.model = model + self.optimizer = optimizer + self.args = args + self.model_ema = model_ema + self.amp_scaler = amp_scaler + + # state + self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness + self.best_epoch = None + self.best_metric = None + self.curr_recovery_file = '' + self.last_recovery_file = '' + + # config + self.checkpoint_dir = checkpoint_dir + self.recovery_dir = recovery_dir + self.save_prefix = checkpoint_prefix + self.recovery_prefix = recovery_prefix + self.extension = '.pth.tar' + self.decreasing = decreasing # a lower metric is better if True + self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs + self.max_history = max_history + self.unwrap_fn = unwrap_fn + self.log_s3 = log_clearml and log_s3 + assert self.max_history >= 1 + + if self.log_s3: + from clearml import Task + self.task = Task.current_task() + + def save_checkpoint(self, epoch, metric=None): + assert epoch >= 0 + tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) + last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) + self._save(tmp_save_path, epoch, metric) + if os.path.exists(last_save_path): + os.unlink(last_save_path) # required for Windows support. + os.rename(tmp_save_path, last_save_path) + worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None + if (len(self.checkpoint_files) < self.max_history + or metric is None or self.cmp(metric, worst_file[1])): + if len(self.checkpoint_files) >= self.max_history: + self._cleanup_checkpoints(1) + filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension + save_path = os.path.join(self.checkpoint_dir, filename) + os.link(last_save_path, save_path) + self.checkpoint_files.append((save_path, metric)) + self.checkpoint_files = sorted( + self.checkpoint_files, key=lambda x: x[1], + reverse=not self.decreasing) # sort in descending order if a lower metric is not better + + checkpoints_str = "Current checkpoints:\n" + for c in self.checkpoint_files: + checkpoints_str += ' {}\n'.format(c) + _logger.info(checkpoints_str) + + if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): + self.best_epoch = epoch + self.best_metric = metric + best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension) + if os.path.exists(best_save_path): + os.unlink(best_save_path) + os.link(last_save_path, best_save_path) + if self.log_s3: + self.task.update_output_model(best_save_path) + if self.log_s3: + self.task.update_output_model(last_save_path) + + return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) + + def _save(self, save_path, epoch, metric=None): + save_state = { + 'epoch': epoch, + 'arch': type(self.model).__name__.lower(), + 'state_dict': get_state_dict(self.model, self.unwrap_fn), + 'optimizer': self.optimizer.state_dict(), + 'version': 2, # version < 2 increments epoch before save + } + if self.args is not None: + save_state['arch'] = self.args.model + save_state['args'] = self.args + if self.amp_scaler is not None: + save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict() + if self.model_ema is not None: + save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn) + if metric is not None: + save_state['metric'] = metric + torch.save(save_state, save_path) + + def _cleanup_checkpoints(self, trim=0): + trim = min(len(self.checkpoint_files), trim) + delete_index = self.max_history - trim + if delete_index < 0 or len(self.checkpoint_files) <= delete_index: + return + to_delete = self.checkpoint_files[delete_index:] + for d in to_delete: + try: + _logger.debug("Cleaning checkpoint: {}".format(d)) + os.remove(d[0]) + except Exception as e: + _logger.error("Exception '{}' while deleting checkpoint".format(e)) + self.checkpoint_files = self.checkpoint_files[:delete_index] + + def save_recovery(self, epoch, batch_idx=0): + assert epoch >= 0 + filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension + save_path = os.path.join(self.recovery_dir, filename) + self._save(save_path, epoch) + if os.path.exists(self.last_recovery_file): + try: + _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) + os.remove(self.last_recovery_file) + except Exception as e: + _logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file)) + self.last_recovery_file = self.curr_recovery_file + self.curr_recovery_file = save_path + + def find_recovery(self): + recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix) + files = glob.glob(recovery_path + '*' + self.extension) + files = sorted(files) + return files[0] if len(files) else '' diff --git a/PyTorch/build-in/Classification/Sequencer2D/utils/timm/dataset_factory.py b/PyTorch/build-in/Classification/Sequencer2D/utils/timm/dataset_factory.py new file mode 100644 index 000000000..dbac8b6b3 --- /dev/null +++ b/PyTorch/build-in/Classification/Sequencer2D/utils/timm/dataset_factory.py @@ -0,0 +1,158 @@ +""" Dataset Factory + +Hacked together by / Copyright 2021, Ross Wightman +""" +import os + +from timm.data import IterableImageDataset, ImageDataset +from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder + +try: + from torchvision.datasets import Places365 + has_places365 = True +except ImportError: + has_places365 = False +try: + from torchvision.datasets import INaturalist + has_inaturalist = True +except ImportError: + has_inaturalist = False + + +from datasets import Flowers102, StanfordCars + +_TORCH_BASIC_DS = dict( + cifar10=CIFAR10, + cifar100=CIFAR100, + mnist=MNIST, + qmist=QMNIST, + kmnist=KMNIST, + fashion_mnist=FashionMNIST, +) +_TRAIN_SYNONYM = {'train', 'training'} +_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'} + + +def _search_split(root, split): + # look for sub-folder with name of split in root and use that if it exists + split_name = split.split('[')[0] + try_root = os.path.join(root, split_name) + if os.path.exists(try_root): + return try_root + + def _try(syn): + for s in syn: + try_root = os.path.join(root, s) + if os.path.exists(try_root): + return try_root + return root + if split_name in _TRAIN_SYNONYM: + root = _try(_TRAIN_SYNONYM) + elif split_name in _EVAL_SYNONYM: + root = _try(_EVAL_SYNONYM) + return root + + +def create_dataset( + name, + root, + split='validation', + search_split=True, + class_map=None, + load_bytes=False, + is_training=False, + download=False, + batch_size=None, + repeats=0, + **kwargs +): + """ Dataset factory method + + In parenthesis after each arg are the type of dataset supported for each arg, one of: + * folder - default, timm folder (or tar) based ImageDataset + * torch - torchvision based datasets + * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset + * all - any of the above + + Args: + name: dataset name, empty is okay for folder based datasets + root: root folder of dataset (all) + split: dataset split (all) + search_split: search for split specific child fold from root so one can specify + `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder) + class_map: specify class -> index mapping via text file or dict (folder) + load_bytes: load data, return images as undecoded bytes (folder) + download: download dataset if not present and supported (TFDS, torch) + is_training: create dataset in train mode, this is different from the split. + For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS) + batch_size: batch size hint for (TFDS) + repeats: dataset repeats per iteration i.e. epoch (TFDS) + **kwargs: other args to pass to dataset + + Returns: + Dataset object + """ + name = name.lower() + if name.startswith('torch/'): + name = name.split('/', 2)[-1] + torch_kwargs = dict(root=root, download=download, **kwargs) + if name in _TORCH_BASIC_DS: + ds_class = _TORCH_BASIC_DS[name] + use_train = split in _TRAIN_SYNONYM + ds = ds_class(train=use_train, **torch_kwargs) + elif name == 'flowers': + if split in _TRAIN_SYNONYM: + split = 'train' + elif split in _EVAL_SYNONYM: + split = 'test' + ds = Flowers102(split=split, **torch_kwargs) + elif name == 'cars': + if split in _TRAIN_SYNONYM: + split = 'train' + elif split in _EVAL_SYNONYM: + split = 'test' + ds = StanfordCars(split=split, **torch_kwargs) + elif name == 'inaturalist' or name == 'inat': + assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist' + target_type = 'full' + split_split = split.split('/') + if len(split_split) > 1: + target_type = split_split[0].split('_') + if len(target_type) == 1: + target_type = target_type[0] + split = split_split[-1] + if split in _TRAIN_SYNONYM: + split = '2021_train' + elif split in _EVAL_SYNONYM: + split = '2021_valid' + ds = INaturalist(version=split, target_type=target_type, **torch_kwargs) + elif name == 'places365': + assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.' + if split in _TRAIN_SYNONYM: + split = 'train-standard' + elif split in _EVAL_SYNONYM: + split = 'val' + ds = Places365(split=split, **torch_kwargs) + elif name == 'imagenet': + if split in _EVAL_SYNONYM: + split = 'val' + ds = ImageNet(split=split, **torch_kwargs) + elif name == 'image_folder' or name == 'folder': + # in case torchvision ImageFolder is preferred over timm ImageDataset for some reason + if search_split and os.path.isdir(root): + # look for split specific sub-folder in root + root = _search_split(root, split) + ds = ImageFolder(root, **kwargs) + else: + assert False, f"Unknown torchvision dataset {name}" + elif name.startswith('tfds/'): + ds = IterableImageDataset( + root, parser=name, split=split, is_training=is_training, + download=download, batch_size=batch_size, repeats=repeats, **kwargs) + else: + # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future + if search_split and os.path.isdir(root): + # look for split specific sub-folder in root + root = _search_split(root, split) + ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs) + return ds diff --git a/PyTorch/build-in/Classification/Sequencer2D/utils/timm/summary.py b/PyTorch/build-in/Classification/Sequencer2D/utils/timm/summary.py new file mode 100644 index 000000000..9ed3cd36e --- /dev/null +++ b/PyTorch/build-in/Classification/Sequencer2D/utils/timm/summary.py @@ -0,0 +1,28 @@ +# Copyright (c) 2022. Yuki Tatsunami +# Licensed under the Apache License, Version 2.0 (the "License"); + +import csv +from collections import OrderedDict + + +def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False, log_clearml=False): + rowd = OrderedDict(epoch=epoch) + rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) + rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) + if log_wandb: + import wandb + wandb.log(rowd) + if log_clearml: + from clearml import Logger + for k, v in train_metrics.items(): + Logger.current_logger().report_scalar( + "train", k, iteration=epoch, value=v) + for k, v in eval_metrics.items(): + Logger.current_logger().report_scalar( + "eval", k, iteration=epoch, value=v) + + with open(filename, mode='a') as cf: + dw = csv.DictWriter(cf, fieldnames=rowd.keys()) + if write_header: # first iteration (epoch == 1 can't be used) + dw.writeheader() + dw.writerow(rowd) diff --git a/PyTorch/build-in/Classification/Sequencer2D/weloTrainStep.py b/PyTorch/build-in/Classification/Sequencer2D/weloTrainStep.py new file mode 100644 index 000000000..13297c11b --- /dev/null +++ b/PyTorch/build-in/Classification/Sequencer2D/weloTrainStep.py @@ -0,0 +1,692 @@ +#!/usr/bin/env python3 +# coding: utf-8 + +import os +import random +import sys +import time +import json +import argparse +from collections import OrderedDict +from pathlib import Path +import numpy as np +import pandas as pd +from tqdm import tqdm +import importlib + +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # 强烈推荐在 shell/最顶端设置 +os.environ["PYTHONHASHSEED"] = "12345" +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" + +def ensure_cublas_workspace(config=":4096:8"): + """ + 尝试为 cuBLAS 设置可复现 workspace。强烈建议在主脚本入口处(import torch 之前) + 通过 export 设置该 env。此函数会在运行时设置,但如果 torch 已经被 import, + 则可能为时已晚——函数会打印提醒。 + """ + already = os.environ.get("CUBLAS_WORKSPACE_CONFIG") + if already: + print(f"[seed_utils] CUBLAS_WORKSPACE_CONFIG 已存在:{already}") + else: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = config + print(f"[seed_utils] 已设置 CUBLAS_WORKSPACE_CONFIG={config} (注意:请在 import torch 前设置以保证生效)") + +def set_global_seed(seed: int = 42, set_threads: bool = True): + """ + 统一随机性设置。注意:若希望完全发挥效果,请在主脚本入口(import torch 之前) + 先调用 ensure_cublas_workspace(...) 或在 shell 中 export CUBLAS_WORKSPACE_CONFIG。 + """ + ensure_cublas_workspace() # 会设置 env 并提醒 + os.environ["PYTHONHASHSEED"] = str(seed) + + if set_threads: + os.environ["OMP_NUM_THREADS"] = "1" + os.environ["MKL_NUM_THREADS"] = "1" + + random.seed(seed) + np.random.seed(seed) + + # 现在导入 torch(晚导入以便前面 env 生效) + import torch + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + # 强制确定性(如果存在不确定性算子,PyTorch 会报错并提示) + try: + torch.use_deterministic_algorithms(True) + except Exception as e: + print("[seed_utils] 设置 deterministic 模式时出错:", e) + print("[seed_utils] 请确认 CUBLAS_WORKSPACE_CONFIG 已在 import torch 之前设置。") + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + if set_threads: + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + + print(f"[seed_utils] 全局 seed 已设置为 {seed}") + +set_global_seed(2025) + +""" +通用训练模版(优先从本地导入 Model -> 支持 DDP / 单卡,AMP,resume,日志,checkpoint) +保存为 train_template_localmodel.py +""" +import torch +import torch.nn as nn +import torch.optim as optim +import torch.backends.cudnn as cudnn +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as tv_models + +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from torch.sdaa import amp +# from torch.cuda import amp + + +# ---------------------------- +# Helper utilities (self-contained) +# ---------------------------- +class AverageMeter(object): + def __init__(self, name='Meter', fmt=':.4f'): + self.name = name + self.fmt = fmt + self.reset() + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / max(1, self.count) + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} (avg {avg' + self.fmt + '})' + return fmtstr.format(name=self.name, val=self.val, avg=self.avg) + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k + 返回一个 list,每个元素是 tensor(百分比形式) + """ + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + # output: (N, C) -> pred: (maxk, N) + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() # (maxk, N) + correct = pred.eq(target.view(1, -1).expand_as(pred)) # (maxk, N) bool + + res = [] + for k in topk: + # 把前 k 行展平后求和(返回 0-dim tensor),随后换算为百分比 + correct_k = correct[:k].reshape(-1).float().sum() # 注意:不传 keepdim + # 乘以 100.0 / batch_size,保持返回 tensor(和之前代码兼容) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + +def save_checkpoint(state, is_best, save_dir, filename='checkpoint.pth'): + save_path = os.path.join(save_dir, filename) + torch.save(state, save_path) + if is_best: + best_path = os.path.join(save_dir, 'model_best.pth') + torch.save(state, best_path) + +def set_seed(seed, deterministic=False): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if deterministic: + cudnn.deterministic = True + cudnn.benchmark = False + else: + cudnn.deterministic = False + cudnn.benchmark = True + +# ---------------------------- +# Argument parser +# ---------------------------- +def parse_args(): + parser = argparse.ArgumentParser(description='Generic PyTorch training template (DDP/AMP) with LocalModel priority') + parser.add_argument('--name', default='run', type=str, help='experiment name (log/checkpoints dir)') + parser.add_argument('--seed', default=42, type=int, help='random seed') + parser.add_argument('--arch', default='None', type=str, help='model name') + parser.add_argument('--deterministic', action='store_true', help='set cudnn deterministic (may be slower)') + parser.add_argument('--dataset', default='cifar10', choices=['cifar10','cifar100','imagenet','custom'], help='which dataset') + parser.add_argument('--datapath', default='./data', type=str, help='dataset root / imagenet root / custom root') + parser.add_argument('--imagenet_dir', default='./imagenet', type=str, help='if dataset=imagenet, path to imagenet root') + parser.add_argument('--custom_eval_dir', default=None, help='if dataset=custom, provide val dir') + parser.add_argument('--num_workers', default=4, type=int, help='dataloader workers per process') + parser.add_argument('--epochs', default=200, type=int) + parser.add_argument('--steps', default=0, type=int, help='max steps to run (if >0, training will stop when global_step reaches this).') + parser.add_argument('--batch_size', default=128, type=int) + parser.add_argument('--model_name', default='resnet18', help='torchvision model name or python path e.g. mypkg.mymodule.Model (used if no local Model)') + parser.add_argument('--num_classes', default=None, type=int, help='override num classes (auto-detect for common sets)') + parser.add_argument('--pretrained', action='store_true', help='use torchvision pretrained weights when available') + parser.add_argument('--optimizer', default='sgd', choices=['sgd','adam','adamw'], help='optimizer') + parser.add_argument('--lr', '--learning_rate', default=0.1, type=float) + parser.add_argument('--momentum', default=0.9, type=float) + parser.add_argument('--weight_decay', default=5e-4, type=float) + parser.add_argument('--nesterov', action='store_true') + parser.add_argument('--scheduler', default='multistep', choices=['multistep','step','cosine','none'], help='lr scheduler') + parser.add_argument('--milestones', default='100,150', type=str, help='milestones for multistep (comma sep)') + parser.add_argument('--step_size', default=30, type=int, help='step size for StepLR or cosine max epochs') + parser.add_argument('--gamma', default=0.1, type=float) + parser.add_argument('--scheduler_step_per_batch', action='store_true', help='call scheduler.step() per batch (for some schedulers)') + parser.add_argument('--resume', default='', type=str, help='path to checkpoint to resume from') + parser.add_argument('--start_epoch', default=0, type=int) + parser.add_argument('--print_freq', default=100, type=int) + parser.add_argument('--save_freq', default=10, type=int, help='save checkpoint every N epochs (rank0 only)') + parser.add_argument('--amp', action='store_true', default = True,help='use automatic mixed precision (AMP)') + parser.add_argument('--grad_accum_steps', default=1, type=int, help='gradient accumulation steps') + parser.add_argument('--local_rank', default=None, type=int, help='local rank passed by torchrun (if any). Use -1 or None for non-distributed') + parser.add_argument('--cutmix_prob', default=0.0, type=float) + parser.add_argument('--beta', default=1.0, type=float) + parser.add_argument('--seed_sampler', default=False, action='store_true', help='set sampler epoch seeds to make deterministic distributed shuffling') + args = parser.parse_args() + args.milestones = [int(x) for x in args.milestones.split(',')] if args.milestones else [] + return args + +# ---------------------------- +# build model (优先 LocalModel) +# ---------------------------- +def build_model_with_local_priority(args, device=None): + """ + 用参数 args.arch 作为模块名导入 Model() + 如果模块不存在或没有 Model 类,则报错停止。 + """ + try: + # 动态导入模块,比如 args.arch = "rexnet" + mod = importlib.import_module(args.arch) + Model = getattr(mod, "Model") # 从模块中获取 Model 类 + except Exception as e: + raise RuntimeError( + f"无法导入模型模块 '{args.arch}' 或未找到类 Model。" + f"\n错误信息:{e}" + ) + + # 解析数据集类别数 + if args.dataset == 'cifar10': + num_classes = 10 + elif args.dataset == 'cifar100': + num_classes = 100 + else: + print(f"[ERROR] 不支持的数据集类型:{args.dataset},无法确定类别数。程序终止。") + sys.exit(1) + + + # 实例化 + try: + model = Model(num_classes) + except Exception as e: + raise RuntimeError( + f"Model() 实例化失败,请检查模型构造函数。\n错误信息:{e}" + ) + + return model + +# ---------------------------- +# Data loader factory +# ---------------------------- +def build_dataloaders(args, rank, world_size): + if args.dataset == 'cifar10' or args.dataset == 'cifar100': + mean = (0.4914, 0.4822, 0.4465) + std = (0.2470, 0.2435, 0.2616) if args.dataset == 'cifar10' else (0.2023, 0.1994, 0.2010) + # train_transform = transforms.Compose([ + # transforms.RandomCrop(32, padding=4), + # transforms.RandomHorizontalFlip(), + # transforms.ToTensor(), + # transforms.Normalize(mean, std), + # ]) + # test_transform = transforms.Compose([ + # transforms.ToTensor(), + # transforms.Normalize(mean, std), + # ]) + + train_transform = transforms.Compose([ # 2025/12/3 从visformer模型开始 + transforms.Resize(256), # 先放大到 256 + transforms.RandomCrop(224), # 再随机裁剪为 224(更符合 ImageNet 风格增强) + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean, std), + ]) + test_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean, std), + ]) + root = args.datapath + if args.dataset == 'cifar10': + train_set = datasets.CIFAR10(root=root, train=True, download=False, transform=train_transform) + val_set = datasets.CIFAR10(root=root, train=False, download=False, transform=test_transform) + num_classes = 10 + else: + train_set = datasets.CIFAR100(root=root, train=True, download=False, transform=train_transform) + val_set = datasets.CIFAR100(root=root, train=False, download=False, transform=test_transform) + num_classes = 100 + + elif args.dataset == 'imagenet': + train_dir = os.path.join(args.imagenet_dir, 'train') + val_dir = os.path.join(args.imagenet_dir, 'val') + train_transform = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)), + ]) + test_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)), + ]) + train_set = datasets.ImageFolder(train_dir, train_transform) + val_set = datasets.ImageFolder(val_dir, test_transform) + num_classes = args.num_classes or 1000 + + elif args.dataset == 'custom': + train_dir = os.path.join(args.datapath, 'train') + val_dir = args.custom_eval_dir or os.path.join(args.datapath, 'val') + train_transform = transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + ]) + test_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + ]) + train_set = datasets.ImageFolder(train_dir, train_transform) + val_set = datasets.ImageFolder(val_dir, test_transform) + num_classes = len(train_set.classes) + else: + raise ValueError("Unknown dataset") + + if dist.is_initialized() and world_size > 1: + train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True) + else: + train_sampler = None + + train_loader = DataLoader(train_set, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + num_workers=args.num_workers, + pin_memory=True, + sampler=train_sampler, + drop_last=False) + val_loader = DataLoader(val_set, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + pin_memory=True) + + return train_loader, val_loader, num_classes, train_sampler + +# ---------------------------- +# Train & validate +# ---------------------------- +def train_one_epoch(args, epoch, model, criterion, optimizer, train_loader, device, scaler, scheduler=None, train_sampler=None, global_step_start=0, max_global_steps=None): + """ + 现在支持:若 max_global_steps 非 None,则当 global_step 达到该值时提前退出 + 返回: epoch_summary_dict, step_logs_list, global_step_end + step_logs_list: list of dicts with per-step info (for logging to CSV if需要) + """ + batch_time = AverageMeter('Time') + data_time = AverageMeter('Data') + losses = AverageMeter('Loss') + top1 = AverageMeter('Acc@1') + top5 = AverageMeter('Acc@5') + + model.train() + end = time.time() + optimizer.zero_grad() + + iters = len(train_loader) + step_logs = [] + global_step = global_step_start + + for i, (images, targets) in enumerate(train_loader): + # check global steps limit + if (max_global_steps is not None) and (global_step >= max_global_steps): + break + + data_time.update(time.time() - end) + images = images.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + + if args.amp: + with amp.autocast(): + outputs = model(images) + loss = criterion(outputs, targets) / args.grad_accum_steps + else: + outputs = model(images) + loss = criterion(outputs, targets) / args.grad_accum_steps + + if args.amp: + scaler.scale(loss).backward() + else: + loss.backward() + + # 每当累积步满足 grad_accum_steps 就 step + if (i + 1) % args.grad_accum_steps == 0: + if args.amp: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + optimizer.zero_grad() + if scheduler is not None and args.scheduler_step_per_batch: + scheduler.step() + + with torch.no_grad(): + acc1, acc5 = accuracy(outputs, targets, topk=(1,5)) + losses.update(loss.item() * args.grad_accum_steps, images.size(0)) + top1.update(acc1.item(), images.size(0)) + top5.update(acc5.item(), images.size(0)) + + batch_time.update(time.time() - end) + end = time.time() + + # increment global step AFTER processing this batch + global_step += 1 + + # per-step print (controlled by print_freq) + if ((global_step % args.print_freq == 0) or (i == iters - 1)) and ((dist.get_rank() if dist.is_initialized() else 0) == 0): + lr = optimizer.param_groups[0]['lr'] + print(f"Epoch[{epoch}]:step[{i+1}/{iters}] step_train_loss {losses.val:.4f} acc1 {top1.val:.2f} acc5 {top5.val:.2f}") + + # collect per-step log + step_logs.append({ + 'epoch': epoch, + 'batch_idx': i, + 'global_step': global_step, + 'lr': optimizer.param_groups[0]['lr'], + 'loss': losses.val, + 'loss_avg': losses.avg, + 'acc1': top1.val, + 'acc1_avg': top1.avg, + 'acc5': top5.val, + 'acc5_avg': top5.avg, + 'time': batch_time.val + }) + + # if reached max_global_steps inside epoch, break (handled at loop start next iter) + if (max_global_steps is not None) and (global_step >= max_global_steps): + if (dist.get_rank() if dist.is_initialized() else 0) == 0: + print(f"[Info] 达到 max_global_steps={max_global_steps},将在 epoch 内提前停止。") + break + + # --- flush remaining grads if needed (handle gradient accumulation leftovers) --- + processed_batches = global_step - global_step_start # 实际处理的 batch 数 + if args.grad_accum_steps > 1 and (processed_batches % args.grad_accum_steps) != 0: + # only step if there are gradients + grads_present = any((p.grad is not None and p.requires_grad) for p in model.parameters()) + if grads_present: + if args.amp: + try: + scaler.step(optimizer) + scaler.update() + except Exception as e: + # 防御性:若 scaler.step 因某些原因失败,尝试普通 step(只在极端情况下) + print("[Warning] scaler.step 失败,尝试普通 optimizer.step():", e) + optimizer.step() + else: + optimizer.step() + optimizer.zero_grad() + if scheduler is not None and args.scheduler_step_per_batch: + scheduler.step() + if (dist.get_rank() if dist.is_initialized() else 0) == 0: + print(f"[Info] flushed remaining gradients after early stop (processed_batches={processed_batches}, grad_accum={args.grad_accum_steps}).") + + if scheduler is not None and not args.scheduler_step_per_batch: + scheduler.step() + + return OrderedDict([('loss', losses.avg), ('acc1', top1.avg), ('acc5', top5.avg)]), step_logs, global_step + +def validate(args, model, val_loader, criterion, device, max_batches=None): + """ + Validate on the val_loader. + If max_batches is not None, only process up to that many batches (useful for quick checks). + Returns an OrderedDict with loss/acc1/acc5 (averaged over processed samples). + """ + losses = AverageMeter('Loss') + top1 = AverageMeter('Acc@1') + top5 = AverageMeter('Acc@5') + + model.eval() + processed_batches = 0 + processed_samples = 0 + with torch.no_grad(): + for i, (images, targets) in enumerate(tqdm(val_loader)): + images = images.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + outputs = model(images) + loss = criterion(outputs, targets) + acc1, acc5 = accuracy(outputs, targets, topk=(1,5)) + batch_n = images.size(0) + losses.update(loss.item(), batch_n) + top1.update(acc1.item(), batch_n) + top5.update(acc5.item(), batch_n) + + processed_batches += 1 + processed_samples += batch_n + + if (max_batches is not None) and (processed_batches >= max_batches): + break + + # 如果没处理任何样本,避免除0(不太可能,但防御性) + if processed_samples == 0: + return OrderedDict([('loss', 0.0), ('acc1', 0.0), ('acc5', 0.0)]) + return OrderedDict([('loss', losses.avg), ('acc1', top1.avg), ('acc5', top5.avg)]) + +# ---------------------------- +# Main +# ---------------------------- +def main(): + args = parse_args() + + # handle local_rank from env if not provided + local_rank_env = os.environ.get('LOCAL_RANK', None) + if args.local_rank is None and local_rank_env is not None: + args.local_rank = int(local_rank_env) + + distributed = (args.local_rank is not None and args.local_rank != -1) + if distributed: + dist.init_process_group(backend='nccl', init_method='env://') + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + + if distributed: + torch.cuda.set_device(args.local_rank) + device = torch.device('cuda', args.local_rank) + else: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + set_seed(args.seed + (rank if distributed else 0), deterministic=args.deterministic) + + save_dir = os.path.join('models', args.name) + if rank == 0: + os.makedirs(save_dir, exist_ok=True) + with open(os.path.join(save_dir, 'args.json'), 'w') as f: + json.dump(vars(args), f, indent=2) + if distributed: + dist.barrier() + + train_loader, val_loader, auto_num_classes, train_sampler = build_dataloaders(args, rank, world_size) + if args.num_classes is None: + args.num_classes = auto_num_classes + + # 使用本地 Model 优先(LocalModel 已在文件顶部尝试导入) + model = build_model_with_local_priority(args, device) + model.to(device) + + if distributed: + model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) + + criterion = nn.CrossEntropyLoss().to(device) + params = [p for p in model.parameters() if p.requires_grad] + if args.optimizer == 'sgd': + optimizer = optim.SGD(params, lr=args.lr, momentum=args.momentum, + weight_decay=args.weight_decay, nesterov=args.nesterov) + elif args.optimizer == 'adam': + optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) + elif args.optimizer == 'adamw': + optimizer = optim.AdamW(params, lr=args.lr, weight_decay=args.weight_decay) + else: + raise ValueError('Unknown optimizer') + + scheduler = None + if args.scheduler == 'multistep': + scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma) + elif args.scheduler == 'step': + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) + elif args.scheduler == 'cosine': + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) + elif args.scheduler == 'none': + scheduler = None + + scaler = amp.GradScaler() if args.amp else None + + start_epoch = args.start_epoch + best_acc = 0.0 + if args.resume: + if os.path.isfile(args.resume): + ckpt = torch.load(args.resume, map_location='cpu') + model_state = ckpt.get('state_dict', ckpt) + if isinstance(model, DDP): + model.module.load_state_dict(model_state) + else: + model.load_state_dict(model_state) + if 'optimizer' in ckpt: + optimizer.load_state_dict(ckpt['optimizer']) + start_epoch = ckpt.get('epoch', start_epoch) + best_acc = ckpt.get('best_acc', best_acc) + print(f"=> resumed from {args.resume}, start_epoch={start_epoch}") + else: + print(f"=> resume path {args.resume} not found") + + log_columns = ['epoch', 'lr', 'loss', 'acc1', 'acc5', 'val_loss', 'val_acc1', 'val_acc5'] + log_df = pd.DataFrame(columns=log_columns) + # step-level log + step_log_columns = ['epoch', 'batch_idx', 'global_step', 'lr', 'loss', 'loss_avg', 'acc1', 'acc1_avg', 'acc5', 'acc5_avg', 'time'] + step_log_df = pd.DataFrame(columns=step_log_columns) + + total_epochs = args.epochs + # global_step计数器(训练过程中跨epoch持续) + global_step = 0 + + epoch = start_epoch + # loop until either epoch criteria or step criteria met + while True: + if train_sampler is not None: + if args.seed_sampler: + train_sampler.set_epoch(epoch + args.seed) + else: + train_sampler.set_epoch(epoch) + + if rank == 0: + print(f"==== Epoch {epoch}/{total_epochs - 1} ====") + + # 如果传入了 args.steps (>0),则把剩余允许的 step 数传给 train_one_epoch, + # 否则 max_global_steps=None(按整 epoch 执行完) + if args.steps and args.steps > 0: + max_global_steps = args.steps + else: + max_global_steps = None + + train_log, step_logs, global_step = train_one_epoch( + args, epoch, model, criterion, optimizer, train_loader, device, scaler, + scheduler, train_sampler, global_step_start=global_step, max_global_steps=max_global_steps + ) + + # 如果启用了按 steps 的模式且已经达到上限,标记需要在做一次验证后退出 + if max_global_steps is not None and global_step >= max_global_steps: + if rank == 0: + print(f"[Main] 达到 max_global_steps={max_global_steps}(global_step={global_step}),将在完成验证后退出训练。") + # 我们不 return 立刻退出;后面的 validate / 保存逻辑会执行一次,然后 main 返回/结束 + end_due_to_steps = True + else: + end_due_to_steps = False + + # 验证并记录 epoch 级别日志(如果在 step 模式下很可能在中间某个 epoch 提前结束,但我们仍做一次 validate) + val_log = validate(args, model, val_loader, criterion, device, args.batch_size) + current_lr = optimizer.param_groups[0]['lr'] + + if rank == 0: + # epoch summary print, 格式与示例对齐 + print(f"Epoch[{epoch}]: epoch_train_loss {train_log['loss']:.4f} acc1 {train_log['acc1']:.2f} acc5 {train_log['acc5']:.2f} | " + f"val_loss {val_log['loss']:.4f} acc1 {val_log['acc1']:.2f} acc5 {val_log['acc5']:.2f} lr {current_lr:.6f}") + row = { + 'epoch': epoch, + 'lr': current_lr, + 'loss': train_log['loss'], + 'acc1': train_log['acc1'], + 'acc5': train_log['acc5'], + 'val_loss': val_log['loss'], + 'val_acc1': val_log['acc1'], + 'val_acc5': val_log['acc5'], + } + new_row_df = pd.DataFrame([row]) + log_df = pd.concat([log_df, new_row_df], ignore_index=True) + log_df.to_csv(os.path.join(save_dir, 'log.csv'), index=False) + + is_best = val_log['acc1'] > best_acc + if is_best: + best_acc = val_log['acc1'] + if (epoch % args.save_freq == 0) or is_best or ( (max_global_steps is None) and (epoch == total_epochs - 1) ) : + state = { + 'epoch': epoch, + 'state_dict': model.module.state_dict() if isinstance(model, DDP) else model.state_dict(), + 'best_acc': best_acc, + 'optimizer': optimizer.state_dict(), + 'args': vars(args) + } + save_checkpoint(state, is_best, save_dir, filename=f'checkpoint_epoch_{epoch}.pth') + + # 如果是因为 steps 模式达到上限,则在完成 validation / 保存后退出训练 + if end_due_to_steps: + if rank == 0: + print(f"[Main] 已在 steps 模式下完成最后一次验证并保存,训练结束(global_step={global_step})。") + break + + # increment epoch + epoch += 1 + + # stopping conditions: + # 1) if steps mode enabled and reached steps -> stop + if args.steps and args.steps > 0: + if global_step >= args.steps: + if rank == 0: + print(f"[Main] 已达到指定 steps={args.steps}(global_step={global_step}),训练结束。") + break + + # 2) if steps not used, stop when epoch >= epochs + else: + if epoch >= total_epochs: + if rank == 0: + print(f"[Main] 已达到指定 epochs={total_epochs}(epoch={epoch}),训练结束。") + break + + if dist.is_initialized(): + dist.barrier() + if rank == 0: + print("Training finished. Best val acc1: {:.2f}".format(best_acc)) + +if __name__ == '__main__': + main() \ No newline at end of file