diff --git a/README.md b/README.md
index 4b6d184..604b971 100644
--- a/README.md
+++ b/README.md
@@ -300,9 +300,11 @@ const App = () => {
### Embedding
-Convert text into numerical vector representations that capture semantic meaning, useful for similarity search and semantic understanding.
+Convert text and images into numerical vector representations that capture semantic meaning, useful for similarity search and semantic understanding.
-#### Class
+#### Text Embedding
+
+##### Class
```typescript
import { CactusLM } from 'cactus-react-native';
@@ -314,7 +316,7 @@ console.log('Embedding vector:', result.embedding);
console.log('Embedding vector length:', result.embedding.length);
```
-#### Hook
+##### Hook
```tsx
import { useCactusLM } from 'cactus-react-native';
@@ -332,6 +334,38 @@ const App = () => {
};
```
+#### Image Embedding
+
+##### Class
+
+```typescript
+import { CactusLM } from 'cactus-react-native';
+
+const cactusLM = new CactusLM({ model: 'lfm2-vl-450m' });
+
+const result = await cactusLM.imageEmbed({ imagePath: 'path/to/your/image.jpg' });
+console.log('Image embedding vector:', result.embedding);
+console.log('Embedding vector length:', result.embedding.length);
+```
+
+##### Hook
+
+```tsx
+import { useCactusLM } from 'cactus-react-native';
+
+const App = () => {
+ const cactusLM = useCactusLM({ model: 'lfm2-vl-450m' });
+
+ const handleImageEmbed = async () => {
+ const result = await cactusLM.imageEmbed({ imagePath: 'path/to/your/image.jpg' });
+ console.log('Image embedding vector:', result.embedding);
+ console.log('Embedding vector length:', result.embedding.length);
+ };
+
+ return ;
+};
+```
+
### Hybrid Mode (Cloud Fallback)
The CactusLM supports a hybrid completion mode that falls back to a cloud-based LLM provider `OpenRouter` if local inference fails.
@@ -383,6 +417,96 @@ const App = () => {
};
```
+## Speech-to-Text (STT)
+
+The `CactusSTT` class provides audio transcription and audio embedding capabilities using Whisper models.
+
+### Transcription
+
+Transcribe audio files to text with streaming support.
+
+#### Class
+
+```typescript
+import { CactusSTT } from 'cactus-react-native';
+
+const cactusSTT = new CactusSTT({ model: 'whisper-small' });
+
+await cactusSTT.init();
+
+const result = await cactusSTT.transcribe({
+ audioFilePath: 'path/to/audio.wav',
+ onToken: (token) => console.log('Token:', token)
+});
+
+console.log('Transcription:', result.response);
+```
+
+#### Hook
+
+```tsx
+import { useCactusSTT } from 'cactus-react-native';
+
+const App = () => {
+ const cactusSTT = useCactusSTT({ model: 'whisper-small' });
+
+ const handleTranscribe = async () => {
+ const result = await cactusSTT.transcribe({
+ audioFilePath: 'path/to/audio.wav',
+ });
+ console.log('Transcription:', result.response);
+ };
+
+ return (
+ <>
+
+ {cactusSTT.response}
+ >
+ );
+};
+```
+
+### Audio Embedding
+
+Generate embeddings from audio files for audio understanding.
+
+#### Class
+
+```typescript
+import { CactusSTT } from 'cactus-react-native';
+
+const cactusSTT = new CactusSTT();
+
+await cactusSTT.init();
+
+const result = await cactusSTT.audioEmbed({
+ audioPath: 'path/to/audio.wav'
+});
+
+console.log('Audio embedding vector:', result.embedding);
+console.log('Embedding vector length:', result.embedding.length);
+```
+
+#### Hook
+
+```tsx
+import { useCactusSTT } from 'cactus-react-native';
+
+const App = () => {
+ const cactusSTT = useCactusSTT();
+
+ const handleAudioEmbed = async () => {
+ const result = await cactusSTT.audioEmbed({
+ audioPath: 'path/to/audio.wav'
+ });
+ console.log('Audio embedding vector:', result.embedding);
+ console.log('Embedding vector length:', result.embedding.length);
+ };
+
+ return ;
+};
+```
+
## API Reference
### CactusLM Class
@@ -400,14 +524,14 @@ const App = () => {
**`download(params?: CactusLMDownloadParams): Promise`**
-Downloads the model. If the model is already downloaded, returns immediately with progress at 100%. Throws an error if a download is already in progress. Automatically refreshes the models list after successful download.
+Downloads the model. If the model is already downloaded, returns immediately with progress `1`. Throws an error if a download is already in progress.
**Parameters:**
- `onProgress` - Callback for download progress (0-1).
**`init(): Promise`**
-Initializes the model and prepares it for inference. Safe to call multiple times (idempotent). Throws an error if the model is not downloaded yet. Automatically initializes telemetry if not already done.
+Initializes the model and prepares it for inference. Safe to call multiple times (idempotent). Throws an error if the model is not downloaded yet.
**`complete(params: CactusLMCompleteParams): Promise`**
@@ -423,7 +547,7 @@ Performs text completion with optional streaming and tool support. Automatically
- `stopSequences` - Array of strings to stop generation (default: `undefined`).
- `tools` - Array of `Tool` objects for function calling (default: `undefined`).
- `onToken` - Callback for streaming tokens.
-- `mode` - Completion mode (default: `local`)
+- `mode` - Completion mode: `'local'` | `'hybrid'` (default: `'local'`)
**`embed(params: CactusLMEmbedParams): Promise`**
@@ -432,6 +556,13 @@ Generates embeddings for the given text. Automatically calls `init()` if not alr
**Parameters:**
- `text` - Text to embed.
+**`imageEmbed(params: CactusLMImageEmbedParams): Promise`**
+
+Generates embeddings for the given image. Requires a vision-capable model. Automatically calls `init()` if not already initialized. Throws an error if a generation (completion or embedding) is already in progress.
+
+**Parameters:**
+- `imagePath` - Path to the image file.
+
**`stop(): Promise`**
Stops ongoing generation.
@@ -444,12 +575,9 @@ Resets the model's internal state, clearing any cached context. Automatically ca
Releases all resources associated with the model. Automatically calls `stop()` first. Safe to call even if the model is not initialized.
-**`getModels(params?: CactusLMGetModelsParams): Promise`**
-
-Fetches available models and persists the results locally for caching. Returns cached results if available, unless `forceRefresh` is `true`. Checks the download status for each model and includes it in the results.
+**`getModels(): Promise`**
-**Parameters:**
-- `forceRefresh` - If `true`, fetches from the server and updates the local cache (default: `false`).
+Fetches available models from the database and checks their download status. Results are cached in memory after the first call and subsequent calls return the cached results.
### useCactusLM Hook
@@ -471,10 +599,97 @@ The `useCactusLM` hook manages a `CactusLM` instance with reactive state. When m
- `init(): Promise` - Initializes the model for inference. Sets `isInitializing` to `true` during initialization.
- `complete(params: CactusLMCompleteParams): Promise` - Generates text completions. Automatically accumulates tokens in the `completion` state during streaming. Sets `isGenerating` to `true` while generating. Clears `completion` before starting.
- `embed(params: CactusLMEmbedParams): Promise` - Generates embeddings for the given text. Sets `isGenerating` to `true` during operation.
+- `imageEmbed(params: CactusLMImageEmbedParams): Promise` - Generates embeddings for the given image. Sets `isGenerating` to `true` while generating.
- `stop(): Promise` - Stops ongoing generation. Clears any errors.
- `reset(): Promise` - Resets the model's internal state, clearing cached context. Also clears the `completion` state.
- `destroy(): Promise` - Releases all resources associated with the model. Clears the `completion` state. Automatically called when the component unmounts.
-- `getModels(params?: CactusLMGetModelsParams): Promise` - Fetches available models and returns them. Results are cached locally.
+- `getModels(): Promise` - Fetches available models from the database and checks their download status. Results are cached in memory and reused on subsequent calls.
+
+### CactusSTT Class
+
+#### Constructor
+
+**`new CactusSTT(params?: CactusSTTParams)`**
+
+**Parameters:**
+- `model` - Model slug (default: `'whisper-small'`).
+- `contextSize` - Context window size (default: `2048`).
+
+#### Methods
+
+**`download(params?: CactusSTTDownloadParams): Promise`**
+
+Downloads the model. If the model is already downloaded, returns immediately with progress `1`. Throws an error if a download is already in progress.
+
+**Parameters:**
+- `onProgress` - Callback for download progress (0-1).
+
+**`init(): Promise`**
+
+Initializes the model and prepares it for inference. Safe to call multiple times (idempotent). Throws an error if the model is not downloaded yet.
+
+**`transcribe(params: CactusSTTTranscribeParams): Promise`**
+
+Transcribes audio to text with optional streaming support. Automatically calls `init()` if not already initialized. Throws an error if a generation is already in progress.
+
+**Parameters:**
+- `audioFilePath` - Path to the audio file.
+- `prompt` - Optional prompt to guide transcription (default: `'<|startoftranscript|><|en|><|transcribe|><|notimestamps|>'`).
+- `options` - Transcription options:
+ - `temperature` - Sampling temperature (default: model-optimized).
+ - `topP` - Nucleus sampling threshold (default: model-optimized).
+ - `topK` - Top-K sampling limit (default: model-optimized).
+ - `maxTokens` - Maximum number of tokens to generate (default: `512`).
+ - `stopSequences` - Array of strings to stop generation (default: `undefined`).
+- `onToken` - Callback for streaming tokens.
+
+**`audioEmbed(params: CactusSTTAudioEmbedParams): Promise`**
+
+Generates embeddings for the given audio file. Automatically calls `init()` if not already initialized. Throws an error if a generation is already in progress.
+
+**Parameters:**
+- `audioPath` - Path to the audio file.
+
+**`stop(): Promise`**
+
+Stops ongoing transcription or embedding generation.
+
+**`reset(): Promise`**
+
+Resets the model's internal state. Automatically calls `stop()` first.
+
+**`destroy(): Promise`**
+
+Releases all resources associated with the model. Automatically calls `stop()` first. Safe to call even if the model is not initialized.
+
+**`getModels(): Promise`**
+
+Fetches available models from the database and checks their download status. Results are cached in memory after the first call and subsequent calls return the cached results.
+
+### useCactusSTT Hook
+
+The `useCactusSTT` hook manages a `CactusSTT` instance with reactive state. When model parameters (`model`, `contextSize`) change, the hook creates a new instance and resets all state. The hook automatically cleans up resources when the component unmounts.
+
+#### State
+
+- `response: string` - Current transcription text. Automatically accumulated during streaming. Cleared before each new transcription and when calling `reset()` or `destroy()`.
+- `isGenerating: boolean` - Whether the model is currently generating (transcription or embedding). Both operations share this flag.
+- `isInitializing: boolean` - Whether the model is initializing.
+- `isDownloaded: boolean` - Whether the model is downloaded locally. Automatically checked when the hook mounts or model changes.
+- `isDownloading: boolean` - Whether the model is being downloaded.
+- `downloadProgress: number` - Download progress (0-1). Reset to `0` after download completes.
+- `error: string | null` - Last error message from any operation, or `null` if there is no error. Cleared before starting new operations.
+
+#### Methods
+
+- `download(params?: CactusSTTDownloadParams): Promise` - Downloads the model. Updates `isDownloading` and `downloadProgress` state during download. Sets `isDownloaded` to `true` on success.
+- `init(): Promise` - Initializes the model for inference. Sets `isInitializing` to `true` during initialization.
+- `transcribe(params: CactusSTTTranscribeParams): Promise` - Transcribes audio to text. Automatically accumulates tokens in the `response` state during streaming. Sets `isGenerating` to `true` while generating. Clears `response` before starting.
+- `audioEmbed(params: CactusSTTAudioEmbedParams): Promise` - Generates embeddings for the given audio. Sets `isGenerating` to `true` during operation.
+- `stop(): Promise` - Stops ongoing generation. Clears any errors.
+- `reset(): Promise` - Resets the model's internal state. Also clears the `response` state.
+- `destroy(): Promise` - Releases all resources associated with the model. Clears the `response` state. Automatically called when the component unmounts.
+- `getModels(): Promise` - Fetches available models from the database and checks their download status. Results are cached in memory and reused on subsequent calls.
## Type Definitions
@@ -506,10 +721,10 @@ interface Message {
}
```
-### Options
+### CompleteOptions
```typescript
-interface Options {
+interface CompleteOptions {
temperature?: number;
topP?: number;
topK?: number;
@@ -542,7 +757,7 @@ interface Tool {
```typescript
interface CactusLMCompleteParams {
messages: Message[];
- options?: Options;
+ options?: CompleteOptions;
tools?: Tool[];
onToken?: (token: string) => void;
mode?: 'local' | 'hybrid';
@@ -584,11 +799,19 @@ interface CactusLMEmbedResult {
}
```
-### CactusLMGetModelsParams
+### CactusLMImageEmbedParams
```typescript
-interface CactusLMGetModelsParams {
- forceRefresh?: boolean;
+interface CactusLMImageEmbedParams {
+ imagePath: string;
+}
+```
+
+### CactusLMImageEmbedResult
+
+```typescript
+interface CactusLMImageEmbedResult {
+ embedding: number[];
}
```
@@ -608,6 +831,79 @@ interface CactusModel {
}
```
+### CactusSTTParams
+
+```typescript
+interface CactusSTTParams {
+ model?: string;
+ contextSize?: number;
+}
+```
+
+### CactusSTTDownloadParams
+
+```typescript
+interface CactusSTTDownloadParams {
+ onProgress?: (progress: number) => void;
+}
+
+```
+
+### TranscribeOptions
+
+```ts
+interface TranscribeOptions {
+ temperature?: number;
+ topP?: number;
+ topK?: number;
+ maxTokens?: number;
+ stopSequences?: string[];
+}
+```
+
+### CactusSTTTranscribeParams
+
+```typescript
+interface CactusSTTTranscribeParams {
+ audioFilePath: string;
+ prompt?: string;
+ options?: TranscribeOptions;
+ onToken?: (token: string) => void;
+}
+```
+
+### CactusSTTTranscribeResult
+
+```typescript
+interface CactusSTTTranscribeResult {
+ success: boolean;
+ response: string;
+ timeToFirstTokenMs: number;
+ totalTimeMs: number;
+ tokensPerSecond: number;
+ prefillTokens: number;
+ decodeTokens: number;
+ totalTokens: number;
+}
+
+```
+
+### CactusSTTAudioEmbedParams
+
+```typescript
+interface CactusSTTAudioEmbedParams {
+ audioPath: string;
+}
+```
+
+### CactusSTTAudioEmbedResult
+
+```typescript
+interface CactusSTTAudioEmbedResult {
+ embedding: number[];
+}
+```
+
## Configuration
### Telemetry
diff --git a/android/src/main/jniLibs/arm64-v8a/libcactus.a b/android/src/main/jniLibs/arm64-v8a/libcactus.a
index 359ab4c..7fb0617 100644
Binary files a/android/src/main/jniLibs/arm64-v8a/libcactus.a and b/android/src/main/jniLibs/arm64-v8a/libcactus.a differ
diff --git a/cpp/HybridCactus.cpp b/cpp/HybridCactus.cpp
index ae59ea0..ea9d913 100644
--- a/cpp/HybridCactus.cpp
+++ b/cpp/HybridCactus.cpp
@@ -75,6 +75,53 @@ std::shared_ptr> HybridCactus::complete(
});
}
+std::shared_ptr> HybridCactus::transcribe(
+ const std::string &audioFilePath, const std::string &prompt,
+ double responseBufferSize, const std::optional &optionsJson,
+ const std::optional> &callback) {
+ return Promise::async([this, audioFilePath, prompt, optionsJson,
+ callback,
+ responseBufferSize]() -> std::string {
+ std::lock_guard lock(this->_modelMutex);
+
+ if (!this->_model) {
+ throw std::runtime_error("Cactus model is not initialized");
+ }
+
+ struct CallbackCtx {
+ const std::function *callback;
+ } callbackCtx{callback.has_value() ? &callback.value() : nullptr};
+
+ auto cactusTokenCallback = [](const char *token, uint32_t tokenId,
+ void *userData) {
+ auto *callbackCtx = static_cast(userData);
+ if (!callbackCtx || !callbackCtx->callback || !(*callbackCtx->callback))
+ return;
+ (*callbackCtx->callback)(token, tokenId);
+ };
+
+ std::string responseBuffer;
+ responseBuffer.resize(responseBufferSize);
+
+ int result =
+ cactus_transcribe(this->_model, audioFilePath.c_str(), prompt.c_str(),
+ responseBuffer.data(), responseBufferSize,
+ optionsJson ? optionsJson->c_str() : nullptr,
+ cactusTokenCallback, &callbackCtx);
+
+ if (result < 0) {
+ throw std::runtime_error("Cactus transcription failed");
+ }
+
+ // Remove null terminator
+ responseBuffer.resize(strlen(responseBuffer.c_str()));
+
+ return responseBuffer;
+ });
+}
+
std::shared_ptr>>
HybridCactus::embed(const std::string &text, double embeddingBufferSize) {
return Promise>::async(
@@ -103,6 +150,64 @@ HybridCactus::embed(const std::string &text, double embeddingBufferSize) {
});
}
+std::shared_ptr>>
+HybridCactus::imageEmbed(const std::string &imagePath,
+ double embeddingBufferSize) {
+ return Promise>::async(
+ [this, imagePath, embeddingBufferSize]() -> std::vector {
+ std::lock_guard lock(this->_modelMutex);
+
+ if (!this->_model) {
+ throw std::runtime_error("Cactus model is not initialized");
+ }
+
+ std::vector embeddingBuffer(embeddingBufferSize);
+ size_t embeddingDim;
+
+ int result = cactus_image_embed(
+ this->_model, imagePath.c_str(), embeddingBuffer.data(),
+ embeddingBufferSize * sizeof(float), &embeddingDim);
+
+ if (result < 0) {
+ throw std::runtime_error("Cactus image embedding failed");
+ }
+
+ embeddingBuffer.resize(embeddingDim);
+
+ return std::vector(embeddingBuffer.begin(),
+ embeddingBuffer.end());
+ });
+}
+
+std::shared_ptr>>
+HybridCactus::audioEmbed(const std::string &audioPath,
+ double embeddingBufferSize) {
+ return Promise>::async(
+ [this, audioPath, embeddingBufferSize]() -> std::vector {
+ std::lock_guard lock(this->_modelMutex);
+
+ if (!this->_model) {
+ throw std::runtime_error("Cactus model is not initialized");
+ }
+
+ std::vector embeddingBuffer(embeddingBufferSize);
+ size_t embeddingDim;
+
+ int result = cactus_audio_embed(
+ this->_model, audioPath.c_str(), embeddingBuffer.data(),
+ embeddingBufferSize * sizeof(float), &embeddingDim);
+
+ if (result < 0) {
+ throw std::runtime_error("Cactus audio embedding failed");
+ }
+
+ embeddingBuffer.resize(embeddingDim);
+
+ return std::vector(embeddingBuffer.begin(),
+ embeddingBuffer.end());
+ });
+}
+
std::shared_ptr> HybridCactus::reset() {
return Promise::async([this]() -> void {
std::lock_guard lock(this->_modelMutex);
diff --git a/cpp/HybridCactus.hpp b/cpp/HybridCactus.hpp
index d866561..575294d 100644
--- a/cpp/HybridCactus.hpp
+++ b/cpp/HybridCactus.hpp
@@ -23,9 +23,22 @@ class HybridCactus : public HybridCactusSpec {
double /* tokenId */)>> &callback)
override;
+ std::shared_ptr> transcribe(
+ const std::string &audioFilePath, const std::string &prompt,
+ double responseBufferSize, const std::optional &optionsJson,
+ const std::optional> &callback)
+ override;
+
std::shared_ptr>>
embed(const std::string &text, double embeddingBufferSize) override;
+ std::shared_ptr>>
+ imageEmbed(const std::string &imagePath, double embeddingBufferSize) override;
+
+ std::shared_ptr>>
+ audioEmbed(const std::string &audioPath, double embeddingBufferSize) override;
+
std::shared_ptr> reset() override;
std::shared_ptr> stop() override;
diff --git a/cpp/cactus_ffi.h b/cpp/cactus_ffi.h
index faac947..6bb3f27 100644
--- a/cpp/cactus_ffi.h
+++ b/cpp/cactus_ffi.h
@@ -33,6 +33,17 @@ CACTUS_FFI_EXPORT int cactus_complete(
void* user_data
);
+CACTUS_FFI_EXPORT int cactus_transcribe(
+ cactus_model_t model,
+ const char* audio_file_path,
+ const char* prompt,
+ char* response_buffer,
+ size_t buffer_size,
+ const char* options_json,
+ cactus_token_callback callback,
+ void* user_data
+);
+
CACTUS_FFI_EXPORT int cactus_embed(
cactus_model_t model,
@@ -42,6 +53,22 @@ CACTUS_FFI_EXPORT int cactus_embed(
size_t* embedding_dim
);
+CACTUS_FFI_EXPORT int cactus_image_embed(
+ cactus_model_t model,
+ const char* image_path,
+ float* embeddings_buffer,
+ size_t buffer_size,
+ size_t* embedding_dim
+);
+
+CACTUS_FFI_EXPORT int cactus_audio_embed(
+ cactus_model_t model,
+ const char* audio_path,
+ float* embeddings_buffer,
+ size_t buffer_size,
+ size_t* embedding_dim
+);
+
CACTUS_FFI_EXPORT void cactus_reset(cactus_model_t model);
CACTUS_FFI_EXPORT void cactus_stop(cactus_model_t model);
diff --git a/example/ios/Podfile.lock b/example/ios/Podfile.lock
index 1912daa..fcb4602 100644
--- a/example/ios/Podfile.lock
+++ b/example/ios/Podfile.lock
@@ -1,6 +1,6 @@
PODS:
- boost (1.84.0)
- - Cactus (1.0.2):
+ - Cactus (1.1.0):
- boost
- DoubleConversion
- fast_float
@@ -1808,6 +1808,34 @@ PODS:
- React-RCTFBReactNativeSpec
- ReactCommon/turbomodule/core
- SocketRocket
+ - react-native-document-picker (11.0.0):
+ - boost
+ - DoubleConversion
+ - fast_float
+ - fmt
+ - glog
+ - hermes-engine
+ - RCT-Folly
+ - RCT-Folly/Fabric
+ - RCTRequired
+ - RCTTypeSafety
+ - React-Core
+ - React-debug
+ - React-Fabric
+ - React-featureflags
+ - React-graphics
+ - React-ImageManager
+ - React-jsi
+ - React-NativeModulesApple
+ - React-RCTFabric
+ - React-renderercss
+ - React-rendererdebug
+ - React-utils
+ - ReactCodegen
+ - ReactCommon/turbomodule/bridging
+ - ReactCommon/turbomodule/core
+ - SocketRocket
+ - Yoga
- react-native-image-picker (8.2.1):
- boost
- DoubleConversion
@@ -2417,6 +2445,7 @@ DEPENDENCIES:
- React-logger (from `../node_modules/react-native/ReactCommon/logger`)
- React-Mapbuffer (from `../node_modules/react-native/ReactCommon`)
- React-microtasksnativemodule (from `../node_modules/react-native/ReactCommon/react/nativemodule/microtasks`)
+ - "react-native-document-picker (from `../node_modules/@react-native-documents/picker`)"
- react-native-image-picker (from `../node_modules/react-native-image-picker`)
- React-NativeModulesApple (from `../node_modules/react-native/ReactCommon/react/nativemodule/core/platform/ios`)
- React-oscompat (from `../node_modules/react-native/ReactCommon/oscompat`)
@@ -2543,6 +2572,8 @@ EXTERNAL SOURCES:
:path: "../node_modules/react-native/ReactCommon"
React-microtasksnativemodule:
:path: "../node_modules/react-native/ReactCommon/react/nativemodule/microtasks"
+ react-native-document-picker:
+ :path: "../node_modules/@react-native-documents/picker"
react-native-image-picker:
:path: "../node_modules/react-native-image-picker"
React-NativeModulesApple:
@@ -2612,7 +2643,7 @@ EXTERNAL SOURCES:
SPEC CHECKSUMS:
boost: 7e761d76ca2ce687f7cc98e698152abd03a18f90
- Cactus: 45d4b8148a963a719617a71b06add8bb2ef8721c
+ Cactus: 2949301f1229677c0bbeba6856d3c78e5798aace
DoubleConversion: cb417026b2400c8f53ae97020b2be961b59470cb
fast_float: b32c788ed9c6a8c584d114d0047beda9664e7cc6
FBLazyVector: b8f1312d48447cca7b4abc21ed155db14742bd03
@@ -2653,6 +2684,7 @@ SPEC CHECKSUMS:
React-logger: d27dd2000f520bf891d24f6e141cde34df41f0ee
React-Mapbuffer: 0746ffab5ac0f49b7c9347338e3d0c1d9dd634c8
React-microtasksnativemodule: b0fb3f97372df39bda3e657536039f1af227cc29
+ react-native-document-picker: 63639c144fbdc4bf7b12d31a3827ae2bfbaf7ad4
react-native-image-picker: 43e6cd4231e670030fe09b079d696fa5a634ccfc
React-NativeModulesApple: 9ec9240159974c94886ebbe4caec18e3395f6aef
React-oscompat: b12c633e9c00f1f99467b1e0e0b8038895dae436
diff --git a/example/package.json b/example/package.json
index 7813b11..88d3d3f 100644
--- a/example/package.json
+++ b/example/package.json
@@ -11,6 +11,7 @@
},
"dependencies": {
"@dr.pogodin/react-native-fs": "^2.36.1",
+ "@react-native-documents/picker": "^11.0.0",
"react": "19.1.0",
"react-native": "0.81.1",
"react-native-image-picker": "^8.2.1",
diff --git a/example/src/App.tsx b/example/src/App.tsx
index 1a4da35..63e6ea2 100644
--- a/example/src/App.tsx
+++ b/example/src/App.tsx
@@ -11,7 +11,7 @@ import CompletionScreen from './CompletionScreen';
import VisionScreen from './VisionScreen';
import ToolCallingScreen from './ToolCallingScreen';
import RAGScreen from './RAGScreen';
-import EmbeddingScreen from './EmbeddingScreen';
+import STTScreen from './STTScreen';
import ChatScreen from './ChatScreen';
import PerformanceScreen from './PerformanceScreen';
@@ -21,7 +21,7 @@ type Screen =
| 'Vision'
| 'ToolCalling'
| 'RAG'
- | 'Embedding'
+ | 'STT'
| 'Chat'
| 'Performance';
@@ -48,8 +48,8 @@ const App = () => {
setSelectedScreen('RAG');
};
- const handleGoToEmbedding = () => {
- setSelectedScreen('Embedding');
+ const handleGoToSTT = () => {
+ setSelectedScreen('STT');
};
const handleGoToChat = () => {
@@ -70,8 +70,8 @@ const App = () => {
return ;
case 'RAG':
return ;
- case 'Embedding':
- return ;
+ case 'STT':
+ return ;
case 'Chat':
return ;
case 'Performance':
@@ -106,7 +106,7 @@ const App = () => {
>
Completion
- Generate text with streaming
+ Text generation and embeddings
@@ -115,7 +115,9 @@ const App = () => {
onPress={handleGoToVision}
>
Vision
- Analyze images
+
+ Image analysis and embeddings
+
{
-
- Embedding
- Text to vectors
+
+ Speech-to-Text
+
+ Audio transcription and embeddings
+
diff --git a/example/src/CompletionScreen.tsx b/example/src/CompletionScreen.tsx
index 817352d..a3a4656 100644
--- a/example/src/CompletionScreen.tsx
+++ b/example/src/CompletionScreen.tsx
@@ -12,12 +12,16 @@ import {
useCactusLM,
type Message,
type CactusLMCompleteResult,
+ type CactusLMEmbedResult,
} from 'cactus-react-native';
const CompletionScreen = () => {
const cactusLM = useCactusLM({ model: 'qwen3-0.6' });
const [input, setInput] = useState('What is the capital of France?');
const [result, setResult] = useState(null);
+ const [embedResult, setEmbedResult] = useState(
+ null
+ );
useEffect(() => {
if (!cactusLM.isDownloaded) {
@@ -39,6 +43,10 @@ const CompletionScreen = () => {
setResult(completionResult);
};
+ const handleEmbed = async () => {
+ setEmbedResult(await cactusLM.embed({ text: input }));
+ };
+
const handleStop = () => {
cactusLM.stop();
};
@@ -87,6 +95,14 @@ const CompletionScreen = () => {
+
+ Embed
+
+
Stop
@@ -171,6 +187,26 @@ const CompletionScreen = () => {
)}
+ {embedResult && (
+
+ CactusLMEmbedResult:
+
+ embedding:
+
+
+ [
+ {embedResult.embedding
+ .slice(0, 20)
+ .map((v) => v.toFixed(4))
+ .join(', ')}
+ {embedResult.embedding.length > 20 ? ', ...' : ''}] (length:{' '}
+ {embedResult.embedding.length})
+
+
+
+
+ )}
+
{cactusLM.error && (
{cactusLM.error}
diff --git a/example/src/EmbeddingScreen.tsx b/example/src/EmbeddingScreen.tsx
deleted file mode 100644
index 2aead00..0000000
--- a/example/src/EmbeddingScreen.tsx
+++ /dev/null
@@ -1,195 +0,0 @@
-import { useEffect, useState } from 'react';
-import {
- View,
- Text,
- TextInput,
- TouchableOpacity,
- ScrollView,
- StyleSheet,
- ActivityIndicator,
-} from 'react-native';
-import { useCactusLM, type CactusLMEmbedResult } from 'cactus-react-native';
-
-const EmbeddingScreen = () => {
- const cactusLM = useCactusLM({ model: 'qwen3-0.6' });
- const [text, setText] = useState('Hello, World!');
- const [result, setResult] = useState(null);
-
- useEffect(() => {
- if (!cactusLM.isDownloaded) {
- cactusLM.download();
- }
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [cactusLM.isDownloaded]);
-
- const handleInit = () => {
- cactusLM.init();
- };
-
- const handleEmbed = async () => {
- const embedResult = await cactusLM.embed({ text });
- setResult(embedResult);
- };
-
- const handleDestroy = () => {
- cactusLM.destroy();
- };
-
- if (cactusLM.isDownloading) {
- return (
-
-
-
- Downloading model: {Math.round(cactusLM.downloadProgress * 100)}%
-
-
- );
- }
-
- return (
-
-
-
-
-
- Init
-
-
-
-
- {cactusLM.isGenerating ? 'Embedding...' : 'Embed'}
-
-
-
-
- Destroy
-
-
-
- {result && (
-
- CactusLMEmbedResult:
-
- embedding:
-
-
- [
- {result.embedding
- .slice(0, 20)
- .map((v) => v.toFixed(4))
- .join(', ')}
- {result.embedding.length > 20 ? ', ...' : ''}] (length:{' '}
- {result.embedding.length})
-
-
-
-
- )}
-
- {cactusLM.error && (
-
- {cactusLM.error}
-
- )}
-
- );
-};
-
-export default EmbeddingScreen;
-
-const styles = StyleSheet.create({
- container: {
- flex: 1,
- backgroundColor: '#fff',
- },
- content: {
- padding: 20,
- },
- centerContainer: {
- flex: 1,
- justifyContent: 'center',
- alignItems: 'center',
- padding: 20,
- },
- progressText: {
- marginTop: 16,
- fontSize: 16,
- color: '#000',
- },
- input: {
- borderWidth: 1,
- borderColor: '#ddd',
- borderRadius: 8,
- padding: 12,
- fontSize: 16,
- textAlignVertical: 'top',
- marginBottom: 16,
- color: '#000',
- },
- buttonContainer: {
- flexDirection: 'row',
- flexWrap: 'wrap',
- gap: 8,
- marginBottom: 16,
- },
- button: {
- backgroundColor: '#000',
- paddingVertical: 12,
- paddingHorizontal: 16,
- borderRadius: 8,
- alignItems: 'center',
- },
- buttonText: {
- color: '#fff',
- fontSize: 16,
- fontWeight: '600',
- },
- resultContainer: {
- marginTop: 16,
- },
- resultLabel: {
- fontSize: 16,
- fontWeight: '600',
- marginBottom: 8,
- color: '#000',
- },
- resultBox: {
- backgroundColor: '#f3f3f3',
- padding: 12,
- borderRadius: 8,
- },
- resultFieldLabel: {
- fontSize: 12,
- fontWeight: '600',
- color: '#666',
- marginBottom: 4,
- },
- resultFieldValue: {
- fontSize: 14,
- color: '#000',
- lineHeight: 20,
- },
- marginTop: {
- marginTop: 12,
- },
- errorContainer: {
- backgroundColor: '#000',
- padding: 12,
- borderRadius: 8,
- marginTop: 16,
- },
- errorText: {
- color: '#fff',
- fontSize: 14,
- },
-});
diff --git a/example/src/STTScreen.tsx b/example/src/STTScreen.tsx
new file mode 100644
index 0000000..f58fbbe
--- /dev/null
+++ b/example/src/STTScreen.tsx
@@ -0,0 +1,350 @@
+import { useEffect, useState } from 'react';
+import {
+ View,
+ Text,
+ TouchableOpacity,
+ ScrollView,
+ StyleSheet,
+ ActivityIndicator,
+} from 'react-native';
+import {
+ useCactusSTT,
+ type CactusSTTTranscribeResult,
+ type CactusSTTAudioEmbedResult,
+} from 'cactus-react-native';
+import * as DocumentPicker from '@react-native-documents/picker';
+import * as RNFS from '@dr.pogodin/react-native-fs';
+
+const STTScreen = () => {
+ const cactusSTT = useCactusSTT({ model: 'whisper-small' });
+ const [audioFile, setAudioFile] = useState(null);
+ const [audioFileName, setAudioFileName] = useState('');
+ const [result, setResult] = useState(null);
+ const [embeddingResult, setEmbeddingResult] =
+ useState(null);
+
+ useEffect(() => {
+ if (!cactusSTT.isDownloaded) {
+ cactusSTT.download();
+ }
+ // eslint-disable-next-line react-hooks/exhaustive-deps
+ }, [cactusSTT.isDownloaded]);
+
+ const handleInit = () => {
+ cactusSTT.init();
+ };
+
+ const handleSelectAudio = async () => {
+ try {
+ const res = await DocumentPicker.pick({
+ type: [DocumentPicker.types.audio],
+ });
+ if (res && res.length > 0) {
+ const fileName = `audio_${Date.now()}.wav`;
+ const destPath = `${RNFS.CachesDirectoryPath}/${fileName}`;
+ await RNFS.copyFile(res[0].uri, destPath);
+ setAudioFile(destPath);
+ setAudioFileName(res[0].name || 'Unknown');
+ }
+ } catch (err) {
+ console.error(err);
+ }
+ };
+
+ const handleTranscribe = async () => {
+ if (!audioFile) {
+ return;
+ }
+ const transcribeResult = await cactusSTT.transcribe({
+ audioFilePath: audioFile,
+ });
+ setResult(transcribeResult);
+ };
+
+ const handleAudioEmbed = async () => {
+ if (!audioFile) {
+ return;
+ }
+ const embedResult = await cactusSTT.audioEmbed({
+ audioPath: audioFile,
+ });
+ setEmbeddingResult(embedResult);
+ };
+
+ const handleStop = () => {
+ cactusSTT.stop();
+ };
+
+ const handleReset = () => {
+ cactusSTT.reset();
+ };
+
+ const handleDestroy = () => {
+ cactusSTT.destroy();
+ };
+
+ if (cactusSTT.isDownloading) {
+ return (
+
+
+
+ Downloading model: {Math.round(cactusSTT.downloadProgress * 100)}%
+
+
+ );
+ }
+
+ return (
+
+
+
+ {audioFile ? `Selected: ${audioFileName}` : 'Select Audio File'}
+
+
+
+
+
+ Init
+
+
+
+ Transcribe
+
+
+
+ Audio Embed
+
+
+
+ Stop
+
+
+
+ Reset
+
+
+
+ Destroy
+
+
+
+ {cactusSTT.response && (
+
+ Streaming:
+
+ {cactusSTT.response}
+
+
+ )}
+
+ {result && (
+
+ CactusSTTTranscribeResult:
+
+ success:
+
+ {result.success.toString()}
+
+
+
+ response:
+
+ {result.response}
+
+
+ timeToFirstTokenMs:
+
+
+ {result.timeToFirstTokenMs.toFixed(2)}
+
+
+
+ totalTimeMs:
+
+
+ {result.totalTimeMs.toFixed(2)}
+
+
+
+ tokensPerSecond:
+
+
+ {result.tokensPerSecond.toFixed(2)}
+
+
+
+ prefillTokens:
+
+ {result.prefillTokens}
+
+
+ decodeTokens:
+
+ {result.decodeTokens}
+
+
+ totalTokens:
+
+ {result.totalTokens}
+
+
+ )}
+
+ {embeddingResult && (
+
+ CactusSTTAudioEmbedResult:
+
+ embedding:
+
+
+ [{embeddingResult.embedding.slice(0, 10).join(', ')}
+ {embeddingResult.embedding.length > 10 ? ', ...' : ''}]
+
+
+
+
+ dimensions:
+
+
+ {embeddingResult.embedding.length}
+
+
+
+ )}
+
+ {cactusSTT.error && (
+
+ {cactusSTT.error}
+
+ )}
+
+ );
+};
+
+export default STTScreen;
+
+const styles = StyleSheet.create({
+ container: {
+ flex: 1,
+ backgroundColor: '#fff',
+ },
+ content: {
+ padding: 20,
+ },
+ centerContainer: {
+ flex: 1,
+ justifyContent: 'center',
+ alignItems: 'center',
+ padding: 20,
+ },
+ progressText: {
+ marginTop: 16,
+ fontSize: 16,
+ color: '#000',
+ },
+ selectButton: {
+ backgroundColor: '#f3f3f3',
+ padding: 16,
+ borderRadius: 8,
+ marginBottom: 16,
+ },
+ selectButtonText: {
+ fontSize: 16,
+ color: '#000',
+ textAlign: 'center',
+ },
+ buttonContainer: {
+ flexDirection: 'row',
+ flexWrap: 'wrap',
+ gap: 8,
+ marginBottom: 16,
+ },
+ button: {
+ backgroundColor: '#000',
+ paddingVertical: 12,
+ paddingHorizontal: 16,
+ borderRadius: 8,
+ alignItems: 'center',
+ },
+ buttonDisabled: {
+ backgroundColor: '#ccc',
+ },
+ buttonText: {
+ color: '#fff',
+ fontSize: 16,
+ fontWeight: '600',
+ },
+ responseContainer: {
+ marginTop: 16,
+ },
+ responseLabel: {
+ fontSize: 16,
+ fontWeight: '600',
+ marginBottom: 8,
+ color: '#000',
+ },
+ responseBox: {
+ backgroundColor: '#f3f3f3',
+ padding: 12,
+ borderRadius: 8,
+ minHeight: 100,
+ },
+ responseText: {
+ fontSize: 14,
+ color: '#000',
+ lineHeight: 20,
+ },
+ resultContainer: {
+ marginTop: 16,
+ },
+ resultLabel: {
+ fontSize: 16,
+ fontWeight: '600',
+ marginBottom: 8,
+ color: '#000',
+ },
+ resultBox: {
+ backgroundColor: '#f3f3f3',
+ padding: 12,
+ borderRadius: 8,
+ },
+ resultFieldLabel: {
+ fontSize: 12,
+ fontWeight: '600',
+ color: '#666',
+ marginBottom: 4,
+ },
+ resultFieldValue: {
+ fontSize: 14,
+ color: '#000',
+ lineHeight: 20,
+ },
+ embeddingScrollView: {
+ maxHeight: 60,
+ },
+ marginTop: {
+ marginTop: 12,
+ },
+ errorContainer: {
+ backgroundColor: '#000',
+ padding: 12,
+ borderRadius: 8,
+ marginTop: 16,
+ },
+ errorText: {
+ color: '#fff',
+ fontSize: 14,
+ },
+});
diff --git a/example/src/VisionScreen.tsx b/example/src/VisionScreen.tsx
index 9e9acb8..27fbff1 100644
--- a/example/src/VisionScreen.tsx
+++ b/example/src/VisionScreen.tsx
@@ -13,6 +13,7 @@ import {
useCactusLM,
type Message,
type CactusLMCompleteResult,
+ type CactusLMEmbedResult,
} from 'cactus-react-native';
import { launchImageLibrary } from 'react-native-image-picker';
@@ -21,6 +22,10 @@ const VisionScreen = () => {
const [input, setInput] = useState("What's in the image?");
const [selectedImage, setSelectedImage] = useState(null);
const [result, setResult] = useState(null);
+ const [textEmbedResult, setTextEmbedResult] =
+ useState(null);
+ const [imageEmbedResult, setImageEmbedResult] =
+ useState(null);
useEffect(() => {
if (!cactusLM.isDownloaded) {
@@ -59,6 +64,17 @@ const VisionScreen = () => {
setResult(completionResult);
};
+ const handleEmbedText = async () => {
+ setTextEmbedResult(await cactusLM.embed({ text: input }));
+ };
+
+ const handleEmbedImage = async () => {
+ if (!selectedImage) return;
+ setImageEmbedResult(
+ await cactusLM.imageEmbed({ imagePath: selectedImage })
+ );
+ };
+
const handleInit = () => {
cactusLM.init();
};
@@ -119,6 +135,22 @@ const VisionScreen = () => {
+
+ Embed Text
+
+
+
+ Embed Image
+
+
Stop
@@ -194,6 +226,46 @@ const VisionScreen = () => {
)}
+ {textEmbedResult && (
+
+ Text Embedding Result:
+
+ embedding:
+
+
+ [
+ {textEmbedResult.embedding
+ .slice(0, 20)
+ .map((v) => v.toFixed(4))
+ .join(', ')}
+ {textEmbedResult.embedding.length > 20 ? ', ...' : ''}] (length:{' '}
+ {textEmbedResult.embedding.length})
+
+
+
+
+ )}
+
+ {imageEmbedResult && (
+
+ Image Embedding Result:
+
+ embedding:
+
+
+ [
+ {imageEmbedResult.embedding
+ .slice(0, 20)
+ .map((v) => v.toFixed(4))
+ .join(', ')}
+ {imageEmbedResult.embedding.length > 20 ? ', ...' : ''}]
+ (length: {imageEmbedResult.embedding.length})
+
+
+
+
+ )}
+
{cactusLM.error && (
{cactusLM.error}
diff --git a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h
index faac947..6bb3f27 100644
--- a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h
+++ b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/cactus_ffi.h
@@ -33,6 +33,17 @@ CACTUS_FFI_EXPORT int cactus_complete(
void* user_data
);
+CACTUS_FFI_EXPORT int cactus_transcribe(
+ cactus_model_t model,
+ const char* audio_file_path,
+ const char* prompt,
+ char* response_buffer,
+ size_t buffer_size,
+ const char* options_json,
+ cactus_token_callback callback,
+ void* user_data
+);
+
CACTUS_FFI_EXPORT int cactus_embed(
cactus_model_t model,
@@ -42,6 +53,22 @@ CACTUS_FFI_EXPORT int cactus_embed(
size_t* embedding_dim
);
+CACTUS_FFI_EXPORT int cactus_image_embed(
+ cactus_model_t model,
+ const char* image_path,
+ float* embeddings_buffer,
+ size_t buffer_size,
+ size_t* embedding_dim
+);
+
+CACTUS_FFI_EXPORT int cactus_audio_embed(
+ cactus_model_t model,
+ const char* audio_path,
+ float* embeddings_buffer,
+ size_t buffer_size,
+ size_t* embedding_dim
+);
+
CACTUS_FFI_EXPORT void cactus_reset(cactus_model_t model);
CACTUS_FFI_EXPORT void cactus_stop(cactus_model_t model);
diff --git a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h
index da734e3..e233cd2 100644
--- a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h
+++ b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/engine.h
@@ -7,11 +7,28 @@
#include
#include "../graph/graph.h"
+
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wc99-extensions"
+#pragma clang diagnostic ignored "-Wunused-parameter"
+#elif defined(__GNUC__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wpedantic"
+#pragma GCC diagnostic ignored "-Wunused-parameter"
+#endif
+
extern "C" {
#include "../../libs/stb/stb_image.h"
#include "../../libs/stb/stb_image_resize2.h"
}
+#ifdef __clang__
+#pragma clang diagnostic pop
+#elif defined(__GNUC__)
+#pragma GCC diagnostic pop
+#endif
+
class CactusGraph;
namespace cactus {
@@ -68,7 +85,7 @@ struct Config {
float max_pixels_tolerance = 2.0f;
bool do_image_splitting = true;
- enum class ModelType {QWEN = 0, GEMMA = 1, SMOL = 2, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6};
+ enum class ModelType {QWEN = 0, GEMMA = 1, SMOL = 2, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7};
ModelType model_type = ModelType::QWEN;
enum class ModelVariant {DEFAULT = 0, VLM = 1, EXTRACT = 2, RAG = 3};
@@ -139,7 +156,7 @@ class Tokenizer {
void set_corpus_dir(const std::string& dir) { corpus_dir_ = dir; }
protected:
- enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, SMOL, BERT };
+ enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, SMOL, BERT, WHISPER};
ModelType model_type_ = ModelType::UNKNOWN;
enum class ModelVariant { DEFAULT, VLM, EXTRACT, RAG};
ModelVariant model_variant_ = ModelVariant::DEFAULT;
@@ -365,28 +382,43 @@ class Model {
const std::vector& get_debug_nodes() const;
virtual bool init(const std::string& model_folder, size_t context_size, const std::string& system_prompt = "", bool do_warmup = true);
+
virtual bool init(CactusGraph* external_graph, const std::string& model_folder, size_t context_size,
const std::string& system_prompt = "", bool do_warmup = true);
+
virtual uint32_t generate(const std::vector& tokens, float temperature = -1.0f, float top_p = -1.0f,
size_t top_k = 0, const std::string& profile_file = "");
virtual uint32_t generate_with_images(const std::vector& tokens, const std::vector& image_paths,
float temperature = -1.0f, float top_p = -1.0f,
size_t top_k = 0, const std::string& profile_file = "");
+
+ virtual uint32_t generate_with_audio(const std::vector& tokens, const std::vector& mel_bins, float temperature = 0.0f, float top_p = 0.0f,
+ size_t top_k = 0, const std::string& profile_file = "");
std::vector get_embeddings(const std::vector& tokens, bool pooled = true, const std::string& profile_file = "");
+
+ virtual std::vector get_image_embeddings(const std::string& image_path);
+
+ virtual std::vector get_audio_embeddings(const std::vector& mel_bins);
virtual void reset_cache() { kv_cache_.reset(); }
+
void set_cache_window(size_t window_size, size_t sink_size = 4) { kv_cache_.set_window_size(window_size, sink_size); }
void* graph_handle_;
protected:
virtual size_t forward(const std::vector& tokens, bool use_cache = false) = 0;
+
+ virtual size_t forward(const std::vector& mel_bins, const std::vector& tokens, bool use_cache = false);
+
virtual void load_weights_to_graph(CactusGraph* gb) = 0;
+
virtual size_t build_attention(CactusGraph* gb, size_t normalized_input, uint32_t layer_idx,
ComputeBackend backend, bool use_cache = false, size_t position_offset = 0) = 0;
- virtual size_t build_mlp(CactusGraph* gb, size_t normalized_h, uint32_t layer_idx,
+
+ virtual size_t build_mlp(CactusGraph* gb, size_t normalized_h, uint32_t layer_idx,
ComputeBackend backend) const = 0;
virtual size_t build_transformer_block(CactusGraph* gb, size_t hidden, uint32_t layer_idx,
ComputeBackend backend, bool use_cache = false, size_t position_offset = 0) = 0;
diff --git a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/ffi_utils.h b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/ffi_utils.h
index 9f464e4..1bbccc4 100644
--- a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/ffi_utils.h
+++ b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/ffi_utils.h
@@ -8,6 +8,8 @@
#include
#include
#include
+#include
+#include
#include
#include
@@ -177,8 +179,8 @@ inline void parse_options_json(const std::string& json,
float& temperature, float& top_p,
size_t& top_k, size_t& max_tokens,
std::vector& stop_sequences) {
- temperature = -1.0f;
- top_p = -1.0f;
+ temperature = 0.0f;
+ top_p = 0.0f;
top_k = 0;
max_tokens = 100;
stop_sequences.clear();
@@ -233,15 +235,14 @@ inline std::string format_tools_for_prompt(const std::vector& tool
std::string formatted_tools_json;
for (size_t i = 0; i < tools.size(); i++) {
if (i > 0) formatted_tools_json += ",\n";
- formatted_tools_json += " {\n";
- formatted_tools_json += " \"type\": \"function\",\n";
- formatted_tools_json += " \"function\": {\n";
- formatted_tools_json += " \"name\": \"" + tools[i].name + "\",\n";
- formatted_tools_json += " \"description\": \"" + tools[i].description + "\"";
+ formatted_tools_json += "{\"type\":\"function\",\"function\":{\"name\":\""
+ + tools[i].name
+ + "\",\"description\":\""
+ + tools[i].description + "\"";
if (tools[i].parameters.find("schema") != tools[i].parameters.end()) {
- formatted_tools_json += ",\n \"parameters\": " + tools[i].parameters.at("schema");
+ formatted_tools_json += ",\"parameters\":" + tools[i].parameters.at("schema");
}
- formatted_tools_json += "\n }\n }";
+ formatted_tools_json += "}}";
}
return formatted_tools_json;
}
diff --git a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h
index c6df4e6..bf4bc19 100644
--- a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h
+++ b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/graph.h
@@ -32,7 +32,7 @@ enum class OpType {
SUM, MEAN, VARIANCE, MIN, MAX,
RMS_NORM, ROPE, SOFTMAX, ATTENTION, CONV1D_CAUSAL, CONV1D_K3,
SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN,
- SILU, GELU,
+ SILU, GELU, GELU_ERF,
SAMPLE, CONCAT,
SCATTER_TOPK,
TOPK, LAYERNORM,
@@ -219,6 +219,7 @@ class CactusGraph {
size_t silu(size_t input);
size_t gelu(size_t input);
+ size_t gelu_erf(size_t input);
size_t matmul(size_t input1, size_t input2, bool pretransposed_rhs = false, ComputeBackend backend = ComputeBackend::CPU);
size_t transpose(size_t input, ComputeBackend backend = ComputeBackend::CPU);
@@ -236,6 +237,7 @@ class CactusGraph {
size_t gather(size_t embeddings, size_t indices);
size_t mmap_embeddings(const std::string& filename);
size_t mmap_weights(const std::string& filename);
+ size_t load_weights(const std::string& filename);
void set_quantization_scale(size_t node_id, float scale);
size_t embedding(const std::string& filename, size_t indices);
size_t embedding(size_t embedding_tensor, size_t indices);
diff --git a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h
index cdd45ee..5c1e59d 100644
--- a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h
+++ b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/Headers/kernel.h
@@ -174,6 +174,15 @@ void cactus_gelu_f16(const __fp16* input, __fp16* output, size_t num_elements);
void cactus_gelu_int8(const int8_t* input, int8_t* output, size_t num_elements,
float input_scale, float output_scale);
+void cactus_gelu_f32_erf(const float* input, float* output, size_t num_elements);
+void cactus_gelu_f16_erf(const __fp16* input, __fp16* output, size_t num_elements);
+void cactus_gelu_int8_erf(
+ const int8_t* input,
+ int8_t* output,
+ size_t num_elements,
+ float scale_in,
+ float scale_out);
+
void cactus_attention_int8(const int8_t* queries, const int8_t* keys, const int8_t* values, int8_t* output,
size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
@@ -225,6 +234,28 @@ void cactus_conv1d_causal_depthwise_int8(
float weight_scale,
float output_scale);
+void cactus_conv1d_f32_k3(
+ const float* input,
+ const float* weight,
+ float* output,
+ size_t N,
+ size_t L,
+ size_t C_in,
+ size_t C_out,
+ size_t stride
+);
+
+void cactus_conv1d_f16_k3(
+ const __fp16* input,
+ const __fp16* weight,
+ __fp16* output,
+ size_t N,
+ size_t L,
+ size_t C_in,
+ size_t C_out,
+ size_t stride
+);
+
void cactus_conv1d_f32_k3(
const float* input,
const float* weight,
diff --git a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus
index ccc94c3..449c8ac 100755
Binary files a/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus and b/ios/cactus.xcframework/ios-arm64-simulator/cactus.framework/cactus differ
diff --git a/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h b/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h
index faac947..6bb3f27 100644
--- a/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h
+++ b/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/cactus_ffi.h
@@ -33,6 +33,17 @@ CACTUS_FFI_EXPORT int cactus_complete(
void* user_data
);
+CACTUS_FFI_EXPORT int cactus_transcribe(
+ cactus_model_t model,
+ const char* audio_file_path,
+ const char* prompt,
+ char* response_buffer,
+ size_t buffer_size,
+ const char* options_json,
+ cactus_token_callback callback,
+ void* user_data
+);
+
CACTUS_FFI_EXPORT int cactus_embed(
cactus_model_t model,
@@ -42,6 +53,22 @@ CACTUS_FFI_EXPORT int cactus_embed(
size_t* embedding_dim
);
+CACTUS_FFI_EXPORT int cactus_image_embed(
+ cactus_model_t model,
+ const char* image_path,
+ float* embeddings_buffer,
+ size_t buffer_size,
+ size_t* embedding_dim
+);
+
+CACTUS_FFI_EXPORT int cactus_audio_embed(
+ cactus_model_t model,
+ const char* audio_path,
+ float* embeddings_buffer,
+ size_t buffer_size,
+ size_t* embedding_dim
+);
+
CACTUS_FFI_EXPORT void cactus_reset(cactus_model_t model);
CACTUS_FFI_EXPORT void cactus_stop(cactus_model_t model);
diff --git a/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h b/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h
index da734e3..e233cd2 100644
--- a/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h
+++ b/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/engine.h
@@ -7,11 +7,28 @@
#include
#include "../graph/graph.h"
+
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wc99-extensions"
+#pragma clang diagnostic ignored "-Wunused-parameter"
+#elif defined(__GNUC__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wpedantic"
+#pragma GCC diagnostic ignored "-Wunused-parameter"
+#endif
+
extern "C" {
#include "../../libs/stb/stb_image.h"
#include "../../libs/stb/stb_image_resize2.h"
}
+#ifdef __clang__
+#pragma clang diagnostic pop
+#elif defined(__GNUC__)
+#pragma GCC diagnostic pop
+#endif
+
class CactusGraph;
namespace cactus {
@@ -68,7 +85,7 @@ struct Config {
float max_pixels_tolerance = 2.0f;
bool do_image_splitting = true;
- enum class ModelType {QWEN = 0, GEMMA = 1, SMOL = 2, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6};
+ enum class ModelType {QWEN = 0, GEMMA = 1, SMOL = 2, NOMIC = 3, LFM2 = 5, SIGLIP2 = 6, WHISPER = 7};
ModelType model_type = ModelType::QWEN;
enum class ModelVariant {DEFAULT = 0, VLM = 1, EXTRACT = 2, RAG = 3};
@@ -139,7 +156,7 @@ class Tokenizer {
void set_corpus_dir(const std::string& dir) { corpus_dir_ = dir; }
protected:
- enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, SMOL, BERT };
+ enum class ModelType { UNKNOWN, QWEN, GEMMA, LFM2, SMOL, BERT, WHISPER};
ModelType model_type_ = ModelType::UNKNOWN;
enum class ModelVariant { DEFAULT, VLM, EXTRACT, RAG};
ModelVariant model_variant_ = ModelVariant::DEFAULT;
@@ -365,28 +382,43 @@ class Model {
const std::vector& get_debug_nodes() const;
virtual bool init(const std::string& model_folder, size_t context_size, const std::string& system_prompt = "", bool do_warmup = true);
+
virtual bool init(CactusGraph* external_graph, const std::string& model_folder, size_t context_size,
const std::string& system_prompt = "", bool do_warmup = true);
+
virtual uint32_t generate(const std::vector& tokens, float temperature = -1.0f, float top_p = -1.0f,
size_t top_k = 0, const std::string& profile_file = "");
virtual uint32_t generate_with_images(const std::vector& tokens, const std::vector& image_paths,
float temperature = -1.0f, float top_p = -1.0f,
size_t top_k = 0, const std::string& profile_file = "");
+
+ virtual uint32_t generate_with_audio(const std::vector& tokens, const std::vector& mel_bins, float temperature = 0.0f, float top_p = 0.0f,
+ size_t top_k = 0, const std::string& profile_file = "");
std::vector get_embeddings(const std::vector& tokens, bool pooled = true, const std::string& profile_file = "");
+
+ virtual std::vector get_image_embeddings(const std::string& image_path);
+
+ virtual std::vector get_audio_embeddings(const std::vector& mel_bins);
virtual void reset_cache() { kv_cache_.reset(); }
+
void set_cache_window(size_t window_size, size_t sink_size = 4) { kv_cache_.set_window_size(window_size, sink_size); }
void* graph_handle_;
protected:
virtual size_t forward(const std::vector& tokens, bool use_cache = false) = 0;
+
+ virtual size_t forward(const std::vector& mel_bins, const std::vector& tokens, bool use_cache = false);
+
virtual void load_weights_to_graph(CactusGraph* gb) = 0;
+
virtual size_t build_attention(CactusGraph* gb, size_t normalized_input, uint32_t layer_idx,
ComputeBackend backend, bool use_cache = false, size_t position_offset = 0) = 0;
- virtual size_t build_mlp(CactusGraph* gb, size_t normalized_h, uint32_t layer_idx,
+
+ virtual size_t build_mlp(CactusGraph* gb, size_t normalized_h, uint32_t layer_idx,
ComputeBackend backend) const = 0;
virtual size_t build_transformer_block(CactusGraph* gb, size_t hidden, uint32_t layer_idx,
ComputeBackend backend, bool use_cache = false, size_t position_offset = 0) = 0;
diff --git a/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ffi_utils.h b/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ffi_utils.h
index 9f464e4..1bbccc4 100644
--- a/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ffi_utils.h
+++ b/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/ffi_utils.h
@@ -8,6 +8,8 @@
#include
#include
#include
+#include
+#include
#include
#include
@@ -177,8 +179,8 @@ inline void parse_options_json(const std::string& json,
float& temperature, float& top_p,
size_t& top_k, size_t& max_tokens,
std::vector& stop_sequences) {
- temperature = -1.0f;
- top_p = -1.0f;
+ temperature = 0.0f;
+ top_p = 0.0f;
top_k = 0;
max_tokens = 100;
stop_sequences.clear();
@@ -233,15 +235,14 @@ inline std::string format_tools_for_prompt(const std::vector& tool
std::string formatted_tools_json;
for (size_t i = 0; i < tools.size(); i++) {
if (i > 0) formatted_tools_json += ",\n";
- formatted_tools_json += " {\n";
- formatted_tools_json += " \"type\": \"function\",\n";
- formatted_tools_json += " \"function\": {\n";
- formatted_tools_json += " \"name\": \"" + tools[i].name + "\",\n";
- formatted_tools_json += " \"description\": \"" + tools[i].description + "\"";
+ formatted_tools_json += "{\"type\":\"function\",\"function\":{\"name\":\""
+ + tools[i].name
+ + "\",\"description\":\""
+ + tools[i].description + "\"";
if (tools[i].parameters.find("schema") != tools[i].parameters.end()) {
- formatted_tools_json += ",\n \"parameters\": " + tools[i].parameters.at("schema");
+ formatted_tools_json += ",\"parameters\":" + tools[i].parameters.at("schema");
}
- formatted_tools_json += "\n }\n }";
+ formatted_tools_json += "}}";
}
return formatted_tools_json;
}
diff --git a/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h b/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h
index c6df4e6..bf4bc19 100644
--- a/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h
+++ b/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/graph.h
@@ -32,7 +32,7 @@ enum class OpType {
SUM, MEAN, VARIANCE, MIN, MAX,
RMS_NORM, ROPE, SOFTMAX, ATTENTION, CONV1D_CAUSAL, CONV1D_K3,
SCALAR_ADD, SCALAR_SUBTRACT, SCALAR_MULTIPLY, SCALAR_DIVIDE, SCALAR_EXP, SCALAR_SQRT, SCALAR_COS, SCALAR_SIN,
- SILU, GELU,
+ SILU, GELU, GELU_ERF,
SAMPLE, CONCAT,
SCATTER_TOPK,
TOPK, LAYERNORM,
@@ -219,6 +219,7 @@ class CactusGraph {
size_t silu(size_t input);
size_t gelu(size_t input);
+ size_t gelu_erf(size_t input);
size_t matmul(size_t input1, size_t input2, bool pretransposed_rhs = false, ComputeBackend backend = ComputeBackend::CPU);
size_t transpose(size_t input, ComputeBackend backend = ComputeBackend::CPU);
@@ -236,6 +237,7 @@ class CactusGraph {
size_t gather(size_t embeddings, size_t indices);
size_t mmap_embeddings(const std::string& filename);
size_t mmap_weights(const std::string& filename);
+ size_t load_weights(const std::string& filename);
void set_quantization_scale(size_t node_id, float scale);
size_t embedding(const std::string& filename, size_t indices);
size_t embedding(size_t embedding_tensor, size_t indices);
diff --git a/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h b/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h
index cdd45ee..5c1e59d 100644
--- a/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h
+++ b/ios/cactus.xcframework/ios-arm64/cactus.framework/Headers/kernel.h
@@ -174,6 +174,15 @@ void cactus_gelu_f16(const __fp16* input, __fp16* output, size_t num_elements);
void cactus_gelu_int8(const int8_t* input, int8_t* output, size_t num_elements,
float input_scale, float output_scale);
+void cactus_gelu_f32_erf(const float* input, float* output, size_t num_elements);
+void cactus_gelu_f16_erf(const __fp16* input, __fp16* output, size_t num_elements);
+void cactus_gelu_int8_erf(
+ const int8_t* input,
+ int8_t* output,
+ size_t num_elements,
+ float scale_in,
+ float scale_out);
+
void cactus_attention_int8(const int8_t* queries, const int8_t* keys, const int8_t* values, int8_t* output,
size_t batch_size, size_t seq_len, size_t kv_seq_len, size_t num_q_heads, size_t num_kv_heads,
@@ -225,6 +234,28 @@ void cactus_conv1d_causal_depthwise_int8(
float weight_scale,
float output_scale);
+void cactus_conv1d_f32_k3(
+ const float* input,
+ const float* weight,
+ float* output,
+ size_t N,
+ size_t L,
+ size_t C_in,
+ size_t C_out,
+ size_t stride
+);
+
+void cactus_conv1d_f16_k3(
+ const __fp16* input,
+ const __fp16* weight,
+ __fp16* output,
+ size_t N,
+ size_t L,
+ size_t C_in,
+ size_t C_out,
+ size_t stride
+);
+
void cactus_conv1d_f32_k3(
const float* input,
const float* weight,
diff --git a/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus b/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus
index 72982a1..c71c4eb 100755
Binary files a/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus and b/ios/cactus.xcframework/ios-arm64/cactus.framework/cactus differ
diff --git a/nitrogen/generated/shared/c++/HybridCactusSpec.cpp b/nitrogen/generated/shared/c++/HybridCactusSpec.cpp
index b5eff13..b164ed2 100644
--- a/nitrogen/generated/shared/c++/HybridCactusSpec.cpp
+++ b/nitrogen/generated/shared/c++/HybridCactusSpec.cpp
@@ -16,7 +16,10 @@ namespace margelo::nitro::cactus {
registerHybrids(this, [](Prototype& prototype) {
prototype.registerHybridMethod("init", &HybridCactusSpec::init);
prototype.registerHybridMethod("complete", &HybridCactusSpec::complete);
+ prototype.registerHybridMethod("transcribe", &HybridCactusSpec::transcribe);
prototype.registerHybridMethod("embed", &HybridCactusSpec::embed);
+ prototype.registerHybridMethod("imageEmbed", &HybridCactusSpec::imageEmbed);
+ prototype.registerHybridMethod("audioEmbed", &HybridCactusSpec::audioEmbed);
prototype.registerHybridMethod("reset", &HybridCactusSpec::reset);
prototype.registerHybridMethod("stop", &HybridCactusSpec::stop);
prototype.registerHybridMethod("destroy", &HybridCactusSpec::destroy);
diff --git a/nitrogen/generated/shared/c++/HybridCactusSpec.hpp b/nitrogen/generated/shared/c++/HybridCactusSpec.hpp
index d37edfa..19c8b24 100644
--- a/nitrogen/generated/shared/c++/HybridCactusSpec.hpp
+++ b/nitrogen/generated/shared/c++/HybridCactusSpec.hpp
@@ -54,7 +54,10 @@ namespace margelo::nitro::cactus {
// Methods
virtual std::shared_ptr> init(const std::string& modelPath, double contextSize, const std::optional& corpusDir) = 0;
virtual std::shared_ptr> complete(const std::string& messagesJson, double responseBufferSize, const std::optional& optionsJson, const std::optional& toolsJson, const std::optional>& callback) = 0;
+ virtual std::shared_ptr> transcribe(const std::string& audioFilePath, const std::string& prompt, double responseBufferSize, const std::optional& optionsJson, const std::optional>& callback) = 0;
virtual std::shared_ptr>> embed(const std::string& text, double embeddingBufferSize) = 0;
+ virtual std::shared_ptr>> imageEmbed(const std::string& imagePath, double embeddingBufferSize) = 0;
+ virtual std::shared_ptr>> audioEmbed(const std::string& audioPath, double embeddingBufferSize) = 0;
virtual std::shared_ptr> reset() = 0;
virtual std::shared_ptr> stop() = 0;
virtual std::shared_ptr> destroy() = 0;
diff --git a/package.json b/package.json
index 01e7301..45e4055 100644
--- a/package.json
+++ b/package.json
@@ -1,6 +1,6 @@
{
"name": "cactus-react-native",
- "version": "1.0.2",
+ "version": "1.1.0",
"description": "Run AI models locally on mobile devices",
"main": "./lib/module/index.js",
"types": "./lib/typescript/src/index.d.ts",
diff --git a/src/api/Database.ts b/src/api/Database.ts
index 2b46933..e4ad636 100644
--- a/src/api/Database.ts
+++ b/src/api/Database.ts
@@ -54,6 +54,33 @@ export class Database {
return await CactusUtil.registerApp(await response.text());
}
+ public static async getModel(slug: string): Promise {
+ const response = await fetch(
+ `${this.url}/functions/v1/get-models?slug=${slug}&sdk_name=react&sdk_version=${packageVersion}`,
+ {
+ headers: { apikey: this.key, Authorization: `Bearer ${this.key}` },
+ }
+ );
+
+ if (!response.ok) {
+ throw new Error('Getting model failed');
+ }
+
+ const model = (await response.json()) as CactusModelResponse;
+
+ return {
+ name: model.name,
+ slug: model.slug,
+ quantization: model.quantization,
+ sizeMb: model.size_mb,
+ downloadUrl: model.download_url,
+ supportsToolCalling: model.supports_tool_calling,
+ supportsVision: model.supports_vision,
+ createdAt: model.created_at,
+ isDownloaded: false,
+ };
+ }
+
public static async getModels(): Promise {
const response = await fetch(
`${this.url}/functions/v1/get-models?sdk_name=react&sdk_version=${packageVersion}`,
diff --git a/src/api/RemoteLM.ts b/src/api/RemoteLM.ts
index 3f81daf..49b1633 100644
--- a/src/api/RemoteLM.ts
+++ b/src/api/RemoteLM.ts
@@ -3,7 +3,7 @@ import { CactusImage } from '../native/CactusImage';
import type {
CactusLMCompleteResult,
Message,
- Options,
+ CompleteOptions,
Tool,
} from '../types/CactusLM';
@@ -15,7 +15,7 @@ export class RemoteLM {
public static async complete(
messages: Message[],
- options?: Options,
+ options?: CompleteOptions,
tools?: { type: 'function'; function: Tool }[],
callback?: (token: string) => void
): Promise {
diff --git a/src/classes/CactusLM.ts b/src/classes/CactusLM.ts
index c514520..f9defea 100644
--- a/src/classes/CactusLM.ts
+++ b/src/classes/CactusLM.ts
@@ -5,7 +5,8 @@ import type {
CactusLMCompleteResult,
CactusLMEmbedParams,
CactusLMEmbedResult,
- CactusLMGetModelsParams,
+ CactusLMImageEmbedParams,
+ CactusLMImageEmbedResult,
CactusLMParams,
} from '../types/CactusLM';
import type { CactusModel } from '../types/CactusModel';
@@ -34,7 +35,7 @@ export class CactusLM {
private static readonly defaultCompleteMode = 'local';
private static readonly defaultEmbedBufferSize = 2048;
- private static readonly modelsInfoPath = 'models/info.json';
+ private static cactusModelsCache: CactusModel[] | null = null;
constructor({ model, contextSize, corpusDir }: CactusLMParams = {}) {
Telemetry.init(CactusConfig.telemetryToken);
@@ -58,8 +59,12 @@ export class CactusLM {
this.isDownloading = true;
try {
- await CactusFileSystem.downloadModel(this.model, onProgress);
- await this.getModels({ forceRefresh: true });
+ const model = await Database.getModel(this.model);
+ await CactusFileSystem.downloadModel(
+ this.model,
+ model.downloadUrl,
+ onProgress
+ );
} finally {
this.isDownloading = false;
}
@@ -175,6 +180,31 @@ export class CactusLM {
}
}
+ public async imageEmbed({
+ imagePath,
+ }: CactusLMImageEmbedParams): Promise {
+ if (this.isGenerating) {
+ throw new Error('CactusLM is already generating');
+ }
+
+ await this.init();
+
+ this.isGenerating = true;
+ try {
+ const embedding = await this.cactus.imageEmbed(
+ imagePath,
+ CactusLM.defaultEmbedBufferSize
+ );
+ Telemetry.logImageEmbedding(this.model, true);
+ return { embedding };
+ } catch (error) {
+ Telemetry.logImageEmbedding(this.model, false, getErrorMessage(error));
+ throw error;
+ } finally {
+ this.isGenerating = false;
+ }
+ }
+
public stop(): Promise {
return this.cactus.stop();
}
@@ -195,34 +225,15 @@ export class CactusLM {
this.isInitialized = false;
}
- public async getModels({
- forceRefresh = false,
- }: CactusLMGetModelsParams = {}): Promise {
- if (
- !forceRefresh &&
- (await CactusFileSystem.fileExists(CactusLM.modelsInfoPath))
- ) {
- try {
- return JSON.parse(
- await CactusFileSystem.readFile(CactusLM.modelsInfoPath)
- );
- } catch {
- // Delete corrupted models info
- await CactusFileSystem.deleteFile(CactusLM.modelsInfoPath);
- }
+ public async getModels(): Promise {
+ if (CactusLM.cactusModelsCache) {
+ return CactusLM.cactusModelsCache;
}
-
const models = await Database.getModels();
-
for (const model of models) {
model.isDownloaded = await CactusFileSystem.modelExists(model.slug);
}
-
- await CactusFileSystem.writeFile(
- CactusLM.modelsInfoPath,
- JSON.stringify(models)
- );
-
+ CactusLM.cactusModelsCache = models;
return models;
}
}
diff --git a/src/classes/CactusSTT.ts b/src/classes/CactusSTT.ts
new file mode 100644
index 0000000..2747217
--- /dev/null
+++ b/src/classes/CactusSTT.ts
@@ -0,0 +1,182 @@
+import { Cactus, CactusFileSystem } from '../native';
+import type {
+ CactusSTTDownloadParams,
+ CactusSTTTranscribeParams,
+ CactusSTTTranscribeResult,
+ CactusSTTParams,
+ CactusSTTAudioEmbedParams,
+ CactusSTTAudioEmbedResult,
+} from '../types/CactusSTT';
+import type { CactusModel } from '../types/CactusModel';
+import { Telemetry } from '../telemetry/Telemetry';
+import { CactusConfig } from '../config/CactusConfig';
+import { Database } from '../api/Database';
+import { getErrorMessage } from '../utils/error';
+
+export class CactusSTT {
+ private readonly cactus = new Cactus();
+
+ private readonly model: string;
+ private readonly contextSize: number;
+
+ private isDownloading = false;
+ private isInitialized = false;
+ private isGenerating = false;
+
+ private static readonly defaultModel = 'whisper-small';
+ private static readonly defaultContextSize = 2048;
+ private static readonly defaultTranscribeOptions = {
+ maxTokens: 512,
+ };
+ private static readonly defaultEmbedBufferSize = 32768;
+
+ private static cactusModelsCache: CactusModel[] | null = null;
+
+ constructor({ model, contextSize }: CactusSTTParams = {}) {
+ Telemetry.init(CactusConfig.telemetryToken);
+
+ this.model = model ?? CactusSTT.defaultModel;
+ this.contextSize = contextSize ?? CactusSTT.defaultContextSize;
+ }
+
+ public async download({
+ onProgress,
+ }: CactusSTTDownloadParams = {}): Promise {
+ if (this.isDownloading) {
+ throw new Error('CactusSTT is already downloading');
+ }
+
+ if (await CactusFileSystem.modelExists(this.model)) {
+ onProgress?.(1.0);
+ return;
+ }
+
+ this.isDownloading = true;
+ try {
+ await CactusFileSystem.downloadModel(
+ this.model,
+ `https://vlqqczxwyaodtcdmdmlw.supabase.co/storage/v1/object/public/voice-models/${this.model}.zip`,
+ onProgress
+ );
+ } finally {
+ this.isDownloading = false;
+ }
+ }
+
+ public async init(): Promise {
+ if (this.isInitialized) {
+ return;
+ }
+
+ if (!(await CactusFileSystem.modelExists(this.model))) {
+ throw new Error(`Model "${this.model}" is not downloaded`);
+ }
+
+ const modelPath = await CactusFileSystem.getModelPath(this.model);
+
+ try {
+ await this.cactus.init(modelPath, this.contextSize);
+ Telemetry.logInit(this.model, true);
+ this.isInitialized = true;
+ } catch (error) {
+ Telemetry.logInit(this.model, false, getErrorMessage(error));
+ throw error;
+ }
+ }
+
+ public async transcribe({
+ audioFilePath,
+ prompt = '<|startoftranscript|><|en|><|transcribe|><|notimestamps|>',
+ options,
+ onToken,
+ }: CactusSTTTranscribeParams): Promise {
+ if (this.isGenerating) {
+ throw new Error('CactusSTT is already generating');
+ }
+
+ await this.init();
+
+ options = { ...CactusSTT.defaultTranscribeOptions, ...options };
+ const responseBufferSize = 32768;
+
+ this.isGenerating = true;
+ try {
+ const result = await this.cactus.transcribe(
+ audioFilePath,
+ prompt,
+ responseBufferSize,
+ options,
+ onToken
+ );
+ Telemetry.logTranscribe(
+ this.model,
+ result.success,
+ result.success ? undefined : result.response,
+ result
+ );
+ return result;
+ } catch (error) {
+ Telemetry.logTranscribe(this.model, false, getErrorMessage(error));
+ throw error;
+ } finally {
+ this.isGenerating = false;
+ }
+ }
+
+ public async audioEmbed({
+ audioPath,
+ }: CactusSTTAudioEmbedParams): Promise {
+ if (this.isGenerating) {
+ throw new Error('CactusSTT is already generating');
+ }
+
+ await this.init();
+
+ this.isGenerating = true;
+ try {
+ const embedding = await this.cactus.audioEmbed(
+ audioPath,
+ CactusSTT.defaultEmbedBufferSize
+ );
+ Telemetry.logAudioEmbedding(this.model, true);
+ return { embedding };
+ } catch (error) {
+ Telemetry.logAudioEmbedding(this.model, false, getErrorMessage(error));
+ throw error;
+ } finally {
+ this.isGenerating = false;
+ }
+ }
+
+ public stop(): Promise {
+ return this.cactus.stop();
+ }
+
+ public async reset(): Promise {
+ await this.stop();
+ return this.cactus.reset();
+ }
+
+ public async destroy(): Promise {
+ if (!this.isInitialized) {
+ return;
+ }
+
+ await this.stop();
+ await this.cactus.destroy();
+
+ this.isInitialized = false;
+ }
+
+ public async getModels(): Promise {
+ if (CactusSTT.cactusModelsCache) {
+ return CactusSTT.cactusModelsCache;
+ }
+ const models = await Database.getModels();
+ for (const model of models) {
+ model.isDownloaded = await CactusFileSystem.modelExists(model.slug);
+ }
+ CactusSTT.cactusModelsCache = models;
+ return models;
+ }
+}
diff --git a/src/constants/packageVersion.ts b/src/constants/packageVersion.ts
index 5801820..3009f7a 100644
--- a/src/constants/packageVersion.ts
+++ b/src/constants/packageVersion.ts
@@ -1 +1 @@
-export const packageVersion = '1.0.2';
+export const packageVersion = '1.1.0';
diff --git a/src/hooks/useCactusLM.ts b/src/hooks/useCactusLM.ts
index 65594bf..5a7e951 100644
--- a/src/hooks/useCactusLM.ts
+++ b/src/hooks/useCactusLM.ts
@@ -7,7 +7,8 @@ import type {
CactusLMCompleteResult,
CactusLMEmbedParams,
CactusLMEmbedResult,
- CactusLMGetModelsParams,
+ CactusLMImageEmbedParams,
+ CactusLMImageEmbedResult,
CactusLMCompleteParams,
CactusLMDownloadParams,
} from '../types/CactusLM';
@@ -222,6 +223,30 @@ export const useCactusLM = ({
[cactusLM, isGenerating]
);
+ const imageEmbed = useCallback(
+ async ({
+ imagePath,
+ }: CactusLMImageEmbedParams): Promise => {
+ if (isGenerating) {
+ const message = 'CactusLM is already generating';
+ setError(message);
+ throw new Error(message);
+ }
+
+ setError(null);
+ setIsGenerating(true);
+ try {
+ return await cactusLM.imageEmbed({ imagePath });
+ } catch (e) {
+ setError(getErrorMessage(e));
+ throw e;
+ } finally {
+ setIsGenerating(false);
+ }
+ },
+ [cactusLM, isGenerating]
+ );
+
const stop = useCallback(async () => {
setError(null);
try {
@@ -256,20 +281,15 @@ export const useCactusLM = ({
}
}, [cactusLM]);
- const getModels = useCallback(
- async ({ forceRefresh }: CactusLMGetModelsParams = {}): Promise<
- CactusModel[]
- > => {
- setError(null);
- try {
- return await cactusLM.getModels({ forceRefresh });
- } catch (e) {
- setError(getErrorMessage(e));
- throw e;
- }
- },
- [cactusLM]
- );
+ const getModels = useCallback(async (): Promise => {
+ setError(null);
+ try {
+ return await cactusLM.getModels();
+ } catch (e) {
+ setError(getErrorMessage(e));
+ throw e;
+ }
+ }, [cactusLM]);
return {
completion,
@@ -284,6 +304,7 @@ export const useCactusLM = ({
init,
complete,
embed,
+ imageEmbed,
reset,
stop,
destroy,
diff --git a/src/hooks/useCactusSTT.ts b/src/hooks/useCactusSTT.ts
new file mode 100644
index 0000000..3446df8
--- /dev/null
+++ b/src/hooks/useCactusSTT.ts
@@ -0,0 +1,285 @@
+import { useCallback, useEffect, useState, useRef } from 'react';
+import { CactusSTT } from '../classes/CactusSTT';
+import { CactusFileSystem } from '../native';
+import { getErrorMessage } from '../utils/error';
+import type {
+ CactusSTTParams,
+ CactusSTTTranscribeResult,
+ CactusSTTTranscribeParams,
+ CactusSTTDownloadParams,
+ CactusSTTAudioEmbedParams,
+ CactusSTTAudioEmbedResult,
+} from '../types/CactusSTT';
+import type { CactusModel } from '../types/CactusModel';
+
+export const useCactusSTT = ({
+ model = 'whisper-small',
+ contextSize = 2048,
+}: CactusSTTParams = {}) => {
+ const [cactusSTT, setCactusSTT] = useState(
+ () => new CactusSTT({ model, contextSize })
+ );
+
+ // State
+ const [response, setResponse] = useState('');
+ const [isGenerating, setIsGenerating] = useState(false);
+ const [isInitializing, setIsInitializing] = useState(false);
+ const [isDownloaded, setIsDownloaded] = useState(false);
+ const [isDownloading, setIsDownloading] = useState(false);
+ const [downloadProgress, setDownloadProgress] = useState(0);
+ const [error, setError] = useState(null);
+
+ const currentModelRef = useRef(model);
+ const currentDownloadIdRef = useRef(0);
+
+ useEffect(() => {
+ currentModelRef.current = model;
+ }, [model]);
+
+ useEffect(() => {
+ setCactusSTT(new CactusSTT({ model, contextSize }));
+
+ setResponse('');
+ setIsGenerating(false);
+ setIsInitializing(false);
+ setIsDownloaded(false);
+ setIsDownloading(false);
+ setDownloadProgress(0);
+ setError(null);
+
+ let mounted = true;
+ CactusFileSystem.modelExists(model)
+ .then((exists) => {
+ if (!mounted) {
+ return;
+ }
+ setIsDownloaded(exists);
+ })
+ .catch((e) => {
+ if (!mounted) {
+ return;
+ }
+ setIsDownloaded(false);
+ setError(getErrorMessage(e));
+ });
+
+ return () => {
+ mounted = false;
+ };
+ }, [model, contextSize]);
+
+ useEffect(() => {
+ return () => {
+ cactusSTT.destroy().catch(() => {});
+ };
+ }, [cactusSTT]);
+
+ const download = useCallback(
+ async ({ onProgress }: CactusSTTDownloadParams = {}) => {
+ if (isDownloading) {
+ const message = 'CactusSTT is already downloading';
+ setError(message);
+ throw new Error(message);
+ }
+
+ setError(null);
+
+ if (isDownloaded) {
+ return;
+ }
+
+ const thisModel = currentModelRef.current;
+ const thisDownloadId = ++currentDownloadIdRef.current;
+
+ setDownloadProgress(0);
+ setIsDownloading(true);
+ try {
+ await cactusSTT.download({
+ onProgress: (progress) => {
+ if (
+ currentModelRef.current !== thisModel ||
+ currentDownloadIdRef.current !== thisDownloadId
+ ) {
+ return;
+ }
+
+ setDownloadProgress(progress);
+ onProgress?.(progress);
+ },
+ });
+
+ if (
+ currentModelRef.current !== thisModel ||
+ currentDownloadIdRef.current !== thisDownloadId
+ ) {
+ return;
+ }
+
+ setIsDownloaded(true);
+ } catch (e) {
+ if (
+ currentModelRef.current !== thisModel ||
+ currentDownloadIdRef.current !== thisDownloadId
+ ) {
+ return;
+ }
+
+ setError(getErrorMessage(e));
+ throw e;
+ } finally {
+ if (
+ currentModelRef.current !== thisModel ||
+ currentDownloadIdRef.current !== thisDownloadId
+ ) {
+ return;
+ }
+
+ setIsDownloading(false);
+ setDownloadProgress(0);
+ }
+ },
+ [cactusSTT, isDownloading, isDownloaded]
+ );
+
+ const init = useCallback(async () => {
+ if (isInitializing) {
+ const message = 'CactusSTT is already initializing';
+ setError(message);
+ throw new Error(message);
+ }
+
+ setError(null);
+ setIsInitializing(true);
+ try {
+ await cactusSTT.init();
+ } catch (e) {
+ setError(getErrorMessage(e));
+ throw e;
+ } finally {
+ setIsInitializing(false);
+ }
+ }, [cactusSTT, isInitializing]);
+
+ const transcribe = useCallback(
+ async ({
+ audioFilePath,
+ prompt,
+ options,
+ onToken,
+ }: CactusSTTTranscribeParams): Promise => {
+ if (isGenerating) {
+ const message = 'CactusSTT is already generating';
+ setError(message);
+ throw new Error(message);
+ }
+
+ setError(null);
+ setResponse('');
+ setIsGenerating(true);
+ try {
+ return await cactusSTT.transcribe({
+ audioFilePath,
+ prompt,
+ options,
+ onToken: (token) => {
+ setResponse((prev) => prev + token);
+ onToken?.(token);
+ },
+ });
+ } catch (e) {
+ setError(getErrorMessage(e));
+ throw e;
+ } finally {
+ setIsGenerating(false);
+ }
+ },
+ [cactusSTT, isGenerating]
+ );
+
+ const audioEmbed = useCallback(
+ async ({
+ audioPath,
+ }: CactusSTTAudioEmbedParams): Promise => {
+ if (isGenerating) {
+ const message = 'CactusSTT is already generating';
+ setError(message);
+ throw new Error(message);
+ }
+
+ setError(null);
+ setIsGenerating(true);
+ try {
+ return await cactusSTT.audioEmbed({ audioPath });
+ } catch (e) {
+ setError(getErrorMessage(e));
+ throw e;
+ } finally {
+ setIsGenerating(false);
+ }
+ },
+ [cactusSTT, isGenerating]
+ );
+
+ const stop = useCallback(async () => {
+ setError(null);
+ try {
+ await cactusSTT.stop();
+ } catch (e) {
+ setError(getErrorMessage(e));
+ throw e;
+ }
+ }, [cactusSTT]);
+
+ const reset = useCallback(async () => {
+ setError(null);
+ try {
+ await cactusSTT.reset();
+ } catch (e) {
+ setError(getErrorMessage(e));
+ throw e;
+ } finally {
+ setResponse('');
+ }
+ }, [cactusSTT]);
+
+ const destroy = useCallback(async () => {
+ setError(null);
+ try {
+ await cactusSTT.destroy();
+ } catch (e) {
+ setError(getErrorMessage(e));
+ throw e;
+ } finally {
+ setResponse('');
+ }
+ }, [cactusSTT]);
+
+ const getModels = useCallback(async (): Promise => {
+ setError(null);
+ try {
+ return await cactusSTT.getModels();
+ } catch (e) {
+ setError(getErrorMessage(e));
+ throw e;
+ }
+ }, [cactusSTT]);
+
+ return {
+ response,
+ isGenerating,
+ isInitializing,
+ isDownloaded,
+ isDownloading,
+ downloadProgress,
+ error,
+
+ download,
+ init,
+ transcribe,
+ audioEmbed,
+ reset,
+ stop,
+ destroy,
+ getModels,
+ };
+};
diff --git a/src/index.tsx b/src/index.tsx
index 90fb864..565123e 100644
--- a/src/index.tsx
+++ b/src/index.tsx
@@ -1,8 +1,10 @@
// Classes
export { CactusLM } from './classes/CactusLM';
+export { CactusSTT } from './classes/CactusSTT';
// Hooks
export { useCactusLM } from './hooks/useCactusLM';
+export { useCactusSTT } from './hooks/useCactusSTT';
// Types
export type { CactusModel } from './types/CactusModel';
@@ -10,14 +12,24 @@ export type {
CactusLMParams,
CactusLMDownloadParams,
Message,
- Options,
+ CompleteOptions,
Tool,
CactusLMCompleteParams,
CactusLMCompleteResult,
CactusLMEmbedParams,
CactusLMEmbedResult,
- CactusLMGetModelsParams,
+ CactusLMImageEmbedParams,
+ CactusLMImageEmbedResult,
} from './types/CactusLM';
+export type {
+ CactusSTTParams,
+ CactusSTTDownloadParams,
+ TranscribeOptions,
+ CactusSTTTranscribeParams,
+ CactusSTTTranscribeResult,
+ CactusSTTAudioEmbedParams,
+ CactusSTTAudioEmbedResult,
+} from './types/CactusSTT';
// Config
export { CactusConfig } from './config/CactusConfig';
diff --git a/src/native/Cactus.ts b/src/native/Cactus.ts
index 4605e21..00679b0 100644
--- a/src/native/Cactus.ts
+++ b/src/native/Cactus.ts
@@ -1,12 +1,16 @@
import { NitroModules } from 'react-native-nitro-modules';
import type { Cactus as CactusSpec } from '../specs/Cactus.nitro';
+import { CactusImage } from './CactusImage';
import type {
CactusLMCompleteResult,
Message,
- Options,
+ CompleteOptions,
Tool,
} from '../types/CactusLM';
-import { CactusImage } from './CactusImage';
+import type {
+ CactusSTTTranscribeResult,
+ TranscribeOptions,
+} from '../types/CactusSTT';
export class Cactus {
private readonly hybridCactus =
@@ -23,7 +27,7 @@ export class Cactus {
public async complete(
messages: Message[],
responseBufferSize: number,
- options?: Options,
+ options?: CompleteOptions,
tools?: { type: 'function'; function: Tool }[],
callback?: (token: string, tokenId: number) => void
): Promise {
@@ -33,17 +37,17 @@ export class Cactus {
messagesInternal.push(message);
continue;
}
- const images: string[] = [];
- for (const image of message.images) {
+ const resizedImages: string[] = [];
+ for (const imagePath of message.images) {
const resizedImage = await CactusImage.resize(
- image.replace('file://', ''),
+ imagePath.replace('file://', ''),
128,
128,
1
);
- images.push(resizedImage);
+ resizedImages.push(resizedImage);
}
- messagesInternal.push({ ...message, images });
+ messagesInternal.push({ ...message, images: resizedImages });
}
const messagesJson = JSON.stringify(messagesInternal);
@@ -85,10 +89,76 @@ export class Cactus {
}
}
+ public async transcribe(
+ audioFilePath: string,
+ prompt: string,
+ responseBufferSize: number,
+ options?: TranscribeOptions,
+ callback?: (token: string, tokenId: number) => void
+ ): Promise {
+ const optionsJson = options
+ ? JSON.stringify({
+ temperature: options.temperature,
+ top_p: options.topP,
+ top_k: options.topK,
+ max_tokens: options.maxTokens,
+ stop_sequences: options.stopSequences,
+ })
+ : undefined;
+
+ const response = await this.hybridCactus.transcribe(
+ audioFilePath.replace('file://', ''),
+ prompt,
+ responseBufferSize,
+ optionsJson,
+ callback
+ );
+
+ try {
+ const parsed = JSON.parse(response);
+
+ return {
+ success: parsed.success,
+ response: parsed.response,
+ timeToFirstTokenMs: parsed.time_to_first_token_ms,
+ totalTimeMs: parsed.total_time_ms,
+ tokensPerSecond: parsed.tokens_per_second,
+ prefillTokens: parsed.prefill_tokens,
+ decodeTokens: parsed.decode_tokens,
+ totalTokens: parsed.total_tokens,
+ };
+ } catch {
+ throw new Error('Unable to parse transcription response');
+ }
+ }
+
public embed(text: string, embeddingBufferSize: number): Promise {
return this.hybridCactus.embed(text, embeddingBufferSize);
}
+ public async imageEmbed(
+ imagePath: string,
+ embeddingBufferSize: number
+ ): Promise {
+ const resizedImage = await CactusImage.resize(
+ imagePath.replace('file://', ''),
+ 128,
+ 128,
+ 1
+ );
+ return this.hybridCactus.imageEmbed(resizedImage, embeddingBufferSize);
+ }
+
+ public audioEmbed(
+ audioPath: string,
+ embeddingBufferSize: number
+ ): Promise {
+ return this.hybridCactus.audioEmbed(
+ audioPath.replace('file://', ''),
+ embeddingBufferSize
+ );
+ }
+
public reset(): Promise {
return this.hybridCactus.reset();
}
diff --git a/src/native/CactusFileSystem.ts b/src/native/CactusFileSystem.ts
index 5bfc904..054ef64 100644
--- a/src/native/CactusFileSystem.ts
+++ b/src/native/CactusFileSystem.ts
@@ -35,10 +35,10 @@ export class CactusFileSystem {
public static downloadModel(
model: string,
+ url: string,
onProgress?: (progress: number) => void
): Promise {
- const from = `https://vlqqczxwyaodtcdmdmlw.supabase.co/storage/v1/object/public/cactus-models/${model}.zip`;
- return this.hybridCactusFileSystem.downloadModel(model, from, onProgress);
+ return this.hybridCactusFileSystem.downloadModel(model, url, onProgress);
}
public static deleteModel(model: string): Promise {
diff --git a/src/specs/Cactus.nitro.ts b/src/specs/Cactus.nitro.ts
index 51fe724..7abd820 100644
--- a/src/specs/Cactus.nitro.ts
+++ b/src/specs/Cactus.nitro.ts
@@ -13,7 +13,16 @@ export interface Cactus extends HybridObject<{ ios: 'c++'; android: 'c++' }> {
toolsJson?: string,
callback?: (token: string, tokenId: number) => void
): Promise;
+ transcribe(
+ audioFilePath: string,
+ prompt: string,
+ responseBufferSize: number,
+ optionsJson?: string,
+ callback?: (token: string, tokenId: number) => void
+ ): Promise;
embed(text: string, embeddingBufferSize: number): Promise;
+ imageEmbed(imagePath: string, embeddingBufferSize: number): Promise;
+ audioEmbed(audioPath: string, embeddingBufferSize: number): Promise;
reset(): Promise;
stop(): Promise;
destroy(): Promise;
diff --git a/src/telemetry/Telemetry.ts b/src/telemetry/Telemetry.ts
index fa9c07c..273886c 100644
--- a/src/telemetry/Telemetry.ts
+++ b/src/telemetry/Telemetry.ts
@@ -8,6 +8,7 @@ import {
import { CactusConfig } from '../config/CactusConfig';
import { packageVersion } from '../constants/packageVersion';
import type { CactusLMCompleteResult } from '../types/CactusLM';
+import type { CactusSTTTranscribeResult } from '../types/CactusSTT';
export interface LogRecord {
// Framework
@@ -15,7 +16,13 @@ export interface LogRecord {
framework_version: string;
// Event
- event_type: 'init' | 'completion' | 'embedding';
+ event_type:
+ | 'init'
+ | 'completion'
+ | 'transcription'
+ | 'embedding'
+ | 'image_embedding'
+ | 'audio_embedding';
model: string;
success: boolean;
message?: string;
@@ -41,7 +48,10 @@ export class Telemetry {
private static readonly logBufferPaths = {
init: 'logs/init.json',
completion: 'logs/completion.json',
+ transcription: 'logs/transcription.json',
embedding: 'logs/embedding.json',
+ image_embedding: 'logs/image_embedding.json',
+ audio_embedding: 'logs/audio_embedding.json',
};
private static async handleLog(logRecord: LogRecord) {
@@ -147,6 +157,29 @@ export class Telemetry {
});
}
+ public static logTranscribe(
+ model: string,
+ success: boolean,
+ message?: string,
+ result?: CactusSTTTranscribeResult
+ ): Promise {
+ return this.handleLog({
+ framework: 'react-native',
+ framework_version: packageVersion,
+ event_type: 'transcription',
+ model,
+ success,
+ message,
+ telemetry_token: this.cactusTelemetryToken,
+ project_id: this.projectId,
+ device_id: this.deviceId,
+ tokens: result?.totalTokens,
+ response_time: result?.totalTimeMs,
+ ttft: result?.timeToFirstTokenMs,
+ tps: result?.tokensPerSecond,
+ });
+ }
+
public static logEmbedding(
model: string,
success: boolean,
@@ -164,4 +197,40 @@ export class Telemetry {
device_id: this.deviceId,
});
}
+
+ public static logImageEmbedding(
+ model: string,
+ success: boolean,
+ message?: string
+ ): Promise {
+ return this.handleLog({
+ framework: 'react-native',
+ framework_version: packageVersion,
+ event_type: 'image_embedding',
+ model,
+ success,
+ message,
+ telemetry_token: this.cactusTelemetryToken,
+ project_id: this.projectId,
+ device_id: this.deviceId,
+ });
+ }
+
+ public static logAudioEmbedding(
+ model: string,
+ success: boolean,
+ message?: string
+ ): Promise {
+ return this.handleLog({
+ framework: 'react-native',
+ framework_version: packageVersion,
+ event_type: 'audio_embedding',
+ model,
+ success,
+ message,
+ telemetry_token: this.cactusTelemetryToken,
+ project_id: this.projectId,
+ device_id: this.deviceId,
+ });
+ }
}
diff --git a/src/types/CactusLM.ts b/src/types/CactusLM.ts
index 3a11ca7..bb4eab5 100644
--- a/src/types/CactusLM.ts
+++ b/src/types/CactusLM.ts
@@ -14,7 +14,7 @@ export interface Message {
images?: string[];
}
-export interface Options {
+export interface CompleteOptions {
temperature?: number;
topP?: number;
topK?: number;
@@ -39,7 +39,7 @@ export interface Tool {
export interface CactusLMCompleteParams {
messages: Message[];
- options?: Options;
+ options?: CompleteOptions;
tools?: Tool[];
onToken?: (token: string) => void;
mode?: 'local' | 'hybrid';
@@ -68,6 +68,10 @@ export interface CactusLMEmbedResult {
embedding: number[];
}
-export interface CactusLMGetModelsParams {
- forceRefresh?: boolean;
+export interface CactusLMImageEmbedParams {
+ imagePath: string;
+}
+
+export interface CactusLMImageEmbedResult {
+ embedding: number[];
}
diff --git a/src/types/CactusSTT.ts b/src/types/CactusSTT.ts
new file mode 100644
index 0000000..8ba12d9
--- /dev/null
+++ b/src/types/CactusSTT.ts
@@ -0,0 +1,42 @@
+export interface CactusSTTParams {
+ model?: string;
+ contextSize?: number;
+}
+
+export interface CactusSTTDownloadParams {
+ onProgress?: (progress: number) => void;
+}
+
+export interface TranscribeOptions {
+ temperature?: number;
+ topP?: number;
+ topK?: number;
+ maxTokens?: number;
+ stopSequences?: string[];
+}
+
+export interface CactusSTTTranscribeParams {
+ audioFilePath: string;
+ prompt?: string;
+ options?: TranscribeOptions;
+ onToken?: (token: string) => void;
+}
+
+export interface CactusSTTTranscribeResult {
+ success: boolean;
+ response: string;
+ timeToFirstTokenMs: number;
+ totalTimeMs: number;
+ tokensPerSecond: number;
+ prefillTokens: number;
+ decodeTokens: number;
+ totalTokens: number;
+}
+
+export interface CactusSTTAudioEmbedParams {
+ audioPath: string;
+}
+
+export interface CactusSTTAudioEmbedResult {
+ embedding: number[];
+}
diff --git a/yarn.lock b/yarn.lock
index 6ab6675..7c91d4c 100644
--- a/yarn.lock
+++ b/yarn.lock
@@ -2923,6 +2923,16 @@ __metadata:
languageName: node
linkType: hard
+"@react-native-documents/picker@npm:^11.0.0":
+ version: 11.0.0
+ resolution: "@react-native-documents/picker@npm:11.0.0"
+ peerDependencies:
+ react: "*"
+ react-native: ">=0.79.0"
+ checksum: 411a9406f92724f9af618f99cf4948f14e011d505810a3fe5d67dd7ace39f218a08f49e9f8271be7bd27a957c6c942e472cdf0a27d67462f56af2ed5e51148e9
+ languageName: node
+ linkType: hard
+
"@react-native/assets-registry@npm:0.81.1":
version: 0.81.1
resolution: "@react-native/assets-registry@npm:0.81.1"
@@ -4374,6 +4384,7 @@ __metadata:
"@react-native-community/cli": 20.0.0
"@react-native-community/cli-platform-android": 20.0.0
"@react-native-community/cli-platform-ios": 20.0.0
+ "@react-native-documents/picker": ^11.0.0
"@react-native/babel-preset": 0.81.1
"@react-native/metro-config": 0.81.1
"@react-native/typescript-config": 0.81.1