Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aixplain/v1/modules/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def poll(self, poll_url: Text, name: Text = "model_process") -> ModelResponse:
used_credits=resp.pop("usedCredits", 0),
run_time=resp.pop("runTime", 0),
usage=resp.pop("usage", None),
asset=resp.pop("asset", None),
error_code=resp.get("error_code", None),
**resp,
)
Expand Down Expand Up @@ -421,6 +422,7 @@ def run(
used_credits=response.pop("usedCredits", 0),
run_time=response.pop("runTime", 0),
usage=response.pop("usage", None),
asset=response.pop("asset", None),
error_code=response.get("error_code", None),
**response,
)
Expand Down
1 change: 1 addition & 0 deletions aixplain/v1/modules/model/llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def run(
used_credits=response.pop("usedCredits", 0),
run_time=response.pop("runTime", 0),
usage=response.pop("usage", None),
asset=response.pop("asset", None),
error_code=response.get("error_code", None),
**response,
)
Expand Down
9 changes: 8 additions & 1 deletion aixplain/v1/modules/model/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
usage: Optional[Dict] = None,
url: Optional[Text] = None,
error_code: Optional[ErrorCode] = None,
asset: Optional[Dict] = None,
**kwargs,
):
"""Initialize a new ModelResponse instance.
Expand All @@ -35,9 +36,11 @@ def __init__(
error_message (Text): The error message if the response is not successful.
used_credits (float): The amount of credits used for the response.
run_time (float): The time taken to generate the response.
usage (Optional[Dict]): Usage information about the response.
usage (Optional[Dict]): Usage information about the response (prompt_tokens,
completion_tokens, total_tokens).
url (Optional[Text]): The URL of the response.
error_code (Optional[ErrorCode]): The error code if the response is not successful.
asset (Optional[Dict]): Asset information (assetId, id) from model serving.
**kwargs: Additional keyword arguments.
"""
self.status = status
Expand All @@ -54,6 +57,7 @@ def __init__(
self.usage = usage
self.url = url
self.error_code = error_code
self.asset = asset
self.additional_fields = kwargs

def __getitem__(self, key: Text) -> Any:
Expand Down Expand Up @@ -137,6 +141,8 @@ def __repr__(self) -> str:
fields.append(f"run_time={self.run_time}")
if self.usage:
fields.append(f"usage={self.usage}")
if self.asset:
fields.append(f"asset={self.asset}")
if self.url:
fields.append(f"url='{self.url}'")
if self.error_code:
Expand Down Expand Up @@ -175,6 +181,7 @@ def to_dict(self) -> Dict[Text, Any]:
"used_credits": self.used_credits,
"run_time": self.run_time,
"usage": self.usage,
"asset": self.asset,
"url": self.url,
"error_code": self.error_code,
}
Expand Down
55 changes: 52 additions & 3 deletions aixplain/v2/agent_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,9 @@ def _format_step_line(
step_line += f" · API {api_calls:2d}"
step_credits = step.get("used_credits") or step.get("usedCredits") or 0
step_line += f" · ${step_credits:.6f}"
token_text = self._format_token_usage_inline(step)
if token_text:
step_line += f" · ⊘ {token_text}"

step_line += f" · {agent_action_part}"
return step_line
Expand Down Expand Up @@ -588,35 +591,65 @@ def _print_step_details(self, step: Dict, idx: int) -> None:
else:
print(f" {self._format_multiline(str(output_data))}")

def _print_token_usage(self, step: Dict) -> None:
"""Print token usage for a step if available."""
text = self._format_token_usage_inline(step)
if text:
print(f" ⊘ Tokens: {text}")

def _format_token_usage_inline(self, step: Dict) -> str:
"""Format token usage as a compact inline string, or empty if unavailable."""
input_tok = step.get("input_tokens")
output_tok = step.get("output_tokens")
total_tok = step.get("total_tokens")

if input_tok is None and output_tok is None and total_tok is None:
return ""

parts = []
if input_tok is not None:
parts.append(f"in={input_tok}")
if output_tok is not None:
parts.append(f"out={output_tok}")
if total_tok is not None:
parts.append(f"total={total_tok}")
return " · ".join(parts)

def _print_completion_message(self, status: str, steps: List[Dict]) -> None:
"""Print final completion message with stats."""
total_steps = len(steps) if steps else 0
total_elapsed = (self._now() - self._total_start_time) if self._total_start_time else 0

prefix = "\n" if self._format == ProgressFormat.STATUS else ""

token_suffix = ""
total_input = getattr(self, "_total_input_tokens", 0)
total_output = getattr(self, "_total_output_tokens", 0)
if total_input or total_output:
token_suffix = f" · ⊘ {total_input}→{total_output} tokens"

if status == "SUCCESS":
print(
f"{prefix}✓ Completed {total_steps} steps · "
f"⏱ {self._format_elapsed(total_elapsed)} · "
f"API {self._total_api_calls} · "
f"${self._total_credits:.6f}"
f"${self._total_credits:.6f}{token_suffix}"
)
elif status in {"FAILED", "ABORTED", "CANCELLED", "ERROR"}:
print(
f"{prefix}✗ Agent failed with status: {status} · "
f"{total_steps} steps · "
f"⏱ {self._format_elapsed(total_elapsed)} · "
f"API {self._total_api_calls} · "
f"${self._total_credits:.6f}"
f"${self._total_credits:.6f}{token_suffix}"
)
else:
print(
f"{prefix}⏸ Stopped: reached max polling limit ({self.max_polls}) · "
f"{total_steps} steps · "
f"⏱ {self._format_elapsed(total_elapsed)} · "
f"API {self._total_api_calls} · "
f"${self._total_credits:.6f}"
f"${self._total_credits:.6f}{token_suffix}"
)

# =========================================================================
Expand Down Expand Up @@ -667,6 +700,8 @@ def _update_metrics(self, steps: List[Dict]) -> None:
"""Update tracking metrics from steps data."""
self._total_credits = 0.0
self._total_api_calls = 0
self._total_input_tokens = 0
self._total_output_tokens = 0
for idx, s in enumerate(steps):
sid = s.get("_progress_id")
if sid not in self._first_seen:
Expand All @@ -681,6 +716,20 @@ def _update_metrics(self, steps: List[Dict]) -> None:
if api_calls:
self._total_api_calls += int(api_calls)

input_tokens = s.get("input_tokens")
if input_tokens is not None:
try:
self._total_input_tokens += int(input_tokens)
except (ValueError, TypeError):
pass

output_tokens = s.get("output_tokens")
if output_tokens is not None:
try:
self._total_output_tokens += int(output_tokens)
except (ValueError, TypeError):
pass

def _display_logs_format(self, steps: List[Dict]) -> None:
"""Handle display for LOGS format (event timeline)."""
for idx, step in enumerate(steps):
Expand Down
59 changes: 52 additions & 7 deletions aixplain/v2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,57 @@ class Detail:
finish_reason: Optional[str] = field(default=None, metadata=config(field_name="finish_reason"))


def _safe_token_count(val: Any) -> Optional[int]:
"""Coerce a token count to int, returning None for unparseable values.

The model-serving backend returns token counts inconsistently:
valid ints, numeric strings (``"20"``), ``"NaN"``, ``null``, or ``"0"``.
This helper normalises all of those without raising.
"""
if val is None:
return None
if isinstance(val, int):
return val
if isinstance(val, float):
import math

return None if math.isnan(val) else int(val)
s = str(val).strip()
if not s or s.lower() == "nan" or s.lower() == "null" or s.lower() == "none":
return None
try:
return int(s)
except (ValueError, TypeError):
try:
import math

f = float(s)
return None if math.isnan(f) else int(f)
except (ValueError, TypeError):
return None


@dataclass_json
@dataclass
class Usage:
"""Usage structure from the API response."""
"""Usage structure from the API response.

Token counts are nullable because some model providers (GPT-5.4, Claude,
Mistral Large) return ``"NaN"`` or ``null`` instead of integers.
"""

prompt_tokens: int = field(metadata=config(field_name="prompt_tokens"))
completion_tokens: int = field(metadata=config(field_name="completion_tokens"))
total_tokens: int = field(metadata=config(field_name="total_tokens"))
prompt_tokens: Optional[int] = field(
default=None,
metadata=config(field_name="prompt_tokens", decoder=_safe_token_count),
)
completion_tokens: Optional[int] = field(
default=None,
metadata=config(field_name="completion_tokens", decoder=_safe_token_count),
)
total_tokens: Optional[int] = field(
default=None,
metadata=config(field_name="total_tokens", decoder=_safe_token_count),
)


@dataclass_json
Expand All @@ -73,6 +116,7 @@ class ModelResult(Result):
run_time: Optional[float] = field(default=None, metadata=config(field_name="runTime"))
used_credits: Optional[float] = field(default=None, metadata=config(field_name="usedCredits"))
usage: Optional[Usage] = None
asset: Optional[Dict[str, Any]] = None


@dataclass
Expand Down Expand Up @@ -602,9 +646,10 @@ def run(self, **kwargs: Unpack[ModelRunParams]) -> ModelResult:
raise ValueError(f"Parameter validation failed: {'; '.join(param_errors)}")

if self.is_sync_only:
# Sync-only models: Call V2 endpoint directly (bypass run_async which would route to V1)
# V2 returns result directly for sync models, no polling needed
return self._run_sync_v2(**effective_params)
result = self._run_sync_v2(**effective_params)
if result.url and not result.completed:
result = self.sync_poll(result.url, **effective_params)
return result
else:
# Async-capable models: Use base run() which calls run_async() and polls
return super().run(**effective_params)
Expand Down
17 changes: 16 additions & 1 deletion aixplain/v2/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,8 @@ def poll(self, poll_url: str) -> ResultT:
"usedCredits": response.get("usedCredits", 0.0),
"runTime": response.get("runTime", 0.0),
"requestId": response.get("requestId"),
"usage": response.get("usage"),
"asset": response.get("asset"),
}
status = response.get("status", "IN_PROGRESS")

