Skip to content

Latest commit

 

History

History
86 lines (51 loc) · 1.96 KB

File metadata and controls

86 lines (51 loc) · 1.96 KB

PyTorch Lightning训练工具包

Requirements

  • numpy
  • pynvml
  • pytorch_lightning
  • torch
  • torchinfo
  • rich

Install Dependences

假设已经完成了PyTorchPytorch_lightning安装!

pip install pynvml torchinfo

convert_str_params_list

Wandb传入的str根据指定分隔符转译为指定数据类型的list

Params

  • str_params: strWandb输入的字符串。
  • params_type: str,转译后的数据类型。
  • split_marker: str,default: #,分隔符。

Return

  • list<eval('params_type')>,使用指定数据类型的数组。

set_random_seed

PyTorchPytorch_lightning设置固定随机种子。

Params

  • random_seed: float / int,指定的随机种子。

Return

  • null

predict_model_memory_usage

使用torchinfosummary函数预估在Float32格式下训练模型所需的显存占用(GB)。

Params

  • model: torch.nn.Module,指定的模型。
  • input_shape: list[...],与summary输入格式一致的模型输入数据形状。

Return

  • float,显存占用(GB)。

auto_find_memory_free_card

仅适用于单卡训练,根据模型的显存占用情况自动选择剩余显存最大的显卡,如果显存均不足,则根据设定的等待时间(秒)自动等待直到超时退出或有满足要求的显卡激活训练过程。

Params

  • card_list: list[int, ...],指定使用的显卡编号列表。
  • model_memory_usage: floatFloat32下模型的显存占用(GB)。
  • idle: bool,default: false,是否启动等待模式;若不启用,则未找到满足显存要求的显卡时直接报错退出,若启用,则按照idle_max_seconds的设置等待空余显卡直到超时。
  • idle_max_seconds: int,超时时长,仅在启用idle时生效。

Return

  • int,可供使用的显卡编号。

Made By Egg Targaryen

MIT License