Skip to content

Customize Engine

marsggbo edited this page Apr 18, 2022 · 1 revision

engine

`engine`的调用入口

hyperbox下有一个engine的模块,它提供了用户一个自定义流程的接口,调用的入口在train.py里,核心代码如下

# hyperbox/train.py
def train(config):
    ...
    if config.get("only_test"):
        # Only test the model
        ...
    elif config.get('engine') is not None and len(config.get('engine')) > 0:
        # customized engine
        engine = hydra.utils.instantiate(config.engine, 
            trainer=trainer, model=model, datamodule=datamodule, cfg=config, _recursive_=False)
        log.info(f"Running customized engine: {engine.__class__}")
        result = engine.run()
    else:
        # Train the model
        ...

可以看到默认情况下我们的流程支持

  • only_test: 即如果你设置了这个参数,那么就只会运行test
  • engine参数默认为空字典,如果你没有设置,就会执行trainer.fit
  • 第三种情况就是我们指定了自定义的engine,就表示按照我们自己的的流程执行代码。可以看到首先代码会拿到config.engine定义的参数,之后也会把trainer, model等已经实例化的对象作为参数,通过调用hydra.utils.instantiate来实例化自定义的engine类。那么如何创建自定义engine呢?
自定义engine
  1. 首先你需要按照hyperbox_app里的步骤创建你自己的新项目。假设你的新项目的目录结构如下
hyperbox_app
|_my_app
  |_ __init__.py
  |_ configs (该目录需要与`hyperbox`的`config`目录保持一致)
    |_ experiment
      |_ my_exp.yaml
    |_ engine
      |_ my_engine.yaml
  |_ my_engine.py
  1. my_engine.py示例如下
from omegaconf import DictConfig

from hyperbox.engine.base_engine import BaseEngine

class MyEngine(BaseEngine):
    def __init__(
        self,
        trainer: "pl.trainer",
        model: "pl.lightning_module",
        datamodule: "pl.datamodule",
        cfg: DictConfig,
        new_arg1,
        new_arg2
    ):
        super().__init__(trainer, model, datamodule, cfg)

    def run(self):
        print('this is new engine')
        result = ... # should be a dict
        return result
  1. my_engine.yaml示例如下
_target_: hyperbox_app.my_app.my_engine.MyEngine
new_arg1: 16
new_arg2: 'test'
  1. 运行代码

假设你这个项目的绝对路径是 /abs/to/hyperbox_app/my_app,运行命令如下:

python -m hyperbox.run hydra.searchpath=[file:///abs/to/hyperbox_app/my_app/configs] experiment=my_exp engine=my_engine
Template
Clone this wiki locally