Expand All @@ -1286,7 +1288,20 @@ def poll(self, poll_url: str) -> ResultT:
try:
result = response_class.from_dict(filtered_response)
except Exception:
raise
if filtered_response.get("completed"):
logger.warning(
"Poll response deserialization failed for a completed response. "
"Building fallback result from raw data."
)
result = response_class.from_dict(
{
"status": filtered_response["status"],
"completed": True,
"data": filtered_response.get("data") or {},
}
)
else:
raise

# Attach raw response
result._raw_data = response
Expand Down
9 changes: 7 additions & 2 deletions tests/unit/llm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def test_run_with_custom_parameters():
"data": "Test Result",
"usedCredits": 10,
"runTime": 1.5,
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
"asset": {"assetId": "test-model-id", "id": "openai/gpt-5-mini/openai"},
}

with requests_mock.Mocker() as mock:
Expand All @@ -156,4 +157,8 @@ def test_run_with_custom_parameters():
assert response.data == "Test Result"
assert response.used_credits == 10
assert response.run_time == 1.5
assert response.usage == {"prompt_tokens": 10, "completion_tokens": 20}
assert response.usage == {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}
assert response.usage["prompt_tokens"] == 10
assert response.usage["completion_tokens"] == 20
assert response.usage["total_tokens"] == 30
assert response.asset == {"assetId": "test-model-id", "id": "openai/gpt-5-mini/openai"}
16 changes: 11 additions & 5 deletions tests/unit/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,10 @@ def test_run_sync():
"completed": True,
"status": "SUCCESS",
"data": "Test Model Result",
"usedCredits": 0,
"runTime": 0,
"usedCredits": 0.05,
"runTime": 1.2,
"usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40},
"asset": {"assetId": "test-model-id", "id": "openai/gpt-5-mini/openai"},
}

with requests_mock.Mocker() as mock:
Expand All @@ -292,9 +294,13 @@ def test_run_sync():
assert response.status == ResponseStatus.SUCCESS
assert response.data == "Test Model Result"
assert response.completed is True
assert response.used_credits == 0
assert response.run_time == 0
assert response.usage is None
assert response.used_credits == 0.05
assert response.run_time == 1.2
assert response.usage == {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40}
assert response.usage["prompt_tokens"] == 15
assert response.usage["completion_tokens"] == 25
assert response.usage["total_tokens"] == 40
assert response.asset == {"assetId": "test-model-id", "id": "openai/gpt-5-mini/openai"}


def test_sync_poll():
Expand Down
Loading
Loading