Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ from open_autonlu.methods.data_types import SaveFormat
pipeline = TextClassificationTrainingPipeline(
train_path="train.csv",
test_path="test.csv",
config_overrides={"language": "en"} # "en" or "ru"
config_overrides={"language": "en"} # for non-en/ru also set "model_name_or_path"
)
result = pipeline.train()
pipeline.save("./model", SaveFormat.ONNX)
Expand All @@ -82,7 +82,7 @@ pipeline.save("./model", SaveFormat.ONNX)
pipeline = TokenClassificationTrainingPipeline(
train_path="train.json",
test_path="test.json",
config_overrides={"language": "en"} # "en" or "ru"
config_overrides={"language": "en"} # for non-en/ru also set "model_name_or_path"
)
result = pipeline.train()
pipeline.save("./model", SaveFormat.ONNX)
Expand Down Expand Up @@ -140,7 +140,7 @@ from open_autonlu.methods.data_types import OodMethod, SaveFormat
pipeline = TextClassificationTrainingPipeline(
train_path="train.csv",
config_overrides={
"language": "en", # Prompt language for LLM pipelines ("en" or "ru")
"language": "en", # for non-en/ru also set "model_name_or_path"
"ood_method": OodMethod.LOGIT, # OOD detection method
"batch_size": 32, # Batch size
}
Expand Down Expand Up @@ -175,7 +175,7 @@ config_overrides = {

### LLM Data Augmentation

Automatically augment underrepresented classes using LLM generation. The `language` parameter controls which prompts are sent to the LLM (`"en"` for English, `"ru"` for Russian).
Automatically augment underrepresented classes using LLM generation. The `language` parameter controls which prompts are sent to the LLM (`"en"` for English, `"ru"` for Russian). For other languages, English prompts are used with an instruction to generate text in the language of the provided examples.

```python
import os
Expand Down Expand Up @@ -259,6 +259,40 @@ config_overrides = {
}
```

## Multilingual Support

The pipeline has been tested on **English (en), Russian (ru), French (fr), Chinese (zh), Arabic (ar), and Hindi (hi)**. Correct tokenization and NER behavior is guaranteed for these languages. Other languages are also supported but have not been explicitly validated.

### Model selection for non-default languages

Default models are only available for English (`bert-base-uncased`) and Russian (`ai-forever/ruBert-base`). For any other language you **must** set `model_name_or_path` in `config_overrides`:

```python
pipeline = TextClassificationTrainingPipeline(
train_path="train.csv",
config_overrides={
"language": "fr",
"model_name_or_path": "MODEL_NAME",
}
)
```

Any HuggingFace checkpoint that supports your target language can be used.

### AncSetFit template

When the pipeline selects AncSetFit (2-5 examples per class), it prepends a `template` string to each `anc_label` to form anchor sentences. Default templates exist only for English and Russian. For other languages a custom `template` **must** be provided, otherwise the pipeline will raise an error. Even for English/Russian, setting a domain-specific template is recommended for best results:

```python
config_overrides={
"language": "fr",
"model_name_or_path": "camembert-base",
"AncSetFitMethod": {
"template": "User asks the bot to perform a request using the skill: ", # write in your target language
}
}
```

## Data Formats

### Text Classification (CSV)
Expand Down
3 changes: 1 addition & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,13 +763,12 @@ def apply_dq_filter_ner():
or "labels" not in train_ds.column_names
):
return
language = st.session_state.get("language", "en")
records = []
for row in train_ds:
text = row["text"]
tokens = row["tokens"]
labels = row["labels"]
spans = convert_bio_to_spans(text, tokens, labels, language=language)
spans = convert_bio_to_spans(text, tokens, labels)
records.append({"text": text, "spans": spans})
elif indices and st.session_state.ner_train_data:
records = [
Expand Down
48 changes: 20 additions & 28 deletions open_autonlu/data/ner_data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,28 @@ def _load_data(self) -> DatasetDict:
if path := getattr(self, f"{split_name}_path"):
with open(path) as f:
records = json.load(f)
data_set = []
if "text" in records[0]:
if "spans" in records[0]:
for record in records:
tokens, bio_tags = convert_offsets_to_bio(
record["text"],
record["spans"],
language=self.language,
)
data_set.append(
{
"text": record["text"],
"tokens": tokens,
"labels": bio_tags,
}
)
else:
for record in records:
tokens, bio_tags = self.from_brackets(record["text"])
data_set.append(
{
"text": " ".join(tokens),
"tokens": tokens,
"labels": bio_tags,
}
)

splits[split_name] = Dataset.from_list(data_set)
splits[split_name] = Dataset.from_list(
[self._parse_record(r) for r in records]
)
return DatasetDict(splits)

def _parse_record(self, record: dict) -> dict:
if "tokens" in record and "labels" in record:
tokens, bio_tags = record["tokens"], record["labels"]
text = record.get("text", " ".join(tokens))
elif "text" in record and "spans" in record:
tokens, bio_tags = convert_offsets_to_bio(record["text"], record["spans"])
text = record["text"]
elif "text" in record:
tokens, bio_tags = self.from_brackets(record["text"])
text = " ".join(tokens)
else:
raise ValueError(
"Record must contain either 'tokens'+'labels', "
"'text'+'spans', or bracket format."
)
return {"text": text, "tokens": tokens, "labels": bio_tags}

@staticmethod
def from_brackets(text: str) -> List[str]:
"""
Expand Down
102 changes: 74 additions & 28 deletions open_autonlu/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
import re
import unicodedata
from typing import NamedTuple

import spacy
from datasets import Dataset

SUPPORTED_LANGUAGES = ("ru", "en")
_CJK_RANGES = (
(0x4E00, 0x9FFF),
(0x3400, 0x4DBF),
(0x20000, 0x2A6DF),
(0x2A700, 0x2B73F),
(0x2B740, 0x2B81F),
(0x2B820, 0x2CEAF),
(0xF900, 0xFAFF),
(0x2F800, 0x2FA1F),
)

_spacy_models: dict[str, spacy.language.Language] = {}
_TOKEN_RE = re.compile(r"\w+(?:-\w+)*|\S", re.UNICODE)


def _get_spacy_model(language: str) -> spacy.language.Language:
if language not in _spacy_models:
_spacy_models[language] = spacy.blank(language)
return _spacy_models[language]
def _is_cjk(char: str) -> bool:
cp = ord(char)
return any(lo <= cp <= hi for lo, hi in _CJK_RANGES)


def _is_mark(char: str) -> bool:
return unicodedata.category(char)[0] == "M"


class Token(NamedTuple):
Expand All @@ -20,25 +33,59 @@ class Token(NamedTuple):
text: str


def tokenize_with_offsets(text: str, language: str = "en") -> list[Token]:
"""Tokenize text and return list of (start, stop, text) for each token."""
if language not in SUPPORTED_LANGUAGES:
raise ValueError(
f"Unsupported language: '{language}'. Supported: {SUPPORTED_LANGUAGES}"
)

nlp = _get_spacy_model(language)
doc = nlp(text)
return [
Token(
start=token.idx if token.idx is not None else 0,
stop=token.idx + len(token.text)
if token.idx is not None
else len(token.text),
text=token.text,
)
for token in doc
]
def tokenize_with_offsets(text: str) -> list[Token]:
"""Tokenize text into words/punctuation and return character offsets.

Uses a Unicode-aware regex. CJK characters are split individually since they have no whitespace word boundaries.
Combining marks are merged back into the preceding token.
"""
raw: list[Token] = []
for m in _TOKEN_RE.finditer(text):
word = m.group()
offset = m.start()
if len(word) > 1 and any(_is_cjk(c) for c in word):
buf_start: int | None = None
for i, ch in enumerate(word):
if _is_cjk(ch):
if buf_start is not None:
raw.append(
Token(
start=offset + buf_start,
stop=offset + i,
text=word[buf_start:i],
)
)
buf_start = None
raw.append(Token(start=offset + i, stop=offset + i + 1, text=ch))
else:
if buf_start is None:
buf_start = i
if buf_start is not None:
raw.append(
Token(
start=offset + buf_start,
stop=offset + len(word),
text=word[buf_start:],
)
)
else:
raw.append(Token(start=offset, stop=m.end(), text=word))

if not raw:
return raw

merged: list[Token] = [raw[0]]
for tok in raw[1:]:
prev = merged[-1]
if tok.start == prev.stop and (
_is_mark(tok.text[0]) or _is_mark(prev.text[-1])
):
merged[-1] = Token(
start=prev.start, stop=tok.stop, text=text[prev.start : tok.stop]
)
else:
merged.append(tok)
return merged


def align_labels_with_tokens(labels, word_ids, b_to_i_label, label_all_tokens=True):
Expand Down Expand Up @@ -69,9 +116,8 @@ def convert_offsets_to_bio(
text: str,
labels: list,
label_key: str = "label",
language: str = "en",
) -> tuple[list[str], list[str]]:
substrings = tokenize_with_offsets(text, language=language)
substrings = tokenize_with_offsets(text)
tokens: list[str] = []
bio_tags: list[str] = []

Expand Down
44 changes: 41 additions & 3 deletions open_autonlu/llm_pipelines/prompts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import logging
from types import SimpleNamespace

import outlines

log = logging.getLogger(__name__)

SUPPORTED_PROMPT_LANGUAGES = frozenset({"en", "ru"})
FALLBACK_PROMPT_LANGUAGE = "en"

# ---------------------------------------------------------------------------
# English prompts
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -40,8 +46,15 @@ def generate_artificial_data(topic, texts, data_size):
You should provide {{ data_size }} texts."""


LANGUAGE_OF_EXAMPLES_INSTRUCTION = (
"Generate all texts in the same language as the provided examples."
)


@outlines.prompt
def generate_texts(topic, texts, data_size, domain_desc=None, label_desc=None):
def generate_texts(
topic, texts, data_size, domain_desc=None, label_desc=None, extra_instruction=None
):
"""### TASK
Generate EXACTLY {{ data_size }} new unique texts that belong to the category "{{ topic }}", preserve the characteristic linguistic features of this category, but contain new, original content.
{% if domain_desc %}
Expand All @@ -67,6 +80,9 @@ def generate_texts(topic, texts, data_size, domain_desc=None, label_desc=None):
5. Preserve the stylistic features and emotional tone typical of the category
6. Preserve the punctuation and formatting features characteristic of the category (for example, if punctuation is absent in the examples, this indicates that the generated texts should also follow this punctuation pattern)
7. Generate EXACTLY {{ data_size }} texts — no more and no fewer
{% if extra_instruction %}
8. {{ extra_instruction }}
{% endif %}

IMPORTANT: Your goal is not to mechanically change individual words in examples, but to create new texts that could organically fit into the corpus of texts of this category, preserving their stylistic and structural features.

Expand Down Expand Up @@ -255,11 +271,19 @@ def analyze_domain_ru(examples_by_label, label_names):
# ---------------------------------------------------------------------------


def _generate_texts_with_language_instruction(*args, **kwargs):
kwargs.setdefault("extra_instruction", LANGUAGE_OF_EXAMPLES_INSTRUCTION)
return generate_texts(*args, **kwargs)


def get_prompts(language: str = "en") -> SimpleNamespace:
"""Return a namespace of prompts for the given language.

Args:
language: Language code ("en" or "ru").
language: Language code. Only "en" and "ru" have dedicated prompts.
Any other language falls back to English prompts.
In fallback mode, the generation prompt includes an instruction to
generate texts in the same language as the provided examples.

Returns:
SimpleNamespace with attributes:
Expand All @@ -268,6 +292,15 @@ def get_prompts(language: str = "en") -> SimpleNamespace:
label_prefix, default_label_desc_template,
domain_description_header, label_descriptions_header
"""
requested_language = language
if language not in SUPPORTED_PROMPT_LANGUAGES:
log.warning(
"No prompts for language '%s'. Falling back to '%s' for data generation.",
language,
FALLBACK_PROMPT_LANGUAGE,
)
language = FALLBACK_PROMPT_LANGUAGE

if language == "ru":
return SimpleNamespace(
default_system_prompt=DEFAULT_SYSTEM_PROMPT_RU,
Expand All @@ -285,10 +318,15 @@ def get_prompts(language: str = "en") -> SimpleNamespace:
label_descriptions_header="ОПИСАНИЯ МЕТОК:",
)

generate_texts_fn = (
_generate_texts_with_language_instruction
if requested_language not in SUPPORTED_PROMPT_LANGUAGES
else generate_texts
)
return SimpleNamespace(
default_system_prompt=DEFAULT_SYSTEM_PROMPT,
analyzer_system_prompt=ANALYZER_SYSTEM_PROMPT,
generate_texts=generate_texts,
generate_texts=generate_texts_fn,
analyze_domain=analyze_domain,
generate_artificial_data=generate_artificial_data,
label_prefix="LABEL",
Expand Down
Loading