Skip to content

Commit fc9aea5

Browse files
Merge pull request #8 from S-FM/feat/limix
added tabular inference
2 parents 5fdba0a + 9f1fd11 commit fc9aea5

39 files changed

+5033
-220
lines changed

.DS_Store

0 Bytes
Binary file not shown.

README.md

Lines changed: 194 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,21 @@
44
[![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/)
55
[![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
66

7-
Production-ready Python SDK for FAIM (Foundation AI Models) - a high-performance time-series forecasting platform powered by foundation models.
7+
Production-ready Python SDK for FAIM (Foundation AI Models) - a unified platform for time-series forecasting and tabular inference powered by foundation models.
88

99
## Features
1010

11-
- **🚀 Multiple Foundation Models**: FlowState, Amazon Chronos 2.0, TiRex
11+
- **🚀 Multiple Foundation Models**:
12+
- **Time-Series**: FlowState, Amazon Chronos 2.0, TiRex
13+
- **Tabular**: LimiX (classification & regression)
1214
- **🔒 Type-Safe API**: Full type hints with Pydantic validation
1315
- **⚡ High Performance**: Optimized Apache Arrow serialization with zero-copy operations
14-
- **🎯 Probabilistic & Deterministic**: Point forecasts, quantiles, and samples
16+
- **🎯 Probabilistic & Deterministic**: Point forecasts, quantiles, samples, and probabilistic predictions
1517
- **🔄 Async Support**: Built-in async/await support for concurrent requests
1618
- **📊 Rich Error Handling**: Machine-readable error codes with detailed diagnostics
1719
- **🧪 Battle-Tested**: Production-ready with comprehensive error handling
1820
- **📈 Evaluation Tools**: Built-in metrics (MSE, MASE, CRPS) and visualization utilities
21+
- **🔎 Retrieval-Augmented Inference**: Optional RAI for improved accuracy on small datasets
1922

2023
## Installation
2124

@@ -67,7 +70,9 @@ print(response.metadata) # Model version, inference time, etc.
6770

6871
### Input Data Format
6972

70-
**All models require 3D input arrays:**
73+
#### Time-Series Models (FlowState, Chronos2, TiRex)
74+
75+
**All time-series models require 3D input arrays:**
7176

7277
```python
7378
# Shape: (batch_size, sequence_length, features)
@@ -83,8 +88,25 @@ x = np.array([
8388

8489
**Important**: 2D input will raise a validation error. Always provide 3D arrays.
8590

91+
#### Tabular Models (LimiX)
92+
93+
**Tabular models require 2D input arrays:**
94+
95+
```python
96+
# Shape: (n_samples, n_features)
97+
X_train = np.array([
98+
[1.0, 2.0, 3.0], # Sample 1
99+
[4.0, 5.0, 6.0], # Sample 2
100+
]) # Shape: (2, 3)
101+
```
102+
103+
- **n_samples**: Number of training/test samples
104+
- **n_features**: Number of input features per sample
105+
86106
### Output Data Format
87107

108+
#### Time-Series Output
109+
88110
**Point Forecasts** (3D):
89111
```python
90112
response.point # Shape: (batch_size, horizon, features)
@@ -96,11 +118,27 @@ response.quantiles # Shape: (batch_size, horizon, num_quantiles, features)
96118
# Example: (32, 24, 5, 1) = 32 series, 24 steps ahead, 5 quantiles, 1 feature
97119
```
98120

99-
### Univariate vs Multivariate
121+
#### Tabular Output
122+
123+
**Predictions** (1D):
124+
```python
125+
response.predictions # Shape: (n_samples,)
126+
# Classification: class labels or indices
127+
# Regression: continuous values
128+
```
129+
130+
**Classification Probabilities** (2D):
131+
```python
132+
response.probabilities # Shape: (n_samples, n_classes) - classification only
133+
# Probability for each class
134+
```
135+
136+
### Univariate vs Multivariate (Time-Series Only)
100137

101138
- **Chronos2**: ✅ Supports multivariate forecasting (multiple features)
102139
- **FlowState**: ⚠️ Univariate only - automatically transforms multivariate input
103140
- **TiRex**: ⚠️ Univariate only - automatically transforms multivariate input
141+
- **LimiX**: ✅ Supports multivariate tabular features (standard in tabular inference)
104142

105143
When you provide multivariate input (features > 1) to FlowState or TiRex, the SDK automatically:
106144
1. Issues a warning
@@ -121,7 +159,19 @@ print(response.point.shape) # (2, 24, 3) - original structure preserved
121159

122160
## Available Models
123161

124-
### FlowState
162+
### Model Selection Guide
163+
164+
Choose your client and model based on your task:
165+
166+
| Task | Client | Models | Input | Output |
167+
|------|--------|--------|-------|--------|
168+
| **Time-Series Forecasting** | `ForecastClient` | FlowState, Chronos2, TiRex | 3D: `(batch, seq_len, features)` | 3D/4D point/quantiles |
169+
| **Tabular Classification** | `TabularClient` | LimiX | 2D: `(n_samples, n_features)` | 1D predictions + 2D probabilities |
170+
| **Tabular Regression** | `TabularClient` | LimiX | 2D: `(n_samples, n_features)` | 1D continuous predictions |
171+
172+
### Time-Series Models
173+
174+
#### FlowState
125175

126176
```python
127177
from faim_sdk import FlowStateForecastRequest
@@ -139,7 +189,7 @@ response = client.forecast(request)
139189
print(response.point.shape) # (batch_size, 24, features)
140190
```
141191

142-
### Chronos 2.0
192+
#### Chronos 2.0
143193

144194
```python
145195
from faim_sdk import Chronos2ForecastRequest
@@ -156,7 +206,7 @@ response = client.forecast(request)
156206
print(response.quantiles.shape) # (batch_size, 24, 5)
157207
```
158208

159-
### TiRex
209+
#### TiRex
160210

161211
```python
162212
from faim_sdk import TiRexForecastRequest
@@ -171,9 +221,117 @@ response = client.forecast(request)
171221
print(response.point.shape) # (batch_size, 24, features)
172222
```
173223

174-
## Response Format
224+
### LimiX
225+
226+
The SDK also supports **LimiX**, a foundation model for tabular classification and regression:
227+
228+
```python
229+
from faim_sdk import TabularClient, LimiXPredictRequest
230+
import numpy as np
231+
232+
# Initialize tabular client
233+
client = TabularClient(api_key="your-api-key")
234+
235+
# Prepare tabular data (2D arrays)
236+
X_train = np.random.randn(100, 10).astype(np.float32)
237+
y_train = np.random.randint(0, 2, 100).astype(np.float32)
238+
X_test = np.random.randn(20, 10).astype(np.float32)
239+
240+
# Create classification request
241+
request = LimiXPredictRequest(
242+
X_train=X_train,
243+
y_train=y_train,
244+
X_test=X_test,
245+
task_type="Classification", # or "Regression"
246+
use_retrieval=False # Set to True for retrieval-augmented inference
247+
)
248+
249+
# Generate predictions
250+
response = client.predict(request)
251+
print(response.predictions.shape) # (20,)
252+
print(response.probabilities.shape) # (20, n_classes) - classification only
253+
```
254+
255+
### Classification Example
256+
257+
```python
258+
from sklearn.datasets import load_breast_cancer
259+
from sklearn.model_selection import train_test_split
260+
261+
# Load dataset
262+
X, y = load_breast_cancer(return_X_y=True)
263+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
264+
265+
# Convert to float32
266+
X_train = X_train.astype(np.float32)
267+
X_test = X_test.astype(np.float32)
268+
y_train = y_train.astype(np.float32)
269+
270+
# Create and send request
271+
request = LimiXPredictRequest(
272+
X_train=X_train,
273+
y_train=y_train,
274+
X_test=X_test,
275+
task_type="Classification"
276+
)
277+
278+
response = client.predict(request)
279+
280+
# Evaluate
281+
from sklearn.metrics import accuracy_score
282+
accuracy = accuracy_score(y_test, response.predictions.astype(int))
283+
print(f"Accuracy: {accuracy:.4f}")
284+
```
175285

176-
All forecasts return a `ForecastResponse` object with predictions and metadata:
286+
### Regression Example
287+
288+
```python
289+
from sklearn.datasets import fetch_california_housing
290+
291+
# Load dataset
292+
house_data = fetch_california_housing()
293+
X, y = house_data.data, house_data.target
294+
295+
# Split data (50/50 for demo)
296+
split_idx = len(X) // 2
297+
X_train, X_test = X[:split_idx].astype(np.float32), X[split_idx:].astype(np.float32)
298+
y_train, y_test = y[:split_idx].astype(np.float32), y[split_idx:].astype(np.float32)
299+
300+
# Create and send request
301+
request = LimiXPredictRequest(
302+
X_train=X_train,
303+
y_train=y_train,
304+
X_test=X_test,
305+
task_type="Regression"
306+
)
307+
308+
response = client.predict(request)
309+
310+
# Evaluate
311+
from sklearn.metrics import mean_squared_error
312+
rmse = np.sqrt(mean_squared_error(y_test, response.predictions))
313+
print(f"RMSE: {rmse:.4f}")
314+
```
315+
316+
### Retrieval-Augmented Inference
317+
318+
For better accuracy on small datasets, enable retrieval-augmented inference:
319+
320+
```python
321+
request = LimiXPredictRequest(
322+
X_train=X_train,
323+
y_train=y_train,
324+
X_test=X_test,
325+
task_type="Classification",
326+
use_retrieval=True # Enable RAI (slower but more accurate)
327+
)
328+
329+
response = client.predict(request)
330+
```
331+
332+
## Response Format (Time-Series Forecasting)
333+
334+
Time-series forecasts return a `ForecastResponse` object with predictions and metadata:
177335

178336
```python
179337
response = client.forecast(request)
@@ -197,9 +355,11 @@ print(response.metadata)
197355
# {'model_name': 'chronos2', 'model_version': '1.0', 'inference_time_ms': 123}
198356
```
199357

200-
## Evaluation & Metrics
358+
## Evaluation & Metrics (Time-Series Forecasting)
359+
360+
The SDK includes a comprehensive evaluation toolkit (`faim_sdk.eval`) for measuring time-series forecast quality with standard metrics and visualizations.
201361

202-
The SDK includes a comprehensive evaluation toolkit (`faim_sdk.eval`) for measuring forecast quality with standard metrics and visualizations.
362+
**Note**: These metrics are designed for time-series forecasting evaluation. For tabular model evaluation (classification/regression), use standard scikit-learn metrics like `accuracy_score`, `mean_squared_error`, etc. (see tabular examples above).
203363

204364
### Installation
205365

@@ -209,7 +369,7 @@ For visualization support, install with the viz extra:
209369
pip install faim-sdk[viz]
210370
```
211371

212-
### Available Metrics
372+
### Available Metrics for Time-Series
213373

214374
#### Mean Squared Error (MSE)
215375

@@ -261,9 +421,9 @@ crps_score = crps_from_quantiles(
261421
print(f"CRPS: {crps_score:.4f}")
262422
```
263423

264-
### Visualization
424+
### Visualization (Time-Series Only)
265425

266-
Plot forecasts with training context and ground truth:
426+
Plot time-series forecasts with training context and ground truth:
267427

268428
```python
269429
from faim_sdk.eval import plot_forecast
@@ -463,7 +623,21 @@ responses = asyncio.run(forecast_multiple_series())
463623

464624
See the `examples/` directory for complete Jupyter notebook examples:
465625

466-
- **`toy_example.ipynb`** - A toy example showing how to get started with FAIM and generate both point and probabilistic forecasts.
626+
### Time-Series Forecasting
627+
- **`toy_example.ipynb`** - Get started with FAIM and generate both point and probabilistic forecasts
628+
- **`airpassengers_dataset.ipynb`** - End-to-end example with AirPassengers dataset
629+
630+
### Tabular Inference with LimiX
631+
- **`limix_classification_example.ipynb`** - Binary classification on breast cancer dataset
632+
- Standard approach with LimiX
633+
- Retrieval-Augmented Inference (RAI) comparison
634+
- Side-by-side metrics comparison (Accuracy, Precision, Recall, F1-Score)
635+
636+
- **`limix_regression_example.ipynb`** - Regression on California housing dataset
637+
- Standard approach with LimiX
638+
- Retrieval-Augmented Inference (RAI) comparison
639+
- Comprehensive metrics comparison (MSE, RMSE, MAE, R²)
640+
- Residual statistics analysis
467641

468642
## Requirements
469643

@@ -475,6 +649,8 @@ See the `examples/` directory for complete Jupyter notebook examples:
475649

476650
## Performance Tips
477651

652+
### Time-Series Forecasting
653+
478654
1. **Batch Processing**: Process multiple time series in a single request for optimal throughput
479655
```python
480656
# Good: Single request with 32 series
@@ -488,6 +664,8 @@ See the `examples/` directory for complete Jupyter notebook examples:
488664

489665
3. **Async for Concurrent Requests**: Use `forecast_async()` with `asyncio.gather()` for parallel processing
490666

667+
### General (All Models)
668+
491669
4. **Connection Pooling**: Reuse client instances across requests instead of creating new ones
492670

493671
## Support

0 commit comments

Comments
 (0)