Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross-Validators (official feature or contrib?) #1384

Open
nsinenian opened this issue Oct 14, 2020 · 6 comments
Open

Cross-Validators (official feature or contrib?) #1384

nsinenian opened this issue Oct 14, 2020 · 6 comments
Labels

Comments

@nsinenian
Copy link

🚀 Feature

It would be useful to adapt the existing Engine class (with its use of decorators/callbacks) to provide for an extensible framework for k-fold/stratified etc. cross validation. The process of k-fold, for instance, always involves the same steps: creating K models, K trainers, running them, then aggregating the results from K runs.

A usage pattern similar to the following might be sufficiently general and useful:

cross_validator = CrossValidator()

@cross_validator.on(CrossValidator.TRAIN_FOLD)
def train_fold(model,optimizer,device,criterion):
    ...
    (use engine state to determine number folds, fold index, generate datasets for that fold, or have this already prepared beforehand)

    (make your per-fold trainer and evaluator, and call run() to train that fold)

    (this is essentially what you do now per docs, except you wrap it in a function)

@cross_validator.on(CrossValidator.MODEL_SETUP)
def model_setup_step():
   ...
   (make new model, optimizer, criterion, called by engine at start of each fold)

   return model, optimizer, criterion

@cross_validator.on(CrossValidator.TRAINING_COMPLETED)
def combine_fold_metrics(array_of_evaluators_across_folds):
    ...
   (can compute and assign aggregate metrics from the folds to CrossValidator state dict)

cross_validator.run(folds=5)

