Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
import argparse
import json
import re

from pydantic import BaseModel, ValidationError, ValidationInfo, field_validator, model_validator
from typing import List, Optional

from pydantic import BaseModel, ValidationError, ValidationInfo, field_validator, model_validator

IMAGE_FORMATS = ["jpeg", "png", "gif", "webp"]
VIDEO_FORMATS = ["mov", "mkv", "mp4", "webm"]
MAX_NUM_IMAGES = 10
MODEL_TO_NUM_SAMPLES_MAP = {"micro": (8, 20000), "lite": (8, 20000), "pro": (8, 20000)}

INVALID_TOKENS_TEXT = [
"System:",
"SYSTEM:",
"User:",
"USER:",
"Bot:",
"BOT:",
"Assistant:",
"ASSISTANT:",
"Thought:",
"[EOS]",
"<image>",
"<video>",
]


class ConverseRoles:
"""Defines the possible roles in a conversation according to converse format"""
Expand All @@ -29,10 +43,13 @@ class NovaClientError(ValueError):
def __init__(self, message):
super().__init__(message)


class NovaInternalError(Exception):
"""Base exception for Nova Fine Tuning validation errors"""

pass


def check_jsonl_file(file_path):
"""Validates that the input file has a .jsonl extension."""
if not file_path.endswith(".jsonl"):
Expand Down Expand Up @@ -67,7 +84,7 @@ class S3Location(BaseModel):
def validate_format(cls, uri):
"""Validates that the URI starts with 's3://'."""
if not uri.startswith("s3://"):
raise ValueError(f"Invalid S3 URI, must start with 's3://'")
raise ValueError("Invalid S3 URI, must start with 's3://'")
is_valid_path(uri.replace("s3://", ""))
return uri

Expand Down Expand Up @@ -122,6 +139,14 @@ def validate_model_fields(cls, values):
)
return values

@field_validator("text")
def validate_text(cls, text: str):
if not text:
return text

validate_invalid_tokens(text)
return text


class Message(BaseModel):
"""Represents a conversation message with role and content."""
Expand Down Expand Up @@ -207,6 +232,14 @@ class SystemMessage(BaseModel):

text: str

@field_validator("text")
def validate_text(cls, text: str):
if not text:
return text

validate_invalid_tokens(text)
return text


class ConverseDatasetSample(BaseModel):
"""Represents a complete conversation sample with system message and message turns."""
Expand All @@ -224,41 +257,68 @@ def validate_data_sample_rules(cls, messages):

def validate_converse_dataset(args):
"""Validates the entire conversation dataset against Nova format requirements."""
samples = load_jsonl_data(args.input_file)
num_samples = len(samples)
validate_data_record_bounds(num_samples, args.model_name)
try:
samples = load_jsonl_data(args.input_file)
num_samples = len(samples)
print(f"Loaded {num_samples} samples from {args.input_file}")
validate_data_record_bounds(num_samples, args.model_name)
except Exception as e:
print(f"Error loading or validating file bounds: {e}")
raise

error_message = ""
failed_samples_id_list = []

print(f"Validating samples for model: {args.model_name}")
for i, sample in enumerate(samples):
try:
ConverseDatasetSample.model_validate(sample, context={"model_name": args.model_name})
except ValidationError as e:
failed_samples_id_list.append(i)
error_message += f"Sample {i} - "
error_message += f"\nSample {i}:\n"
for err in e.errors():
err["msg"] = err["msg"].replace("Value error, ", "")
sample_error_message = f"{err['loc']}: {err['msg']} (type={err['type']}). "
sample_error_message = (
f" - Location {err['loc']}: {err['msg']} (type={err['type']})\n"
)
error_message += sample_error_message
except Exception as e:
raise NovaInternalError(f"Error occured: {e}")
raise NovaInternalError(f"Unexpected error occurred in sample {i}: {e}")

if error_message:
prefix_str = f"Problematic samples: "

if len(failed_samples_id_list) > 3:
first_sample_id = failed_samples_id_list[0]
second_sample_id = failed_samples_id_list[1]
last_sample_id = failed_samples_id_list[-1]
failed_samples_str = f"[{first_sample_id}, {second_sample_id}, ...{last_sample_id}]. "
failed_samples_str = f"[{first_sample_id}, {second_sample_id}, ...{last_sample_id}]"
else:
failed_samples_str = f"{failed_samples_id_list}. "
failed_samples_str = f"{failed_samples_id_list}"

final_err_msg = prefix_str + failed_samples_str + error_message
final_err_msg = (
f"Validation failed for samples: {failed_samples_str}\n\n"
f"Note: Sample IDs are zero-indexed.\n"
f"{error_message}"
)
raise NovaClientError(final_err_msg)
else:
print("Validation successful, all samples passed")
print("Validation successful, all samples passed!")


def validate_invalid_tokens(text: str):
"""Validates that the input text does not contain any disallowed tokens"""

stripped_text = text.strip()
client_invalid_tokens = []
for invalid_token in INVALID_TOKENS_TEXT:
if invalid_token in stripped_text:
client_invalid_tokens.append(f"`{invalid_token}`")

if client_invalid_tokens:
client_invalid_tokens_str = ", ".join(client_invalid_tokens)
raise ValueError(
f"Invalid text content, following tokens are invalid: {client_invalid_tokens_str}. Please check documentation for other invalid tokens"
)


def check_roles_order(messages):
Expand Down Expand Up @@ -289,7 +349,7 @@ def is_valid_path(file_path):
pattern = r"^[\w\-/\.]+$"
if not re.match(pattern, file_path):
raise ValueError(
f"Invalid characters in 'uri'. Only alphanumeric, underscores, hyphens, slashes, and dots are allowed"
"Invalid characters in 'uri'. Only alphanumeric, underscores, hyphens, slashes, and dots are allowed"
)


Expand Down