Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Tune function and CLI command #1397

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dffml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class DuplicateName(Exception):
"train": "high_level.ml",
"predict": "high_level.ml",
"score": "high_level.ml",
"tune": "high_level.ml",
"load": "high_level.source",
"save": "high_level.source",
"run": "high_level.dataflow",
Expand Down
3 changes: 2 additions & 1 deletion dffml/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from .dataflow import Dataflow
from .config import Config
from .ml import Train, Accuracy, Predict
from .ml import Train, Accuracy, Predict, Tune
from .list import List

version = VERSION
Expand Down Expand Up @@ -366,6 +366,7 @@ class CLI(CMD):
train = Train
accuracy = Accuracy
predict = Predict
tune = Tune
service = services()
dataflow = Dataflow
config = Config
58 changes: 57 additions & 1 deletion dffml/cli/ml.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import inspect

from ..model.model import Model
from ..tuner.tuner import Tuner
from ..source.source import Sources, SubsetSources
from ..util.cli.cmd import CMD, CMDOutputOverride
from ..high_level.ml import train, predict, score
from ..high_level.ml import train, predict, score, tune
from ..util.config.fields import FIELD_SOURCES
from ..util.cli.cmds import (
SourcesCMD,
Expand All @@ -15,6 +16,7 @@
)
from ..base import config, field
from ..accuracy import AccuracyScorer

from ..feature import Features


Expand Down Expand Up @@ -118,3 +120,57 @@ class Predict(CMD):

record = PredictRecord
_all = PredictAll


@config
class TuneCMDConfig:
model: Model = field("Model used for ML", required=True)
tuner: Tuner = field("Tuner to optimize hyperparameters", required=True)
scorer: AccuracyScorer = field(
"Method to use to score accuracy", required=True
)
features: Features = field("Predict Feature(s)", default=Features())
sources: Sources = FIELD_SOURCES


class Tune(MLCMD):
"""Optimize hyperparameters of model with given sources"""

CONFIG = TuneCMDConfig

async def run(self):
# Instantiate the accuracy scorer class if for some reason it is a class
# at this point rather than an instance.
if inspect.isclass(self.scorer):
self.scorer = self.scorer.withconfig(self.extra_config)
if inspect.isclass(self.tuner):
self.tuner = self.tuner.withconfig(self.extra_config)

train_source = test_source = None

# Check for tags to determine train/test sets
for source in self.sources:

if hasattr(source, "tag") and source.tag == "train":
train_source = source
if hasattr(source, "tag") and source.tag == "test":
test_source = source

if not train_source or not test_source:
# If tags not found, default to positional
if len(self.sources) >= 2:
train_source = self.sources[0]
test_source = self.sources[1]
elif not train_source:
raise NotImplementedError("Train set not found.")
else:
raise NotImplementedError("Test set not found.")

return await tune(
self.model,
self.tuner,
self.scorer,
self.features,
[train_source],
[test_source],
)
145 changes: 145 additions & 0 deletions dffml/high_level/ml.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import contextlib
from typing import Union, Dict, Any, List


from ..record import Record
from ..source.source import BaseSource
from ..feature import Feature, Features
from ..model import Model, ModelContext
from ..util.internal import records_to_sources, list_records_to_dict
from ..accuracy.accuracy import AccuracyScorer, AccuracyContext
from ..tuner import Tuner, TunerContext


async def train(model, *args: Union[BaseSource, Record, Dict[str, Any], List]):
Expand Down Expand Up @@ -293,3 +295,146 @@ async def predict(
)
if update:
await sctx.update(record)

async def tune(
model,
tuner: Union[Tuner, TunerContext],
accuracy_scorer: Union[AccuracyScorer, AccuracyContext],
features: Union[Feature, Features],
train_ds: Union[BaseSource, Record, Dict[str, Any], List],
valid_ds: Union[BaseSource, Record, Dict[str, Any], List],
) -> float:

"""
Tune the hyperparameters of a model with a given tuner.


Parameters
----------
model : Model
Machine Learning model to use. See :doc:`/plugins/dffml_model` for
models options.
tuner: Tuner
Hyperparameter tuning method to use. See :doc:`/plugins/dffml_tuner` for
tuner options.
train_ds : list
Input data for training. Could be a ``dict``, :py:class:`Record`,
filename, one of the data :doc:`/plugins/dffml_source`, or a filename
with the extension being one of the data sources.
valid_ds : list
Validation data for testing. Could be a ``dict``, :py:class:`Record`,
filename, one of the data :doc:`/plugins/dffml_source`, or a filename
with the extension being one of the data sources.


Returns
-------
float
A decimal value representing the result of the accuracy scorer on the given
test set. For instance, ClassificationAccuracy represents the percentage of correct
classifications made by the model.

Examples
--------

>>> import asyncio
>>> from dffml import *
>>>
>>> model = SLRModel(
... features=Features(
... Feature("Years", int, 1),
... ),
... predict=Feature("Salary", int, 1),
... location="tempdir",
... )
>>>
>>> async def main():
... score = await tune(
... model,
... ParameterGrid(objective="min"),
... MeanSquaredErrorAccuracy(),
... Features(
... Feature("Years", float, 1),
... ),
... [
... {"Years": 0, "Salary": 10},
... {"Years": 1, "Salary": 20},
... {"Years": 2, "Salary": 30},
... {"Years": 3, "Salary": 40}
... ],
... [
... {"Years": 6, "Salary": 70},
... {"Years": 7, "Salary": 80}
... ]
Comment on lines +359 to +368
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we want the train and test sets to be passed in as keyword arguments like this:

score = await tune(model, 
                ParameterGrid(objective="min"),
    ...         MeanSquaredErrorAccuracy(),
    ...         Features(
    ...             Feature("Years", float, 1),
    ...         ),
              train =   [
              {"Years": 0, "Salary": 10},
              {"Years": 1, "Salary": 20},
              {"Years": 2, "Salary": 30},
              {"Years": 3, "Salary": 40}
           ],
          test =  [
             {"Years": 6, "Salary": 70},
             {"Years": 7, "Salary": 80}
      ])

...
... )
... print(f"Tuner score: {score}")
...
>>> asyncio.run(main())
Tuner score: 0.0
"""

