Skip to content

Commit 92bc3b3

Browse files
authored
[AI] [FEAT] 볼륨 마운트를 통한 가중치 저장 (#335)
* [AI] [FEAT] 볼륨 마운트를 통한 가중치 저장 * [AI] [FIX] xai 수정 * [AI] [FIX] 환경변수 정리
1 parent ae48e58 commit 92bc3b3

25 files changed

Lines changed: 771 additions & 744 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ AI/.venv/
3232
AI/data/weights/tcn/
3333
AI/config/trading.local.json
3434
AI/tests/out/
35+
AI/docs/
3536

3637
# ===== Backend =====
3738
backend/src/main/java/org/sejongisc/backend/stock/TestController.java

AI/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
CONFIG_ENV_VAR,
33
DEFAULT_CONFIG_PATH,
44
DEFAULT_LOCAL_CONFIG_PATH,
5+
MODEL_WEIGHTS_DIR_ENV_VAR,
56
DataConfig,
67
ExecutionConfig,
78
MacroFallbackConfig,
@@ -18,6 +19,7 @@
1819
"CONFIG_ENV_VAR",
1920
"DEFAULT_CONFIG_PATH",
2021
"DEFAULT_LOCAL_CONFIG_PATH",
22+
"MODEL_WEIGHTS_DIR_ENV_VAR",
2123
"DataConfig",
2224
"ExecutionConfig",
2325
"MacroFallbackConfig",

AI/config/trading.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
PROJECT_ROOT = Path(__file__).resolve().parents[2]
1212
CONFIG_DIR = Path(__file__).resolve().parent
1313
CONFIG_ENV_VAR = "AI_TRADING_CONFIG_PATH"
14+
MODEL_WEIGHTS_DIR_ENV_VAR = "AI_MODEL_WEIGHTS_DIR"
1415
DEFAULT_CONFIG_PATH = CONFIG_DIR / "trading.default.json"
1516
DEFAULT_LOCAL_CONFIG_PATH = CONFIG_DIR / "trading.local.json"
1617

@@ -116,6 +117,11 @@ def _read_json(path: Path) -> dict[str, Any]:
116117

117118

118119
def _build_config(raw: dict[str, Any]) -> TradingConfig:
120+
env_model_weights_dir = os.getenv(MODEL_WEIGHTS_DIR_ENV_VAR)
121+
if env_model_weights_dir and env_model_weights_dir.strip():
122+
model_weights_dir = env_model_weights_dir.strip()
123+
else:
124+
model_weights_dir = raw["model"]["weights_dir"]
119125
risk_overlay = RiskOverlayConfig(**raw["portfolio"]["risk_overlay"])
120126
macro_fallback = MacroFallbackConfig(**raw["pipeline"]["macro_fallback"])
121127
config = TradingConfig(
@@ -141,7 +147,7 @@ def _build_config(raw: dict[str, Any]) -> TradingConfig:
141147
prediction_horizons=tuple(raw["data"]["prediction_horizons"]),
142148
),
143149
model=ModelConfig(
144-
weights_dir=_resolve_path(raw["model"]["weights_dir"]),
150+
weights_dir=_resolve_path(model_weights_dir),
145151
weights_file=raw["model"]["weights_file"],
146152
scaler_file=raw["model"]["scaler_file"],
147153
),

AI/libs/llm/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,20 @@
22
from .base_client import BaseLLMClient
33
from .groq import GroqClient
44
from .ollama import OllamaClient
5-
from .gemini import GeminiClient
5+
6+
try:
7+
from .gemini import GeminiClient
8+
except Exception as gemini_import_error:
9+
class GeminiClient: # type: ignore[no-redef]
10+
def __init__(self, *args, **kwargs):
11+
raise ImportError(
12+
"GeminiClient requires the `google-genai` package. "
13+
"Install it with `pip install -U google-genai`."
14+
) from gemini_import_error
615

716
__all__ = [
817
"BaseLLMClient",
918
"GroqClient",
1019
"OllamaClient",
1120
"GeminiClient"
12-
]
21+
]

AI/libs/llm/ollama.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,68 @@ class OllamaClient(BaseLLMClient):
1515
def __init__(
1616
self,
1717
base_url: str = "http://localhost:11434",
18-
model_name: str = os.environ.get("OLLAMA_MODEL", "llama3-ko"),
18+
model_name: Optional[str] = None,
1919
):
20-
super().__init__(model_name=model_name)
20+
env_model_name = os.environ.get("OLLAMA_MODEL")
21+
resolved_model_name = model_name or env_model_name or "llama3:latest"
22+
super().__init__(model_name=resolved_model_name)
2123
self.base_url = base_url
24+
self._model_explicitly_set = bool(model_name or env_model_name)
25+
26+
def _list_local_models(self) -> list[str]:
27+
try:
28+
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
29+
response.raise_for_status()
30+
result = response.json()
31+
return [model.get("name", "") for model in result.get("models", []) if model.get("name")]
32+
except Exception:
33+
return []
34+
35+
def _ensure_model_available(self) -> bool:
36+
local_models = self._list_local_models()
37+
if not local_models:
38+
self.set_last_error(
39+
"No local Ollama model found. Pull one first (e.g. `ollama pull llama3:latest`)."
40+
)
41+
return False
42+
43+
if self.model_name in local_models:
44+
return True
45+
46+
if self._model_explicitly_set:
47+
self.set_last_error(
48+
f"Model '{self.model_name}' is not installed. Installed models: {', '.join(local_models)}"
49+
)
50+
return False
51+
52+
fallback_model = local_models[0]
53+
print(
54+
f"[OllamaClient][Warning] Default model '{self.model_name}' is unavailable. "
55+
f"Using '{fallback_model}' instead."
56+
)
57+
self.model_name = fallback_model
58+
return True
2259

2360
def generate_text(self, prompt: str, system_prompt: Optional[str] = None, **kwargs) -> str:
2461
url = f"{self.base_url}/api/generate"
25-
full_prompt = prompt
26-
if system_prompt:
27-
full_prompt = f"System: {system_prompt}\n\nUser: {prompt}"
62+
63+
if not self._ensure_model_available():
64+
print(f"[OllamaClient][Error] Text generation failed: {self.last_error}")
65+
return ""
2866

2967
payload = {
3068
"model": self.model_name,
31-
"prompt": full_prompt,
69+
"prompt": prompt,
3270
"stream": False,
3371
"options": {
3472
"temperature": kwargs.get("temperature", 0.7),
3573
},
3674
}
75+
if system_prompt:
76+
payload["system"] = system_prompt
3777

3878
try:
39-
response = requests.post(url, json=payload)
79+
response = requests.post(url, json=payload, timeout=kwargs.get("timeout", 120))
4080
response.raise_for_status()
4181
result = response.json()
4282
text = result.get("response", "")
@@ -49,8 +89,11 @@ def generate_text(self, prompt: str, system_prompt: Optional[str] = None, **kwar
4989

5090
def get_health(self) -> bool:
5191
try:
52-
res = requests.get(self.base_url, timeout=5)
53-
return res.status_code == 200
92+
res = requests.get(f"{self.base_url}/api/tags", timeout=5)
93+
is_healthy = res.status_code == 200
94+
if is_healthy:
95+
self.clear_last_error()
96+
return is_healthy
5497
except Exception as e:
5598
self.set_last_error(e)
5699
return False

AI/modules/signal/core/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
11
# AI/modules/signal/core/__init__.py
22
from .base_model import BaseSignalModel
33
from .data_loader import DataLoader
4+
from .artifact_paths import (
5+
ARTIFACT_ROOT_ENV_VAR,
6+
ModelArtifactPaths,
7+
resolve_artifact_file,
8+
resolve_artifact_root,
9+
resolve_model_artifacts,
10+
)
411

512
__all__ = [
13+
"ARTIFACT_ROOT_ENV_VAR",
614
"BaseSignalModel",
715
"DataLoader",
8-
]
16+
"ModelArtifactPaths",
17+
"resolve_artifact_file",
18+
"resolve_artifact_root",
19+
"resolve_model_artifacts",
20+
]
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from dataclasses import dataclass
5+
from pathlib import Path
6+
7+
8+
ARTIFACT_ROOT_ENV_VAR = "AI_MODEL_WEIGHTS_DIR"
9+
DEFAULT_ARTIFACT_ROOT = Path("AI/data/weights")
10+
PROJECT_ROOT = Path(__file__).resolve().parents[4]
11+
12+
13+
@dataclass(frozen=True, slots=True)
14+
class ModelArtifactPaths:
15+
root_dir: str
16+
model_dir: str
17+
model_path: str
18+
scaler_path: str | None = None
19+
metadata_path: str | None = None
20+
21+
22+
def _resolve_absolute(raw_path: str | Path) -> Path:
23+
path = Path(raw_path)
24+
if not path.is_absolute():
25+
path = PROJECT_ROOT / path
26+
return path.resolve()
27+
28+
29+
def _normalize_mode(raw_mode: str | None) -> str:
30+
mode = (raw_mode or "prod").strip().lower()
31+
if mode in {"simulation", "sim", "test", "tests", "dev", "development", "qa"}:
32+
return "tests"
33+
if mode in {"live", "production", "prod"}:
34+
return "prod"
35+
return mode
36+
37+
38+
def resolve_artifact_root(config_weights_dir: str | None = None) -> str:
39+
env_root = os.getenv(ARTIFACT_ROOT_ENV_VAR)
40+
selected_root = (
41+
env_root.strip()
42+
if env_root and env_root.strip()
43+
else (config_weights_dir or str(DEFAULT_ARTIFACT_ROOT))
44+
)
45+
return str(_resolve_absolute(selected_root))
46+
47+
48+
def resolve_artifact_file(*relative_parts: str, config_weights_dir: str | None = None) -> str:
49+
if not relative_parts:
50+
raise ValueError("At least one path part is required.")
51+
artifact_root = Path(resolve_artifact_root(config_weights_dir))
52+
return str((artifact_root.joinpath(*relative_parts)).resolve())
53+
54+
55+
def resolve_model_artifacts(
56+
model_name: str,
57+
mode: str | None = None,
58+
config_weights_dir: str | None = None,
59+
model_dir: str | None = None,
60+
) -> ModelArtifactPaths:
61+
normalized_model = model_name.strip().lower()
62+
normalized_mode = _normalize_mode(mode)
63+
root_dir = Path(resolve_artifact_root(config_weights_dir))
64+
65+
if normalized_model == "transformer":
66+
suffix = "_prod"
67+
mode_dir = "prod"
68+
if normalized_mode == "tests":
69+
suffix = "_test"
70+
mode_dir = "tests"
71+
72+
resolved_model_dir = _resolve_absolute(model_dir) if model_dir else (root_dir / "transformer" / mode_dir)
73+
model_path = resolved_model_dir / f"multi_horizon_model{suffix}.keras"
74+
scaler_path = resolved_model_dir / f"multi_horizon_scaler{suffix}.pkl"
75+
return ModelArtifactPaths(
76+
root_dir=str(root_dir),
77+
model_dir=str(resolved_model_dir),
78+
model_path=str(model_path),
79+
scaler_path=str(scaler_path),
80+
metadata_path=None,
81+
)
82+
83+
if normalized_model in {"itransformer", "i_transformer", "i-transformer"}:
84+
resolved_model_dir = _resolve_absolute(model_dir) if model_dir else (root_dir / "itransformer")
85+
return ModelArtifactPaths(
86+
root_dir=str(root_dir),
87+
model_dir=str(resolved_model_dir),
88+
model_path=str(resolved_model_dir / "multi_horizon_model.keras"),
89+
scaler_path=str(resolved_model_dir / "multi_horizon_scaler.pkl"),
90+
metadata_path=str(resolved_model_dir / "metadata.json"),
91+
)
92+
93+
if normalized_model == "tcn":
94+
resolved_model_dir = _resolve_absolute(model_dir) if model_dir else (root_dir / "tcn")
95+
return ModelArtifactPaths(
96+
root_dir=str(root_dir),
97+
model_dir=str(resolved_model_dir),
98+
model_path=str(resolved_model_dir / "model.pt"),
99+
scaler_path=str(resolved_model_dir / "scaler.pkl"),
100+
metadata_path=str(resolved_model_dir / "metadata.json"),
101+
)
102+
103+
if normalized_model == "patchtst":
104+
resolved_model_dir = _resolve_absolute(model_dir) if model_dir else (root_dir / "patchtst")
105+
return ModelArtifactPaths(
106+
root_dir=str(root_dir),
107+
model_dir=str(resolved_model_dir),
108+
model_path=str(resolved_model_dir / "PatchTST_best.pt"),
109+
scaler_path=None,
110+
metadata_path=None,
111+
)
112+
113+
raise ValueError(f"Unsupported model name for artifact resolution: {model_name}")

AI/modules/signal/models/PatchTST/train.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,19 @@
66
import os
77
from torch.utils.data import DataLoader, TensorDataset
88
from .architecture import PatchTST_Model
9+
from AI.config import load_trading_config
10+
from AI.modules.signal.core.artifact_paths import resolve_model_artifacts
11+
12+
13+
def _default_model_save_path() -> str:
14+
try:
15+
trading_config = load_trading_config()
16+
return resolve_model_artifacts(
17+
model_name="patchtst",
18+
config_weights_dir=trading_config.model.weights_dir,
19+
).model_path
20+
except Exception:
21+
return resolve_model_artifacts(model_name="patchtst").model_path
922

1023
# 설정값
1124
CONFIG = {
@@ -15,7 +28,7 @@
1528
'learning_rate': 0.0001,
1629
'epochs': 100,
1730
'patience': 10,
18-
'model_save_path': 'AI/data/weights/PatchTST_best.pt'
31+
'model_save_path': _default_model_save_path()
1932
}
2033

2134
def train_model(train_loader, val_loader, device):
@@ -100,4 +113,4 @@ def run_training(X_train, y_train, X_val, y_val):
100113
val_loader = DataLoader(val_data, batch_size=CONFIG['batch_size'], shuffle=False)
101114

102115
trained_model = train_model(train_loader, val_loader, device)
103-
return trained_model
116+
return trained_model

0 commit comments

Comments
 (0)