22import torch
33import torch .nn as nn
44
5+
56class RevIN (nn .Module ):
67 """
7- Reverse Instance Normalization: 시계열 데이터의 분포 변화(Distribution Shift) 문제를 해결
8+ Reversible Instance Normalization
9+ - 입력 시: 샘플 단위로 정규화 (mean, stdev 저장)
10+ - 출력 시: 저장된 통계로 원래 분포 복원
811 """
912 def __init__ (self , num_features : int , eps = 1e-5 , affine = True ):
1013 super (RevIN , self ).__init__ ()
1114 self .num_features = num_features
1215 self .eps = eps
1316 self .affine = affine
1417 if self .affine :
15- self ._init_params ()
16-
17- def _init_params (self ):
18- self .affine_weight = nn .Parameter (torch .ones (self .num_features ))
19- self .affine_bias = nn .Parameter (torch .zeros (self .num_features ))
18+ # 피처마다 정규화 강도를 조절하는 학습 가능한 파라미터
19+ self .affine_weight = nn .Parameter (torch .ones (self .num_features )) # γ
20+ self .affine_bias = nn .Parameter (torch .zeros (self .num_features )) # β
2021
2122 def forward (self , x , mode : str ):
2223 if mode == 'norm' :
@@ -27,98 +28,107 @@ def forward(self, x, mode: str):
2728 return x
2829
2930 def _get_statistics (self , x ):
31+ # 시간축(dim=1)을 따라 평균/표준편차 계산 후 저장
32+ # detach(): 통계값은 gradient 계산에 포함시키지 않음
3033 dim2reduce = tuple (range (1 , x .ndim - 1 ))
31- self .mean = torch .mean (x , dim = dim2reduce , keepdim = True ).detach ()
32- self .stdev = torch .sqrt (torch .var (x , dim = dim2reduce , keepdim = True , unbiased = False ) + self .eps ).detach ()
34+ self .mean = torch .mean (x , dim = dim2reduce , keepdim = True ).detach ()
35+ self .stdev = torch .sqrt (
36+ torch .var (x , dim = dim2reduce , keepdim = True , unbiased = False ) + self .eps
37+ ).detach ()
3338
3439 def _normalize (self , x ):
35- x = x - self .mean
36- x = x / self .stdev
40+ x = (x - self .mean ) / self .stdev
3741 if self .affine :
3842 x = x * self .affine_weight + self .affine_bias
3943 return x
4044
4145 def _denormalize (self , x ):
4246 if self .affine :
4347 x = (x - self .affine_bias ) / (self .affine_weight + 1e-9 )
44- x = x * self .stdev
45- x = x + self .mean
48+ x = x * self .stdev + self .mean
4649 return x
4750
51+
4852class PatchTST_Model (nn .Module ):
4953 """
5054 SISC 맞춤형 PatchTST 모델
51- - 입력: [Batch, Seq_Len, Features] (예: 120일치 데이터)
52- - 출력: [Batch, 1] ( 상승 확률 Logits)
55+ - 입력 : [Batch, Seq_Len, Features] 예) [32, 120, 17]
56+ - 출력 : [Batch, 4] → 1일/3일/5일/7일 상승 확률 logits
5357 """
54- def __init__ (self ,
55- seq_len = 120 ,
56- enc_in = 7 , # Feature 개수
57- patch_len = 16 ,
58- stride = 8 ,
59- d_model = 128 ,
60- n_heads = 4 ,
61- e_layers = 3 ,
62- d_ff = 256 ,
63- dropout = 0.1 ):
58+ def __init__ (self ,
59+ seq_len = 120 ,
60+ enc_in = 17 , # 피처 수 (일봉 11 + 주봉 4 + 월봉 2)
61+ patch_len = 16 ,
62+ stride = 8 ,
63+ d_model = 128 ,
64+ n_heads = 4 ,
65+ e_layers = 3 ,
66+ d_ff = 256 ,
67+ dropout = 0.1 ,
68+ n_outputs = 4 ): # 1일/3일/5일/7일 예측
6469 super (PatchTST_Model , self ).__init__ ()
65-
66- self .seq_len = seq_len
67- self .patch_len = patch_len
68- self .stride = stride
70+
71+ self .seq_len = seq_len
72+ self .patch_len = patch_len
73+ self .stride = stride
6974 self .num_patches = int ((seq_len - patch_len ) / stride ) + 1
7075
71- # 1. RevIN (입력 정규화)
76+ # 1. RevIN
7277 self .revin = RevIN (enc_in )
7378
74- # 2. Patching & Embedding
75- self .patch_embedding = nn .Linear (patch_len , d_model )
76- self .position_embedding = nn .Parameter (torch .randn (1 , enc_in , self .num_patches , d_model ))
79+ # 2. Patch Embedding: 패치 하나를 d_model 차원 벡터로 변환
80+ self .patch_embedding = nn .Linear (patch_len , d_model )
81+ # 학습 가능한 위치 임베딩 (sin/cos 방식 아님, 논문 공식 구현 방식)
82+ self .position_embedding = nn .Parameter (
83+ torch .randn (1 , enc_in , self .num_patches , d_model )
84+ )
7785 self .dropout = nn .Dropout (dropout )
7886
79- # 3. Transformer Encoder Backbone (Channel Independent)
80- encoder_layer = nn .TransformerEncoderLayer (d_model , n_heads , d_ff , dropout , batch_first = True )
87+ # 3. Transformer Encoder (Channel Independent: B*F 단위로 독립 처리)
88+ encoder_layer = nn .TransformerEncoderLayer (
89+ d_model , n_heads , d_ff , dropout , batch_first = True
90+ )
8191 self .encoder = nn .TransformerEncoder (encoder_layer , e_layers )
8292
83- # 4. Flatten & Head (Prediction)
93+ # 4. Flatten + MLP Head
94+ # enc_in * num_patches * d_model → 256 → n_outputs
8495 self .head = nn .Sequential (
8596 nn .Flatten (start_dim = 1 ),
8697 nn .Linear (enc_in * self .num_patches * d_model , 256 ),
8798 nn .GELU (),
8899 nn .Dropout (dropout ),
89- nn .Linear (256 , 1 ) # 최종 출력: 상승 확률 Logit (Sigmoid 전)
100+ nn .Linear (256 , n_outputs ) # 1일/3일/5일/7일 logits
90101 )
91102
92103 def forward (self , x ):
93- # x shape: [Batch, Seq_Len, Features]
94-
95- # 1. Normalization
96- x = self .revin (x , 'norm' ) # [B, S, F]
97-
98- # 2. Channel Independence handling: [B, S, F] -> [B, F, S] -> [B*F, S]
104+ # x: [B, S, F]
99105 B , S , F = x .shape
100- x = x .permute (0 , 2 , 1 ).reshape (B * F , S )
101-
102- # 3. Patching: [B*F, S] -> [B*F, Num_Patches, Patch_Len]
106+
107+ # 1. RevIN 정규화
108+ x = self .revin (x , 'norm' ) # [B, S, F]
109+
110+ # 2. Channel Independence: 피처를 독립 처리하기 위해 차원 변환
111+ x = x .permute (0 , 2 , 1 ) # [B, F, S]
112+ x = x .reshape (B * F , S ) # [B*F, S]
113+
114+ # 3. Patching: 시계열을 구간으로 자르기
103115 x = x .unfold (dimension = 1 , size = self .patch_len , step = self .stride )
104-
105- # 4. Embedding: [B*F, Num_Patches, d_model]
106- x = self . patch_embedding ( x )
107-
108- # Position Embedding 더하기
109- # [B*F, N, D] 형태로 맞춤
116+ # [B*F, num_patches, patch_len]
117+
118+ # 4. Patch Embedding
119+ x = self . patch_embedding ( x ) # [B*F, N, d_model]
120+
121+ # 5. Positional Embedding
110122 pos_emb = self .position_embedding .repeat (B , 1 , 1 , 1 ).reshape (B * F , self .num_patches , - 1 )
111- x = x + pos_emb
112- x = self .dropout (x )
113-
114- # 5. Transformer Encoder
115- x = self .encoder (x ) # [B*F, N, D]
116-
117- # 6. Reshape back: [B, F, N, D]
118- x = x .reshape (B , F , self .num_patches , - 1 )
119-
120- # 7. Final Prediction Head
121- # 모든 채널과 패치 정보를 합쳐서 하나의 확률값 예측
122- out = self .head (x ) # [B, 1]
123-
124- return out # Logits 반환 (BCEWithLogitsLoss 사용 권장)
123+ x = self .dropout (x + pos_emb ) # [B*F, N, d_model]
124+
125+ # 6. Transformer Encoder
126+ x = self .encoder (x ) # [B*F, N, d_model]
127+
128+ # 7. 채널 복원
129+ x = x .reshape (B , F , self .num_patches , - 1 ) # [B, F, N, d_model]
130+
131+ # 8. Head → 4개 logits 출력
132+ out = self .head (x ) # [B, n_outputs]
133+
134+ return out # BCEWithLogitsLoss 사용 → sigmoid는 loss 내부에서 처리
0 commit comments