Skip to content
This repository was archived by the owner on May 11, 2023. It is now read-only.
This repository was archived by the owner on May 11, 2023. It is now read-only.

dev: Train test split for a dataset #5

@daniel-dodd

Description

@daniel-dodd

Would be nice to have a train test split akin to scikit-learn, for the Dataset.

import jax.random as jr
from jax.random import KeyArray
from jaxutils import Dataset

# Need to define this function
def train_test_split(data: Dataset, Key: KeyArray, test_size: float, ...) -> Tuple[Dataset, Dataset]
    ...

# Example usage:
data = Dataset(...)
key = jr.PRNGKey(42)
size = 0.3
train, test = train_test_split(data, key, size)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions