Skip to content
This repository was archived by the owner on May 11, 2023. It is now read-only.
This repository was archived by the owner on May 11, 2023. It is now read-only.

feat: Scaler for the Dataset. #4

@daniel-dodd

Description

@daniel-dodd

Would be nice to have a Scaler object that scales inputs or and outputs of a jaxutils.Dataset, and that saves the mean and variance, to scale test inputs for later.

from jaxutils import PyTree

class Scaler(PyTree):
  ...

# call method scales data and "fits the scale transform"

train = jaxutils.Dataset(X=..., y=...)
test = jaxutils.Dataset(X=..., y=...)

scaler = Scaler(...)
scaled_train = Scaler(train) # learn the transform
scaled_test = Scaler(test) # scales the test data, under the learnt transform of the train data

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions