Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

feat: Scaler for the Dataset. #4

Open
daniel-dodd opened this issue Dec 20, 2022 · 2 comments
Open

feat: Scaler for the Dataset. #4

daniel-dodd opened this issue Dec 20, 2022 · 2 comments
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@daniel-dodd
Copy link
Member

Would be nice to have a Scaler object that scales inputs or and outputs of a jaxutils.Dataset, and that saves the mean and variance, to scale test inputs for later.

from jaxutils import PyTree

class Scaler(PyTree):
  ...

# call method scales data and "fits the scale transform"

train = jaxutils.Dataset(X=..., y=...)
test = jaxutils.Dataset(X=..., y=...)

scaler = Scaler(...)
scaled_train = Scaler(train) # learn the transform
scaled_test = Scaler(test) # scales the test data, under the learnt transform of the train data
@daniel-dodd daniel-dodd added enhancement New feature or request good first issue Good for newcomers labels Dec 20, 2022
@st--
Copy link

st-- commented Feb 23, 2023

Instead of recoding from scratch, how about interfacing with sklearn's preprocessing tools? That way all of them would become available in one go. Could be simply a wrapper that unbundles the Dataset X/y attributes?

@daniel-dodd
Copy link
Member Author

Thanks @st--, this is a nice suggestion. I agree, and would love to see this functionality. :)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants