Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ariv/runner/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ class EvalRequest(BaseModel):


def _registry_path() -> Path:
return Path(os.getenv("ARIV_MODELS_YAML", "ariv/models/models.yaml"))
env_path = os.getenv("ARIV_MODELS_YAML")
if env_path:
return Path(env_path).expanduser().resolve()
return (Path(__file__).resolve().parents[1] / "models" / "models.yaml").resolve()


app = FastAPI(title="ARIV Runner", version="0.2.0")
Expand Down
20 changes: 20 additions & 0 deletions ariv/runner/llama_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ async def stream_chat(
)

assert process.stdout
stderr_chunks: List[bytes] = []

async def _collect_stderr() -> None:
assert process.stderr
async for chunk in process.stderr:
stderr_chunks.append(chunk)

stderr_task = asyncio.create_task(_collect_stderr())

async for raw_line in process.stdout:
line = raw_line.decode("utf-8").strip()
if not line:
Expand All @@ -77,6 +86,17 @@ async def stream_chat(
yield str(token)

await process.wait()
await stderr_task

if process.returncode != 0:
stderr_tail = b"".join(stderr_chunks).decode("utf-8", errors="replace").strip()
if len(stderr_tail) > 1200:
stderr_tail = stderr_tail[-1200:]
raise RuntimeError(
"llama-cli failed: "
f"binary={self._binary}, model={model_path}, exit_code={process.returncode}, "
f"stderr={stderr_tail or '<empty>'}"
)

async def run_chat(
self,
Expand Down
14 changes: 12 additions & 2 deletions core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
import os
import time
from typing import Dict, Optional, Any, List, Callable
from llama_cpp import Llama
try:
from llama_cpp import Llama
_HAS_LLAMA = True
except ModuleNotFoundError: # pragma: no cover - optional runtime dependency
Llama = Any # type: ignore[assignment]
_HAS_LLAMA = False
from .vram_manager import VRAMManager, MemoryProfiler
from ..tools.registry import ToolRegistry
from tools.registry import ToolRegistry
import logging

logger = logging.getLogger("JugaadOrchestrator")
Expand Down Expand Up @@ -87,6 +92,11 @@ def load_model(self,
load_start = time.time()

try:
if not _HAS_LLAMA:
raise ModelLoadError(
"llama_cpp is not installed. Install project dependencies to load GGUF models."
)

# Get memory optimization recommendations
model_size_gb = os.path.getsize(model_path) / (1024**3)
mem_opt = self.vram_manager.optimize_for_model(model_size_gb)
Expand Down
4 changes: 2 additions & 2 deletions core/trv_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import time
from typing import Dict, List, Optional, Any
from .orchestrator import JugaadOrchestrator
from ..config import (INDIAN_LANGUAGES_22, PIPELINE_CONFIG,
COT_CONFIG, TOOL_CONFIG, ARC_CONFIG)
from config import (INDIAN_LANGUAGES_22, PIPELINE_CONFIG,
COT_CONFIG, TOOL_CONFIG, ARC_CONFIG)
import json

logger = logging.getLogger("TRVPipeline")
Expand Down
10 changes: 7 additions & 3 deletions core/vram_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""

import gc
import torch
import logging
import time
from typing import Optional, Dict, Any
Expand All @@ -14,6 +13,11 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("VRAMManager")

try:
import torch
except ModuleNotFoundError: # pragma: no cover - optional runtime dependency
torch = None

@dataclass
class MemoryStats:
"""Memory statistics data class"""
Expand Down Expand Up @@ -60,7 +64,7 @@ def flush(self, aggressive: bool = True) -> Dict[str, Any]:
gc.collect()

# Step 2: Clear PyTorch CUDA cache multiple times
if torch.cuda.is_available():
if torch is not None and torch.cuda.is_available():
for i in range(2 if aggressive else 1):
torch.cuda.empty_cache()

Expand Down Expand Up @@ -92,7 +96,7 @@ def flush(self, aggressive: bool = True) -> Dict[str, Any]:

def get_memory_stats(self) -> MemoryStats:
"""Get current GPU memory statistics with fragmentation analysis"""
if not torch.cuda.is_available():
if torch is None or not torch.cuda.is_available():
return MemoryStats(
allocated_gb=0,
reserved_gb=0,
Expand Down
57 changes: 57 additions & 0 deletions tests/test_llama_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import asyncio

from ariv.runner.llama_cli import LlamaCLI


def test_stream_chat_raises_on_nonzero_exit(monkeypatch, tmp_path) -> None:
monkeypatch.delenv("ARIV_MOCK_LLAMA", raising=False)
model = tmp_path / "dummy.gguf"
model.write_text("x", encoding="utf-8")

class FakeStream:
def __init__(self, chunks):
self._chunks = chunks

def __aiter__(self):
self._iter = iter(self._chunks)
return self

async def __anext__(self):
try:
return next(self._iter)
except StopIteration as exc:
raise StopAsyncIteration from exc

class FakeProcess:
def __init__(self):
self.stdout = FakeStream([b'{"token":"ok"}\n'])
self.stderr = FakeStream([b"fatal error\n"])
self.returncode = 17

async def wait(self):
return self.returncode

async def fake_exec(*args, **kwargs):
return FakeProcess()

monkeypatch.setattr("asyncio.create_subprocess_exec", fake_exec)

cli = LlamaCLI(binary="llama-cli")

async def _consume() -> None:
async for _ in cli.stream_chat(
model_path=str(model),
prompt="hello",
num_gpu_layers=0,
max_tokens=2,
):
pass

try:
asyncio.run(_consume())
raise AssertionError("Expected RuntimeError")
except RuntimeError as exc:
message = str(exc)
assert "exit_code=17" in message
assert "fatal error" in message
assert "binary=llama-cli" in message
6 changes: 4 additions & 2 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
def test_pipeline_phases():
"""Test that all 4 phases execute"""
mock_orch = Mock()
mock_orch.models_config = {}
mock_orch.generate.return_value = "test output"

pipeline = TRVPipeline(mock_orch, {})
Expand All @@ -19,7 +20,7 @@ def test_pipeline_phases():
pipeline._phase3_critic = Mock(return_value="PASS")
pipeline._phase4_synthesis = Mock(return_value="final")

result = pipeline.execute("query", "hindi", enable_critic=True)
result = pipeline.execute("query", "hindi", enable_critic=True, enable_deep_cot=False)

# Verify all phases called
pipeline._phase1_ingestion.assert_called_once()
Expand All @@ -32,6 +33,7 @@ def test_pipeline_phases():
def test_critic_loop():
"""Test critic iteration logic"""
mock_orch = Mock()
mock_orch.models_config = {}
pipeline = TRVPipeline(mock_orch, {})

# First critic fails, second passes
Expand All @@ -40,7 +42,7 @@ def test_critic_loop():
pipeline._phase1_ingestion = Mock(return_value="english")
pipeline._phase4_synthesis = Mock(return_value="final")

result = pipeline.execute("query", "hindi", enable_critic=True)
result = pipeline.execute("query", "hindi", enable_critic=True, enable_deep_cot=False)

# Should have 2 critic calls (fail then pass)
assert pipeline._phase3_critic.call_count == 2
Expand Down
9 changes: 9 additions & 0 deletions tests/test_runner_app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import os

from fastapi.testclient import TestClient
Expand All @@ -15,3 +16,11 @@ def test_chat_streaming_response() -> None:
assert response.status_code == 200
text = response.text
assert "metadata" in text


def test_registry_path_independent_of_cwd(monkeypatch, tmp_path) -> None:
monkeypatch.delenv("ARIV_MODELS_YAML", raising=False)
monkeypatch.chdir(tmp_path)
runner_module = importlib.import_module("ariv.runner.app")
importlib.reload(runner_module)
assert runner_module.registry.get("mock-0.1b-q4_0").name == "mock-0.1b-q4_0"
24 changes: 12 additions & 12 deletions tests/test_vram.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,32 @@
"""Tests for VRAM Manager"""

import pytest
import torch
from core.vram_manager import VRAMManager
from core.vram_manager import VRAMManager, torch

def test_vram_stats():
"""Test VRAM stats collection"""
if not torch.cuda.is_available():
if torch is None or not torch.cuda.is_available():
pytest.skip("CUDA not available")

stats = VRAMManager.get_memory_stats()
assert stats['available'] == True
assert 'allocated_gb' in stats
assert 'total_gb' in stats
assert stats['total_gb'] > 0
manager = VRAMManager()
stats = manager.get_memory_stats()
assert stats.total_gb > 0
assert stats.available_gb >= 0
assert stats.allocated_gb >= 0

def test_flush_protocol():
"""Test VRAM flush doesn't crash"""
if not torch.cuda.is_available():
if torch is None or not torch.cuda.is_available():
pytest.skip("CUDA not available")

# Allocate some memory
x = torch.randn(1000, 1000).cuda()
del x

# Should not raise
VRAMManager.flush()
manager = VRAMManager()
manager.flush()

# Memory should be reduced (or at least not crash)
stats = VRAMManager.get_memory_stats()
assert stats['available'] == True
stats = manager.get_memory_stats()
assert stats.available_gb >= 0
Loading