A quick look at the source code for Engine class suggests that such a pattern can be implemented by defining new events and replacing the run and _internal_run functions. This approach would re-use existing code and provide modularity (e.g., can easily swap one style of cross-validation for another, doesn't preclude use of different engines for training/evaluation on a per-fold basis, lets user decide how to aggregate data from different folds).

@sdesrozis
Copy link
Contributor

sdesrozis commented Oct 14, 2020

@nsinenian Thank you for suggesting this feature. I strongly agree with you, it would be an awesome feature !!

About implementation, we should not modify Engine class. The CrossValidator class sounds like a contrib high level tool and should be articulated around existing concepts (engines, events, etc.). IMO design a smart api for that feature is the hard part of the job. You've provided some ideas based on decoration but I think we need to dig deeper to make sure everything is ok.

Would you like to contribute on this ? It would be great! A good starting point should be a PR with api proposition to discuss around it.

Thank you!

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 14, 2020

@nsinenian Thanks for the feature request ! It can be useful and can be added to contrib module.
Recently I did a similar thing with Engine and few methods like you describe. I think CrossValidator can be a class that reuse Engine (instead of deriving).

I propose to start discussion here before moving to PR.

@nsinenian
Copy link
Author

@sdesrozis @vfdev-5 I should have been more precise with my language in the original post: I was not suggesting that engine be altered, but rather that a class be derived from engine to provide the requisite functionality, and only if, as you say, the pattern is what is ultimately desired (otherwise there's going to be a lot a rewrite of event-driven/callback registration code.

That brings us to the next essential point - whether the decoration/events pattern is ideal, and that (as part of the API) is the hard part as you say. My experience is limited to certain use cases, so I think this will require much input from a wider of users to get right.

@vfdev-5 I can see how it might be re-used (without subclassing) - register custom events, etc. But you still have the left over standard/default events that are epoch/iteration-driven. I assume there is nothing that would result in side effects? I will take some time to better understand the engine source code as written. Can you provide details as to what you did doing and your specific use case?

My typical usage pattern involves cross-validation with a scaled down model/dataset at first to get the hyperparameters right. Then I re-run the code on the full dataset to collect metrics. I combine metrics from each fold (e.g., confusion matrix) to get a sense of how the model will perform (aggregate precision/recall), then I use the same parameters to train the model using all of the data. The loose proposal I made earlier falls short on that latter point - the final training. I'm curious as to how you would use this hypothetical tool.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 14, 2020

@nsinenian thanks for the details !

rather that a class be derived from engine to provide the requisite functionality, and only if, as you say, the pattern is what is ultimately desired (otherwise there's going to be a lot a rewrite of event-driven/callback registration code.

Yes, you are right. I think a small refactoring of Engine to split run two loops logic and event-driven/callback registration code can be useful.

That brings us to the next essential point - whether the decoration/events pattern is ideal, and that (as part of the API) is the hard part as you say. My experience is limited to certain use cases, so I think this will require much input from a wider of users to get right.

That's a good point ! For instance I have no strong opinion about either deriving from a base class or using events pattern.
If we are using events pattern to customize, certain common logic like epoch loop + iteration loop with events should be defined in a general way. In addition we should have a transition state where handlers could provide the results (like engine.state). Have to try both and see what's better...

Here is how I did it recently. It is not a perfect solution and it is finetuned to the need.

config = {...}

def get_model(config):
    ...
    return model

def initialize_solver(model, config):
    ...
    return optimizer, criterion, lr_scheduler


def train_model(config, train_transforms, model, initialize_solver, debug=False):
    fold_index = config["fold_index"]

    ... setup dataflow, trainer, evaluator, logging etc 

    return train_results, val_results, output_path, best_checkpoint
    

results_per_fold = []
for fold_index in range(config["num_folds"]):
    print("\n\n\nTrain model on fold {}/{}".format(fold_index + 1, config["num_folds"]))
    config["fold_index"] = fold_index
    model = get_model(config)
    train_results, val_results, output_path, best_checkpoint = train_model(
        config, train_transforms, model, initialize_solver, debug=debug
    )
    results_per_fold.append((train_results, val_results, output_path, best_checkpoint))
    print("-> output_path:", output_path)
    print("-> best_checkpoint:", best_checkpoint)

Anyway, above code can be improved to be a bit more generic...

My typical usage pattern involves cross-validation with a scaled down model/dataset at first to get the hyperparameters right. Then I re-run the code on the full dataset to collect metrics. I combine metrics from each fold (e.g., confusion matrix) to get a sense of how the model will perform (aggregate precision/recall), then I use the same parameters to train the model using all of the data. The loose proposal I made earlier falls short on that latter point - the final training. I'm curious as to how you would use this hypothetical tool.

Interesting. Correct me if misunderstood your approach which is

for fold in range(num_folds):
    small_model, small_data = get_downscaled_md(fold, other_config)
    hp = compute_hp(small_model, small_data, other_config)
    large_model, large_data = get_full_md(fold, other_config)
    results = train_model(large_model, large_data, hp, other_config)

There are probably various ways to factorize that. But probably, for sake of transparency I'd keep it like that.
Otherwise

class CVState:    
    # ... is subscriptable ...
    pass

class CVTrainer(EventsDriven):
    
    def __init__(self, train_fold_fn):
        self.train_fold_fn = train_fold_fn
        self.state = CVState()
        self.register_events("TRAIN_FOLD_STARTED", "TRAIN_FOLD_COMPLETED")

    def run(num_folds):
        self.fire_event("STARTED")
        for fold in range(num_folds):
            self.fire_event("TRAIN_FOLD_STARTED", fold)
            result = self.train_fold_fn(fold)
            self.state.per_fold.append(result)
            self.fire_event("TRAIN_FOLD_COMPLETED", fold)
        self.fire_event("COMPLETED")
        return self.state
    

def train_model_on_fold(trainer, fold):
    hp = trainer.state[fold].hp  # 
    large_model, large_data = get_full_md(fold, other_config)
    results = train_model(large_model, large_data, hp, other_config)
    

cvtrainer = CVTrainer(train_model_on_fold)
cvtrainer.state = other_config


@cvtrainer.on("TRAIN_FOLD_STARTED")
def compute_hyperparams(trainer, fold):    
    small_model, small_data = get_downscaled_md(fold, trainer.state.other_config)
    trainer.state[fold].hp = compute_hp(small_model, small_data, trainer.state.other_config)

cvtrainer.run(5)

@nsinenian
Copy link
Author

@vfdev-5 My use case is more like this (pseudocode):

Step 1: Coarse Tune

generate small data (and small/full model if needed)
train-test split small data (1:1 or 2:1)
set hyperparameters
train model
collect metrics
if poor convergence/param issues:
    manually adjust hyperparameters and re-try;  
else 
    proceed Step 2

Step 1 is meant to be quick, to get a feel for the parameter space. This step can easily just be implemented using the existing Ignite frameworks. Now onto Step 2.

Step 2: Fine tune

generate full data and full model
make K folds (or stratified or whatever)
for each fold:
    train each fold using parameters established in Step 1
    collect metrics for each fold
combine metrics from each fold
if performance (as determined by metrics) acceptable:
    goto Step 3
else:
    adjust parameters (currently, manual intervention to adjust, but could pave way for grid search etc)
    goto "for each fold again" step above

Step 3: Train "final" model

Here, the entire dataset is used for training, and this model is the model used for predictions. No metrics. The combined metrics for Step 2 are the benchmark for this model

Step 1 is not too critical, but Step 3 is essential. I am assuming that for a use case like this, with serial steps, we would apply the ignite framework 3 times. Steps 1 and 3 can be readily implemented using the current version of the framework. So it is really Step 2. I assume that trying to something this complex into a single class/instance/what have you of ignite would require the framework to be too specialized.

Thoughts?

BTW, not so sure about Step 1, but Steps 2-3 are fairly common when doing K-fold, from what I understand (in general, not specific to Torch or Ignite).

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 14, 2020

@nsinenian OK, I see, by the step 2 you would like to perform hyperparameters tuning on K folds and in the step 3 retrain the model.

For hyperparameters tuning, we have an example of doing that with Ax: https://github.com/pytorch/ignite/blob/master/examples/notebooks/Cifar10_Ax_hyperparam_tuning.ipynb
Optuna also provides an example for that : blog and code

For a K-folded case, we have to write a proper train_model_kfolds method and provide another method to aggregate metrics. Finally, run_experiment function (from Ax notebook) can simply those functions and return final score. With the current API we have to manually build those methods. Let me check if we can provide a simplifying interface for that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants