-
Notifications
You must be signed in to change notification settings - Fork 4
Customize Engine
marsggbo edited this page Apr 18, 2022
·
1 revision
`engine`的调用入口
hyperbox
下有一个engine
的模块,它提供了用户一个自定义流程的接口,调用的入口在train.py
里,核心代码如下
# hyperbox/train.py
def train(config):
...
if config.get("only_test"):
# Only test the model
...
elif config.get('engine') is not None and len(config.get('engine')) > 0:
# customized engine
engine = hydra.utils.instantiate(config.engine,
trainer=trainer, model=model, datamodule=datamodule, cfg=config, _recursive_=False)
log.info(f"Running customized engine: {engine.__class__}")
result = engine.run()
else:
# Train the model
...
可以看到默认情况下我们的流程支持
-
only_test
: 即如果你设置了这个参数,那么就只会运行test -
engine
参数默认为空字典,如果你没有设置,就会执行trainer.fit
- 第三种情况就是我们指定了自定义的engine,就表示按照我们自己的的流程执行代码。可以看到首先代码会拿到
config.engine
定义的参数,之后也会把trainer, model等已经实例化的对象作为参数,通过调用hydra.utils.instantiate
来实例化自定义的engine类。那么如何创建自定义engine呢?
自定义engine
- 首先你需要按照hyperbox_app里的步骤创建你自己的新项目。假设你的新项目的目录结构如下
hyperbox_app
|_my_app
|_ __init__.py
|_ configs (该目录需要与`hyperbox`的`config`目录保持一致)
|_ experiment
|_ my_exp.yaml
|_ engine
|_ my_engine.yaml
|_ my_engine.py
-
my_engine.py
示例如下
from omegaconf import DictConfig
from hyperbox.engine.base_engine import BaseEngine
class MyEngine(BaseEngine):
def __init__(
self,
trainer: "pl.trainer",
model: "pl.lightning_module",
datamodule: "pl.datamodule",
cfg: DictConfig,
new_arg1,
new_arg2
):
super().__init__(trainer, model, datamodule, cfg)
def run(self):
print('this is new engine')
result = ... # should be a dict
return result
-
my_engine.yaml
示例如下
_target_: hyperbox_app.my_app.my_engine.MyEngine
new_arg1: 16
new_arg2: 'test'
- 运行代码
假设你这个项目的绝对路径是 /abs/to/hyperbox_app/my_app
,运行命令如下:
python -m hyperbox.run hydra.searchpath=[file:///abs/to/hyperbox_app/my_app/configs] experiment=my_exp engine=my_engine