if not isinstance(features, (Feature, Features)):
raise TypeError(
f"features was {type(features)}: {features!r}. Should have been Feature or Features"
)
if isinstance(features, Feature):
features = Features(features)
if hasattr(model.config, "predict"):
if isinstance(model.config.predict, Features):
predict_feature = [
feature.name for feature in model.config.predict
]
else:
predict_feature = [model.config.predict.name]

def records_to_dict_check(ds):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's pull this out into the global scope or dffml/util/internal.py

if hasattr(model.config, "features") and any(
isinstance(td, list) for td in ds
):
return list_records_to_dict(
[feature.name for feature in model.config.features]
+ predict_feature,
*ds,
model=model,
)
return ds

train_ds = records_to_dict_check(train_ds)
valid_ds = records_to_dict_check(valid_ds)


async with contextlib.AsyncExitStack() as astack:
# Open sources
train = await astack.enter_async_context(records_to_sources(*train_ds))
test = await astack.enter_async_context(records_to_sources(*valid_ds))
# Allow for keep models open
if isinstance(model, Model):
model = await astack.enter_async_context(model)
mctx = await astack.enter_async_context(model())
elif isinstance(model, ModelContext):
mctx = model

# Allow for scorers to be kept open
if isinstance(accuracy_scorer, AccuracyScorer):
accuracy_scorer = await astack.enter_async_context(accuracy_scorer)
actx = await astack.enter_async_context(accuracy_scorer())
elif isinstance(accuracy_scorer, AccuracyContext):
actx = accuracy_scorer
else:
# TODO Replace this with static type checking and maybe dynamic
# through something like pydantic. See issue #36
raise TypeError(f"{accuracy_scorer} is not an AccuracyScorer")

if isinstance(tuner, Tuner):
tuner = await astack.enter_async_context(tuner)
tctx = await astack.enter_async_context(tuner())
elif isinstance(tuner, TunerContext):
tctx = tuner
else:
raise TypeError(f"{tuner} is not an Tuner")

return float(
await tctx.optimize(mctx, features, actx, train, test)
)

16 changes: 16 additions & 0 deletions dffml/noasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
train as high_level_train,
score as high_level_score,
predict as high_level_predict,
tune as high_level_tune,
)


Expand All @@ -24,6 +25,21 @@ def train(*args, **kwargs):
)
)

def tune(*args, **kwargs):
return asyncio.run(high_level_tune(*args, **kwargs))


tune.__doc__ = (
high_level_tune.__doc__.replace("await ", "")
.replace("async ", "")
.replace("asyncio.run(main())", "main()")
.replace(" >>> import asyncio\n", "")
.replace(
" >>> from dffml import *\n",
" >>> from dffml import *\n >>> from dffml.noasync import tune\n",
)
)


def score(*args, **kwargs):
return asyncio.run(high_level_score(*args, **kwargs))
Expand Down
1 change: 1 addition & 0 deletions dffml/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def inpath(binary):
("operations", "nlp"),
("service", "http"),
("source", "mysql"),
("tuner", "bayes_opt_gp"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets have a simpler, more understandable entrypoint

]


Expand Down
1 change: 0 additions & 1 deletion dffml/skel/config/README.rst

This file was deleted.

1 change: 1 addition & 0 deletions dffml/skel/config/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../common/README.rst
1 change: 0 additions & 1 deletion dffml/skel/model/README.rst

This file was deleted.

1 change: 1 addition & 0 deletions dffml/skel/model/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../common/README.rst
1 change: 0 additions & 1 deletion dffml/skel/operations/README.rst

This file was deleted.

1 change: 1 addition & 0 deletions dffml/skel/operations/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../common/README.rst
1 change: 0 additions & 1 deletion dffml/skel/service/README.rst

This file was deleted.

1 change: 1 addition & 0 deletions dffml/skel/service/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../common/README.rst
1 change: 0 additions & 1 deletion dffml/skel/source/README.rst

This file was deleted.

1 change: 1 addition & 0 deletions dffml/skel/source/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../common/README.rst
1 change: 0 additions & 1 deletion dffml/tuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@
TunerContext,
Tuner,
)
from .parameter_grid import ParameterGrid
Loading