diff --git a/ding/torch_utils/checkpoint_helper.py b/ding/torch_utils/checkpoint_helper.py index d51d15d07e..5d600b556e 100644 --- a/ding/torch_utils/checkpoint_helper.py +++ b/ding/torch_utils/checkpoint_helper.py @@ -11,7 +11,7 @@ def build_checkpoint_helper(cfg): - r""" + """ Overview: Use config to build checkpoint helper. Arguments: @@ -23,18 +23,18 @@ def build_checkpoint_helper(cfg): class CheckpointHelper: - r""" + """ Overview: Help to save or load checkpoint by give args. - Interface: - save, load + Interfaces: + ``__init__``, ``save``, ``load``, ``_remove_prefix``, ``_add_prefix``, ``_load_matched_model_state_dict`` """ def __init__(self): pass def _remove_prefix(self, state_dict: dict, prefix: str = 'module.') -> dict: - r""" + """ Overview: Remove prefix in state_dict Arguments: @@ -53,7 +53,7 @@ def _remove_prefix(self, state_dict: dict, prefix: str = 'module.') -> dict: return new_state_dict def _add_prefix(self, state_dict: dict, prefix: str = 'module.') -> dict: - r""" + """ Overview: Add prefix in state_dict Arguments: @@ -77,7 +77,7 @@ def save( prefix_op: str = None, prefix: str = None, ) -> None: - r""" + """ Overview: Save checkpoint by given args Arguments: @@ -119,7 +119,7 @@ def save( logger.info('save checkpoint in {}'.format(path)) def _load_matched_model_state_dict(self, model: torch.nn.Module, ckpt_state_dict: dict) -> None: - r""" + """ Overview: Load matched model state_dict, and show mismatch keys between model's state_dict and checkpoint's state_dict Arguments: @@ -169,7 +169,7 @@ def load( logger_prefix: str = '', state_dict_mask: list = [], ): - r""" + """ Overview: Load checkpoint by given path Arguments: @@ -254,22 +254,36 @@ def load( class CountVar(object): - r""" + """ Overview: Number counter - Interface: - val, update, add + Interfaces: + ``__init__``, ``update``, ``add`` + Properties: + - val (:obj:`int`): the value of the counter """ def __init__(self, init_val: int) -> None: + """ + Overview: + Init the var counter + Arguments: + - init_val (:obj:`int`): the init value of the counter + """ + self._val = init_val @property def val(self) -> int: + """ + Overview: + Get the var counter + """ + return self._val def update(self, val: int) -> None: - r""" + """ Overview: Update the var counter Arguments: @@ -278,7 +292,7 @@ def update(self, val: int) -> None: self._val = val def add(self, add_num: int): - r""" + """ Overview: Add the number to counter Arguments: @@ -288,7 +302,7 @@ def add(self, add_num: int): def auto_checkpoint(func: Callable) -> Callable: - r""" + """ Overview: Create a wrapper to wrap function, and the wrapper will call the save_checkpoint method whenever an exception happens. diff --git a/ding/torch_utils/data_helper.py b/ding/torch_utils/data_helper.py index 49e107114e..b906279b1e 100644 --- a/ding/torch_utils/data_helper.py +++ b/ding/torch_utils/data_helper.py @@ -465,7 +465,7 @@ class LogDict(dict): Overview: Derived from ``dict``. Would convert ``torch.Tensor`` to ``list`` for convenient logging. Interfaces: - __setitem__, update. + ``_transform``, ``__setitem__``, ``update``. """ def _transform(self, data: Any) -> None: @@ -525,7 +525,7 @@ class CudaFetcher(object): Overview: Fetch data from source, and transfer it to a specified device. Interfaces: - __init__, run, close, __next__. + ``__init__``, ``__next__``, ``run``, ``close``. """ def __init__(self, data_source: Iterable, device: str, queue_size: int = 4, sleep: float = 0.1) -> None: @@ -577,6 +577,11 @@ def close(self) -> None: self._end_flag = True def _producer(self) -> None: + """ + Overview: + Keep fetching data from source, change the device, and put into ``queue`` for request. + """ + with torch.cuda.stream(self._stream): while not self._end_flag: if self._queue.full(): diff --git a/ding/torch_utils/dataparallel.py b/ding/torch_utils/dataparallel.py index 654ceaa7d1..f4ea14f767 100644 --- a/ding/torch_utils/dataparallel.py +++ b/ding/torch_utils/dataparallel.py @@ -3,10 +3,33 @@ class DataParallel(nn.DataParallel): + """ + Overview: + A wrapper class for nn.DataParallel. + Interfaces: + ``__init__``, ``parameters`` + """ def __init__(self, module, device_ids=None, output_device=None, dim=0): + """ + Overview: + Initialize the DataParallel object. + Arguments: + - module (:obj:`nn.Module`): The module to be parallelized. + - device_ids (:obj:`list`): The list of GPU ids. + - output_device (:obj:`int`): The output GPU id. + - dim (:obj:`int`): The dimension to be parallelized. + """ super().__init__(module, device_ids=None, output_device=None, dim=0) self.module = module def parameters(self, recurse: bool = True): + """ + Overview: + Return the parameters of the module. + Arguments: + - recurse (:obj:`bool`): Whether to return the parameters of the submodules. + Returns: + - params (:obj:`generator`): The generator of the parameters. + """ return self.module.parameters(recurse=True) diff --git a/ding/torch_utils/distribution.py b/ding/torch_utils/distribution.py index dc34fb492a..f68ef6fba0 100644 --- a/ding/torch_utils/distribution.py +++ b/ding/torch_utils/distribution.py @@ -9,11 +9,11 @@ class Pd(object): - r""" + """ Overview: Abstract class for parameterizable probability distributions and sampling functions. - Interface: - neglogp, entropy, noise_mode, mode, sample + Interfaces: + ``neglogp``, ``entropy``, ``noise_mode``, ``mode``, ``sample`` .. tip:: @@ -21,7 +21,7 @@ class Pd(object): """ def neglogp(self, x: torch.Tensor) -> torch.Tensor: - r""" + """ Overview: Calculate cross_entropy between input x and logits Arguments: @@ -32,7 +32,7 @@ def neglogp(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError def entropy(self) -> torch.Tensor: - r""" + """ Overview: Calculate the softmax entropy of logits Arguments: @@ -43,21 +43,21 @@ def entropy(self) -> torch.Tensor: raise NotImplementedError def noise_mode(self): - r""" + """ Overview: Add noise to logits. This method is designed for randomness """ raise NotImplementedError def mode(self): - r""" + """ Overview: Return logits argmax result. This method is designed for deterministic. """ raise NotImplementedError def sample(self): - r""" + """ Overview: Sample from logits's distribution by using softmax. This method is designed for multinomial. """ @@ -65,15 +65,15 @@ def sample(self): class CategoricalPd(Pd): - r""" + """ Overview: Catagorical probility distribution sampler - Interface: - update_logits, neglogp, entropy, noise_mode, mode, sample + Interfaces: + ``__init__``, ``neglogp``, ``entropy``, ``noise_mode``, ``mode``, ``sample`` """ def __init__(self, logits: torch.Tensor = None) -> None: - r""" + """ Overview: Init the Pd with logits Arguments: @@ -82,7 +82,7 @@ def __init__(self, logits: torch.Tensor = None) -> None: self.update_logits(logits) def update_logits(self, logits: torch.Tensor) -> None: - r""" + """ Overview: Updata logits Arguments: @@ -91,7 +91,7 @@ def update_logits(self, logits: torch.Tensor) -> None: self.logits = logits def neglogp(self, x, reduction: str = 'mean') -> torch.Tensor: - r""" + """ Overview: Calculate cross_entropy between input x and logits Arguments: @@ -103,7 +103,7 @@ def neglogp(self, x, reduction: str = 'mean') -> torch.Tensor: return F.cross_entropy(self.logits, x, reduction=reduction) def entropy(self, reduction: str = 'mean') -> torch.Tensor: - r""" + """ Overview: Calculate the softmax entropy of logits Arguments: @@ -191,16 +191,22 @@ class CategoricalPdPytorch(torch.distributions.Categorical): Overview: Wrapped ``torch.distributions.Categorical`` - Interface: - update_logits, update_probs, sample, neglogp, mode, entropy + Interfaces: + ``__init__``, ``update_logits``, ``update_probs``, ``sample``, ``neglogp``, ``mode``, ``entropy`` """ def __init__(self, probs: torch.Tensor = None) -> None: + """ + Overview: + Initialize the CategoricalPdPytorch object. + Arguments: + - probs (:obj:`torch.Tensor`): The tensor of probabilities. + """ if probs is not None: self.update_probs(probs) def update_logits(self, logits: torch.Tensor) -> None: - r""" + """ Overview: Updata logits Arguments: @@ -209,7 +215,7 @@ def update_logits(self, logits: torch.Tensor) -> None: super().__init__(logits=logits) def update_probs(self, probs: torch.Tensor) -> None: - r""" + """ Overview: Updata probs Arguments: @@ -218,7 +224,7 @@ def update_probs(self, probs: torch.Tensor) -> None: super().__init__(probs=probs) def sample(self) -> torch.Tensor: - r""" + """ Overview: Sample from logits's distribution by using softmax Return: @@ -227,7 +233,7 @@ def sample(self) -> torch.Tensor: return super().sample() def neglogp(self, actions: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: - r""" + """ Overview: Calculate cross_entropy between input x and logits Arguments: @@ -244,7 +250,7 @@ def neglogp(self, actions: torch.Tensor, reduction: str = 'mean') -> torch.Tenso return neglogp.mean(dim=0) def mode(self) -> torch.Tensor: - r""" + """ Overview: Return logits argmax result Return: @@ -253,7 +259,7 @@ def mode(self) -> torch.Tensor: return self.probs.argmax(dim=-1) def entropy(self, reduction: str = None) -> torch.Tensor: - r""" + """ Overview: Calculate the softmax entropy of logits Arguments: diff --git a/ding/torch_utils/loss/contrastive_loss.py b/ding/torch_utils/loss/contrastive_loss.py index d46d55711f..94ef62f8de 100644 --- a/ding/torch_utils/loss/contrastive_loss.py +++ b/ding/torch_utils/loss/contrastive_loss.py @@ -12,7 +12,7 @@ class ContrastiveLoss(nn.Module): The class for contrastive learning losses. Only InfoNCE loss is supported currently. \ Code Reference: https://github.com/rdevon/DIM. Paper Reference: https://arxiv.org/abs/1808.06670. Interfaces: - __init__, forward. + ``__init__``, ``forward``. """ def __init__( @@ -45,26 +45,43 @@ def __init__( self._type = loss_type.lower() self._encode_shape = encode_shape self._heads = heads - self._x_encoder = self._get_encoder(x_size, heads[0]) - self._y_encoder = self._get_encoder(y_size, heads[1]) + self._x_encoder = self._create_encoder(x_size, heads[0]) + self._y_encoder = self._create_encoder(y_size, heads[1]) self._temperature = temperature - def _get_encoder(self, obs: Union[int, SequenceType], heads: int) -> nn.Module: + def _create_encoder(self, obs_size: Union[int, SequenceType], heads: int) -> nn.Module: + """ + Overview: + Create the encoder for the input obs. + Arguments: + - obs_size (:obj:`Union[int, SequenceType]`): input shape for x, both the obs shape and the encoding shape \ + are supported. If the obs_size is an int, it means the obs is a 1D vector. If the obs_size is a list \ + such as [1, 16, 16], it means the obs is a 3D image with shape [1, 16, 16]. + - heads (:obj:`int`): The number of heads. + Returns: + - encoder (:obj:`nn.Module`): The encoder module. + Examples: + >>> obs_size = 16 + or + >>> obs_size = [1, 16, 16] + >>> heads = 1 + >>> encoder = self._create_encoder(obs_size, heads) + """ from ding.model import ConvEncoder, FCEncoder - if isinstance(obs, int): - obs = [obs] - assert len(obs) in [1, 3] + if isinstance(obs_size, int): + obs_size = [obs_size] + assert len(obs_size) in [1, 3] - if len(obs) == 1: + if len(obs_size) == 1: hidden_size_list = [128, 128, self._encode_shape * heads] - encoder = FCEncoder(obs[0], hidden_size_list) + encoder = FCEncoder(obs_size[0], hidden_size_list) else: hidden_size_list = [32, 64, 64, self._encode_shape * heads] - if obs[-1] >= 36: - encoder = ConvEncoder(obs, hidden_size_list) + if obs_size[-1] >= 36: + encoder = ConvEncoder(obs_size, hidden_size_list) else: - encoder = ConvEncoder(obs, hidden_size_list, kernel_size=[4, 3, 2], stride=[2, 1, 1]) + encoder = ConvEncoder(obs_size, hidden_size_list, kernel_size=[4, 3, 2], stride=[2, 1, 1]) return encoder def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: diff --git a/ding/torch_utils/loss/cross_entropy_loss.py b/ding/torch_utils/loss/cross_entropy_loss.py index d7ef4fda4f..3cdcb969d7 100644 --- a/ding/torch_utils/loss/cross_entropy_loss.py +++ b/ding/torch_utils/loss/cross_entropy_loss.py @@ -9,7 +9,7 @@ class LabelSmoothCELoss(nn.Module): Overview: Label smooth cross entropy loss. Interfaces: - __init__, forward. + ``__init__``, ``forward``. """ def __init__(self, ratio: float) -> None: @@ -46,7 +46,7 @@ class SoftFocalLoss(nn.Module): Overview: Soft focal loss. Interfaces: - __init__, forward. + ``__init__``, ``forward``. """ def __init__( @@ -72,7 +72,7 @@ def __init__( self.nll_loss = torch.nn.NLLLoss2d(weight, size_average, reduce=reduce) def forward(self, inputs: torch.Tensor, targets: torch.LongTensor) -> torch.Tensor: - r""" + """ Overview: Calculate soft focal loss. Arguments: @@ -89,7 +89,10 @@ def build_ce_criterion(cfg: dict) -> nn.Module: Overview: Get a cross entropy loss instance according to given config. Arguments: - - cfg (:obj:`dict`) + - cfg (:obj:`dict`) : Config dict. It contains: + - type (:obj:`str`): Type of loss function, now supports ['cross_entropy', 'label_smooth_ce', \ + 'soft_focal_loss']. + - kwargs (:obj:`dict`): Arguments for the corresponding loss function. Returns: - loss (:obj:`nn.Module`): loss function instance """ diff --git a/ding/torch_utils/loss/multi_logits_loss.py b/ding/torch_utils/loss/multi_logits_loss.py index 8f0e283e64..86716cd632 100644 --- a/ding/torch_utils/loss/multi_logits_loss.py +++ b/ding/torch_utils/loss/multi_logits_loss.py @@ -6,19 +6,12 @@ from ding.torch_utils.network import one_hot -def get_distance_matrix(lx: np.ndarray, ly: np.ndarray, mat: np.ndarray, M: int) -> np.ndarray: - nlx = np.broadcast_to(lx, [M, M]).T - nly = np.broadcast_to(ly, [M, M]) - nret = nlx + nly - mat - return nret - - class MultiLogitsLoss(nn.Module): """ Overview: Base class for supervised learning on linklink, including basic processes. - Interface: - __init__, forward. + Interfaces: + ``__init__``, ``forward``. """ def __init__(self, criterion: str = None, smooth_ratio: float = 0.1) -> None: @@ -38,6 +31,15 @@ def __init__(self, criterion: str = None, smooth_ratio: float = 0.1) -> None: self.ratio = smooth_ratio def _label_process(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.LongTensor: + """ + Overview: + Process the label according to the criterion. + Arguments: + - logits (:obj:`torch.Tensor`): Predicted logits. + - labels (:obj:`torch.LongTensor`): Ground truth. + Returns: + - ret (:obj:`torch.LongTensor`): Processed label. + """ N = logits.shape[1] if self.criterion == 'cross_entropy': return one_hot(labels, num=N) @@ -48,10 +50,28 @@ def _label_process(self, logits: torch.Tensor, labels: torch.LongTensor) -> torc return ret def _nll_loss(self, nlls: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: + """ + Overview: + Calculate the negative log likelihood loss. + Arguments: + - nlls (:obj:`torch.Tensor`): Negative log likelihood loss. + - labels (:obj:`torch.LongTensor`): Ground truth. + Returns: + - ret (:obj:`torch.Tensor`): Calculated loss. + """ ret = (-nlls * (labels.detach())) return ret.sum(dim=1) def _get_metric_matrix(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: + """ + Overview: + Calculate the metric matrix. + Arguments: + - logits (:obj:`torch.Tensor`): Predicted logits. + - labels (:obj:`torch.LongTensor`): Ground truth. + Returns: + - metric (:obj:`torch.Tensor`): Calculated metric matrix. + """ M, N = logits.shape labels = self._label_process(logits, labels) logits = F.log_softmax(logits, dim=1) @@ -63,6 +83,14 @@ def _get_metric_matrix(self, logits: torch.Tensor, labels: torch.LongTensor) -> return torch.stack(metric, dim=0) def _match(self, matrix: torch.Tensor): + """ + Overview: + Match the metric matrix. + Arguments: + - matrix (:obj:`torch.Tensor`): Metric matrix. + Returns: + - index (:obj:`np.ndarray`): Matched index. + """ mat = matrix.clone().detach().to('cpu').numpy() mat = -mat # maximize M = mat.shape[0] @@ -88,7 +116,7 @@ def has_augmented_path(t, binary_distance_matrix): while True: visx.fill(False) visy.fill(False) - distance_matrix = get_distance_matrix(lx, ly, mat, M) + distance_matrix = self._get_distance_matrix(lx, ly, mat, M) binary_distance_matrix = np.abs(distance_matrix) < 1e-4 if has_augmented_path(i, binary_distance_matrix): break @@ -101,6 +129,22 @@ def has_augmented_path(t, binary_distance_matrix): ly[visy] += d return index + @staticmethod + def _get_distance_matrix(lx: np.ndarray, ly: np.ndarray, mat: np.ndarray, M: int) -> np.ndarray: + """ + Overview: + Get distance matrix. + Arguments: + - lx (:obj:`np.ndarray`): lx. + - ly (:obj:`np.ndarray`): ly. + - mat (:obj:`np.ndarray`): mat. + - M (:obj:`int`): M. + """ + nlx = np.broadcast_to(lx, [M, M]).T + nly = np.broadcast_to(ly, [M, M]) + nret = nlx + nly - mat + return nret + def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor: """ Overview: diff --git a/ding/torch_utils/lr_scheduler.py b/ding/torch_utils/lr_scheduler.py index 7c296ea180..09d341430a 100644 --- a/ding/torch_utils/lr_scheduler.py +++ b/ding/torch_utils/lr_scheduler.py @@ -6,6 +6,17 @@ def get_lr_ratio(epoch: int, warmup_epochs: int, learning_rate: float, lr_decay_epochs: int, min_lr: float) -> float: + """ + Overview: + Get learning rate ratio for each epoch. + Arguments: + - epoch (:obj:`int`): Current epoch. + - warmup_epochs (:obj:`int`): Warmup epochs. + - learning_rate (:obj:`float`): Learning rate. + - lr_decay_epochs (:obj:`int`): Learning rate decay epochs. + - min_lr (:obj:`float`): Minimum learning rate. + """ + # 1) linear warmup for warmup_epochs. if epoch < warmup_epochs: return epoch / warmup_epochs @@ -26,6 +37,17 @@ def cos_lr_scheduler( lr_decay_epochs: float = 100, min_lr: float = 6e-5 ) -> torch.optim.lr_scheduler.LambdaLR: + """ + Overview: + Cosine learning rate scheduler. + Arguments: + - optimizer (:obj:`torch.optim.Optimizer`): Optimizer. + - learning_rate (:obj:`float`): Learning rate. + - warmup_epochs (:obj:`float`): Warmup epochs. + - lr_decay_epochs (:obj:`float`): Learning rate decay epochs. + - min_lr (:obj:`float`): Minimum learning rate. + """ + return LambdaLR( optimizer, partial( diff --git a/ding/torch_utils/network/activation.py b/ding/torch_utils/network/activation.py index ed46a14905..b3c8fcda4c 100644 --- a/ding/torch_utils/network/activation.py +++ b/ding/torch_utils/network/activation.py @@ -9,6 +9,8 @@ class Lambda(nn.Module): """ Overview: A custom lambda module for constructing custom layers. + Interfaces: + ``__init__``, ``forward``. """ def __init__(self, f: Callable): @@ -21,7 +23,13 @@ def __init__(self, f: Callable): super(Lambda, self).__init__() self.f = f - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Compute the function of the input tensor. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor. + """ return self.f(x) @@ -31,7 +39,7 @@ class GLU(nn.Module): Gating Linear Unit (GLU), a specific type of activation function, which is first proposed in [Language Modeling with Gated Convolutional Networks](https://arxiv.org/pdf/1612.08083.pdf). Interfaces: - ``forward``. + ``__init__``, ``forward``. """ def __init__(self, input_dim: int, output_dim: int, context_dim: int, input_type: str = 'fc') -> None: @@ -75,6 +83,8 @@ class Swish(nn.Module): Overview: Swish activation function, which is a smooth, non-monotonic activation function. For more details, please refer to [Searching for Activation Functions](https://arxiv.org/pdf/1710.05941.pdf). + Interfaces: + ``__init__``, ``forward``. """ def __init__(self): @@ -102,10 +112,14 @@ class GELU(nn.Module): Gaussian Error Linear Units (GELU) activation function, which is widely used in NLP models like GPT, BERT. For more details, please refer to the original paper: https://arxiv.org/pdf/1606.08415.pdf. Interfaces: - ``forward`` + ``__init__``, ``forward``. """ def __init__(self): + """ + Overview: + Initialize the GELU module. + """ super(GELU, self).__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/ding/torch_utils/network/diffusion.py b/ding/torch_utils/network/diffusion.py index 8dfa9d3a14..deb95c9022 100755 --- a/ding/torch_utils/network/diffusion.py +++ b/ding/torch_utils/network/diffusion.py @@ -12,6 +12,10 @@ def extract(a, t, x_shape): """ Overview: extract output from a through index t. + Arguments: + - a (:obj:`torch.Tensor`): input tensor + - t (:obj:`torch.Tensor`): index tensor + - x_shape (:obj:`torch.Tensor`): shape of x """ b, *_ = t.shape out = a.gather(-1, t) @@ -23,6 +27,10 @@ def cosine_beta_schedule(timesteps: int, s: float = 0.008, dtype=torch.float32): Overview: cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + Arguments: + - timesteps (:obj:`int`): timesteps of diffusion step + - s (:obj:`float`): s + - dtype (:obj:`torch.dtype`): dtype of beta Return: Tensor of beta [timesteps,], computing by cosine. """ @@ -39,6 +47,10 @@ def apply_conditioning(x, conditions, action_dim): """ Overview: add condition into x + Arguments: + - x (:obj:`torch.Tensor`): input tensor + - conditions (:obj:`dict`): condition dict, key is timestep, value is condition + - action_dim (:obj:`int`): action dim """ for t, val in conditions.items(): x[:, t, action_dim:] = val.clone() @@ -46,6 +58,12 @@ def apply_conditioning(x, conditions, action_dim): class DiffusionConv1d(nn.Module): + """ + Overview: + Conv1d with activation and normalization for diffusion models. + Interfaces: + ``__init__``, ``forward`` + """ def __init__( self, @@ -72,10 +90,14 @@ def __init__( self.norm = nn.GroupNorm(n_groups, out_channels) self.act = activation - def forward(self, inputs): + def forward(self, inputs) -> torch.Tensor: """ Overview: compute conv1d for inputs. + Arguments: + - inputs (:obj:`torch.Tensor`): input tensor + Return: + - out (:obj:`torch.Tensor`): output tensor """ x = self.conv1(inputs) # [batch, channels, horizon] -> [batch, channels, 1, horizon] @@ -90,17 +112,32 @@ def forward(self, inputs): class SinusoidalPosEmb(nn.Module): """ Overview: - compute sin position embeding + class for computing sin position embeding + Interfaces: + ``__init__``, ``forward`` """ - def __init__( - self, - dim: int, - ) -> None: + def __init__(self, dim: int) -> None: + """ + Overview: + Initialization of SinusoidalPosEmb class + Arguments: + - dim (:obj:`int`): dimension of embeding + """ + super().__init__() self.dim = dim - def forward(self, x): + def forward(self, x) -> torch.Tensor: + """ + Overview: + compute sin position embeding + Arguments: + - x (:obj:`torch.Tensor`): input tensor + Return: + - emb (:obj:`torch.Tensor`): output tensor + """ + device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) @@ -111,27 +148,65 @@ def forward(self, x): class Residual(nn.Module): + """ + Overview: + Basic Residual block + Interfaces: + ``__init__``, ``forward`` + """ def __init__(self, fn): + """ + Overview: + Initialization of Residual class + Arguments: + - fn (:obj:`nn.Module`): function of residual block + """ + super().__init__() self.fn = fn def forward(self, x, *arg, **kwargs): + """ + Overview: + compute residual block + Arguments: + - x (:obj:`torch.Tensor`): input tensor + """ + return self.fn(x, *arg, **kwargs) + x class LayerNorm(nn.Module): """ - Overview: LayerNorm, compute dim = 1, because Temporal input x [batch, dim, horizon] + Overview: + LayerNorm, compute dim = 1, because Temporal input x [batch, dim, horizon] + Interfaces: + ``__init__``, ``forward`` """ def __init__(self, dim, eps=1e-5) -> None: + """ + Overview: + Initialization of LayerNorm class + Arguments: + - dim (:obj:`int`): dimension of input + - eps (:obj:`float`): eps of LayerNorm + """ + super().__init__() self.eps = eps self.g = nn.Parameter(torch.ones(1, dim, 1)) self.b = nn.Parameter(torch.zeros(1, dim, 1)) def forward(self, x): + """ + Overview: + compute LayerNorm + Arguments: + - x (:obj:`torch.Tensor`): input tensor + """ + print('x.shape:', x.shape) var = torch.var(x, dim=1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=1, keepdim=True) @@ -139,13 +214,33 @@ def forward(self, x): class PreNorm(nn.Module): + """ + Overview: + PreNorm, compute dim = 1, because Temporal input x [batch, dim, horizon] + Interfaces: + ``__init__``, ``forward`` + """ def __init__(self, dim, fn) -> None: + """ + Overview: + Initialization of PreNorm class + Arguments: + - dim (:obj:`int`): dimension of input + - fn (:obj:`nn.Module`): function of residual block + """ + super().__init__() self.fn = fn self.norm = LayerNorm(dim) def forward(self, x): + """ + Overview: + compute PreNorm + Arguments: + - x (:obj:`torch.Tensor`): input tensor + """ x = self.norm(x) return self.fn(x) @@ -154,13 +249,19 @@ class LinearAttention(nn.Module): """ Overview: Linear Attention head - Arguments: - - dim (:obj:'int'): dim of input - - heads (:obj:'int'): num of head - - dim_head (:obj:'int'): dim of head + Interfaces: + ``__init__``, ``forward`` """ def __init__(self, dim, heads=4, dim_head=32) -> None: + """ + Overview: + Initialization of LinearAttention class + Arguments: + - dim (:obj:`int`): dimension of input + - heads (:obj:`int`): heads of attention + - dim_head (:obj:`int`): dim of head + """ super().__init__() self.scale = dim_head ** -0.5 self.heads = heads @@ -169,6 +270,12 @@ def __init__(self, dim, heads=4, dim_head=32) -> None: self.to_out = nn.Conv1d(hidden_dim, dim, 1) def forward(self, x): + """ + Overview: + compute LinearAttention + Arguments: + - x (:obj:`torch.Tensor`): input tensor + """ qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map(lambda t: t.reshape(t.shape[0], self.heads, -1, t.shape[-1]), qkv) q = q * self.scale @@ -184,17 +291,23 @@ class ResidualTemporalBlock(nn.Module): """ Overview: Residual block of temporal - Arguments: - - in_channels (:obj:'int'): dim of in_channels - - out_channels (:obj:'int'): dim of out_channels - - embed_dim (:obj:'int'): dim of embeding layer - - kernel_size (:obj:'int'): kernel_size of conv1d - - mish (:obj:'bool'): whether use mish as a activate function + Interfaces: + ``__init__``, ``forward`` """ def __init__( self, in_channels: int, out_channels: int, embed_dim: int, kernel_size: int = 5, mish: bool = True ) -> None: + """ + Overview: + Initialization of ResidualTemporalBlock class + Arguments: + - in_channels (:obj:'int'): dim of in_channels + - out_channels (:obj:'int'): dim of out_channels + - embed_dim (:obj:'int'): dim of embeding layer + - kernel_size (:obj:'int'): kernel_size of conv1d + - mish (:obj:'bool'): whether use mish as a activate function + """ super().__init__() if mish: act = nn.Mish() @@ -214,12 +327,25 @@ def __init__( if in_channels != out_channels else nn.Identity() def forward(self, x, t): + """ + Overview: + compute residual block + Arguments: + - x (:obj:'tensor'): input tensor + - t (:obj:'tensor'): time tensor + """ out = self.blocks[0](x) + self.time_mlp(t).unsqueeze(-1) out = self.blocks[1](out) return out + self.residual_conv(x) class DiffusionUNet1d(nn.Module): + """ + Overview: + Diffusion unet for 1d vector data + Interfaces: + ``__init__``, ``forward``, ``get_pred`` + """ def __init__( self, @@ -234,7 +360,7 @@ def __init__( ) -> None: """ Overview: - temporal net + Initialization of DiffusionUNet1d class Arguments: - transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim - dim (:obj:'int'): dim of layer @@ -325,13 +451,15 @@ def __init__( def forward(self, x, cond, time, returns=None, use_dropout: bool = True, force_dropout: bool = False): """ + Overview: + compute diffusion unet forward Arguments: - x (:obj:'tensor'): noise trajectory - cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 - time (:obj:'int'): timestep of diffusion step - returns (:obj:'tensor'): condition returns of trajectory, returns is normal return - use_dropout (:obj:'bool'): Whether use returns condition mask - force_dropout (:obj:'bool'): Whether use returns condition + - x (:obj:'tensor'): noise trajectory + - cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 + - time (:obj:'int'): timestep of diffusion step + - returns (:obj:'tensor'): condition returns of trajectory, returns is normal return + - use_dropout (:obj:'bool'): Whether use returns condition mask + - force_dropout (:obj:'bool'): Whether use returns condition """ if self.cale_energy: x_inp = x @@ -383,6 +511,17 @@ def forward(self, x, cond, time, returns=None, use_dropout: bool = True, force_d return x def get_pred(self, x, cond, time, returns: bool = None, use_dropout: bool = True, force_dropout: bool = False): + """ + Overview: + compute diffusion unet forward + Arguments: + - x (:obj:'tensor'): noise trajectory + - cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 + - time (:obj:'int'): timestep of diffusion step + - returns (:obj:'tensor'): condition returns of trajectory, returns is normal return + - use_dropout (:obj:'bool'): Whether use returns condition mask + - force_dropout (:obj:'bool'): Whether use returns condition + """ # [batch, horizon, transition ] -> [batch, transition , horizon] x = x.transpose(1, 2) t = self.time_mlp(time) @@ -424,13 +563,8 @@ class TemporalValue(nn.Module): """ Overview: temporal net for value function - Arguments: - - horizon (:obj:'int'): horizon of trajectory - - transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim - - dim (:obj:'int'): dim of layer - - time_dim (:obj:'): dim of time - - dim_mults (:obj:'SequenceType'): mults of dim - - kernel_size (:obj:'int'): kernel_size of conv1d + Interfaces: + ``__init__``, ``forward`` """ def __init__( @@ -443,6 +577,18 @@ def __init__( kernel_size: int = 5, dim_mults: SequenceType = [1, 2, 4, 8], ) -> None: + """ + Overview: + Initialization of TemporalValue class + Arguments: + - horizon (:obj:'int'): horizon of trajectory + - transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim + - dim (:obj:'int'): dim of layer + - time_dim (:obj:'int'): dim of time + - out_dim (:obj:'int'): dim of output + - kernel_size (:obj:'int'): kernel_size of conv1d + - dim_mults (:obj:'SequenceType'): mults of dim + """ super().__init__() dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) @@ -489,6 +635,14 @@ def __init__( ) def forward(self, x, cond, time, *args): + """ + Overview: + compute temporal value forward + Arguments: + - x (:obj:'tensor'): noise trajectory + - cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 + - time (:obj:'int'): timestep of diffusion step + """ # [batch, horizon, transition ] -> [batch, transition , horizon] x = x.transpose(1, 2) t = self.time_mlp(time) diff --git a/ding/torch_utils/network/dreamer.py b/ding/torch_utils/network/dreamer.py index 6f48ce085d..f7c1597e54 100644 --- a/ding/torch_utils/network/dreamer.py +++ b/ding/torch_utils/network/dreamer.py @@ -10,11 +10,32 @@ class Conv2dSame(torch.nn.Conv2d): + """ + Overview: + Conv2dSame Network for dreamerv3. + Interfaces: + ``__init__``, ``forward`` + """ def calc_same_pad(self, i, k, s, d): + """ + Overview: + Calculate the same padding size. + Arguments: + - i (:obj:`int`): Input size. + - k (:obj:`int`): Kernel size. + - s (:obj:`int`): Stride size. + - d (:obj:`int`): Dilation size. + """ return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) def forward(self, x): + """ + Overview: + compute the forward of Conv2dSame. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor. + """ ih, iw = x.size()[-2:] pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0]) pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1]) @@ -35,12 +56,33 @@ def forward(self, x): class DreamerLayerNorm(nn.Module): + """ + Overview: + DreamerLayerNorm Network for dreamerv3. + Interfaces: + ``__init__``, ``forward`` + """ def __init__(self, ch, eps=1e-03): + """ + Overview: + Init the DreamerLayerNorm class. + Arguments: + - ch (:obj:`int`): Input channel. + - eps (:obj:`float`): Epsilon. + """ + super(DreamerLayerNorm, self).__init__() self.norm = torch.nn.LayerNorm(ch, eps=eps) def forward(self, x): + """ + Overview: + compute the forward of DreamerLayerNorm. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor. + """ + x = x.permute(0, 2, 3, 1) x = self.norm(x) x = x.permute(0, 3, 1, 2) @@ -51,7 +93,7 @@ class DenseHead(nn.Module): """ Overview: DenseHead Network for value head, reward head, and discount head of dreamerv3. - Interface: + Interfaces: ``__init__``, ``forward`` """ @@ -68,6 +110,22 @@ def __init__( outscale=1.0, device='cpu', ): + """ + Overview: + Init the DenseHead class. + Arguments: + - inp_dim (:obj:`int`): Input dimension. + - shape (:obj:`tuple`): Output shape. + - layer_num (:obj:`int`): Number of layers. + - units (:obj:`int`): Number of units. + - act (:obj:`str`): Activation function. + - norm (:obj:`str`): Normalization function. + - dist (:obj:`str`): Distribution function. + - std (:obj:`float`): Standard deviation. + - outscale (:obj:`float`): Output scale. + - device (:obj:`str`): Device. + """ + super(DenseHead, self).__init__() self._shape = (shape, ) if isinstance(shape, int) else shape if len(self._shape) == 0: @@ -99,6 +157,13 @@ def __init__( self.std_layer.apply(uniform_weight_init(outscale)) def forward(self, features): + """ + Overview: + compute the forward of DenseHead. + Arguments: + - features (:obj:`torch.Tensor`): Input tensor. + """ + x = features out = self.mlp(x) # (batch, time, _units=512) mean = self.mean_layer(out) # (batch, time, 255) @@ -121,7 +186,7 @@ class ActionHead(nn.Module): """ Overview: ActionHead Network for action head of dreamerv3. - Interface: + Interfaces: ``__init__``, ``forward`` """ @@ -141,6 +206,24 @@ def __init__( outscale=1.0, unimix_ratio=0.01, ): + """ + Overview: + Initialize the ActionHead class. + Arguments: + - inp_dim (:obj:`int`): Input dimension. + - size (:obj:`int`): Output size. + - layers (:obj:`int`): Number of layers. + - units (:obj:`int`): Number of units. + - act (:obj:`str`): Activation function. + - norm (:obj:`str`): Normalization function. + - dist (:obj:`str`): Distribution function. + - init_std (:obj:`float`): Initial standard deviation. + - min_std (:obj:`float`): Minimum standard deviation. + - max_std (:obj:`float`): Maximum standard deviation. + - temp (:obj:`float`): Temperature. + - outscale (:obj:`float`): Output scale. + - unimix_ratio (:obj:`float`): Unimix ratio. + """ super(ActionHead, self).__init__() self._size = size self._layers = layers @@ -173,6 +256,13 @@ def __init__( self._dist_layer.apply(uniform_weight_init(outscale)) def forward(self, features): + """ + Overview: + compute the forward of ActionHead. + Arguments: + - features (:obj:`torch.Tensor`): Input tensor. + """ + x = features x = self._pre_layers(x) if self._dist == "tanh_normal": @@ -226,24 +316,47 @@ class SampleDist: """ Overview: A kind of sample Dist for ActionHead of dreamerv3. - Interface: + Interfaces: ``__init__``, ``mean``, ``mode``, ``entropy`` """ def __init__(self, dist, samples=100): + """ + Overview: + Initialize the SampleDist class. + Arguments: + - dist (:obj:`torch.Tensor`): Distribution. + - samples (:obj:`int`): Number of samples. + """ + self._dist = dist self._samples = samples def mean(self): + """ + Overview: + Calculate the mean of the distribution. + """ + samples = self._dist.sample(self._samples) return torch.mean(samples, 0) def mode(self): + """ + Overview: + Calculate the mode of the distribution. + """ + sample = self._dist.sample(self._samples) logprob = self._dist.log_prob(sample) return sample[torch.argmax(logprob)][0] def entropy(self): + """ + Overview: + Calculate the entropy of the distribution. + """ + sample = self._dist.sample(self._samples) logprob = self.log_prob(sample) return -torch.mean(logprob, 0) @@ -253,11 +366,20 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical): """ Overview: A kind of onehot Dist for dreamerv3. - Interface: + Interfaces: ``__init__``, ``mode``, ``sample`` """ def __init__(self, logits=None, probs=None, unimix_ratio=0.0): + """ + Overview: + Initialize the OneHotDist class. + Arguments: + - logits (:obj:`torch.Tensor`): Logits. + - probs (:obj:`torch.Tensor`): Probabilities. + - unimix_ratio (:obj:`float`): Unimix ratio. + """ + if logits is not None and unimix_ratio > 0.0: probs = F.softmax(logits, dim=-1) probs = probs * (1.0 - unimix_ratio) + unimix_ratio / probs.shape[-1] @@ -267,10 +389,23 @@ def __init__(self, logits=None, probs=None, unimix_ratio=0.0): super().__init__(logits=logits, probs=probs) def mode(self): + """ + Overview: + Calculate the mode of the distribution. + """ + _mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1]) return _mode.detach() + super().logits - super().logits.detach() def sample(self, sample_shape=(), seed=None): + """ + Overview: + Sample from the distribution. + Arguments: + - sample_shape (:obj:`tuple`): Sample shape. + - seed (:obj:`int`): Seed. + """ + if seed is not None: raise ValueError('need to check') sample = super().sample(sample_shape) @@ -285,26 +420,53 @@ class TwoHotDistSymlog: """ Overview: A kind of twohotsymlog Dist for dreamerv3. - Interface: + Interfaces: ``__init__``, ``mode``, ``mean``, ``log_prob``, ``log_prob_target`` """ def __init__(self, logits=None, low=-20.0, high=20.0, device='cpu'): + """ + Overview: + Initialize the TwoHotDistSymlog class. + Arguments: + - logits (:obj:`torch.Tensor`): Logits. + - low (:obj:`float`): Low. + - high (:obj:`float`): High. + - device (:obj:`str`): Device. + """ + self.logits = logits self.probs = torch.softmax(logits, -1) self.buckets = torch.linspace(low, high, steps=255).to(device) self.width = (self.buckets[-1] - self.buckets[0]) / 255 def mean(self): + """ + Overview: + Calculate the mean of the distribution. + """ + _mean = self.probs * self.buckets return inv_symlog(torch.sum(_mean, dim=-1, keepdim=True)) def mode(self): + """ + Overview: + Calculate the mode of the distribution. + """ + _mode = self.probs * self.buckets return inv_symlog(torch.sum(_mode, dim=-1, keepdim=True)) # Inside OneHotCategorical, log_prob is calculated using only max element in targets def log_prob(self, x): + """ + Overview: + Calculate the log probability of the distribution. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor. + """ + x = symlog(x) # x(time, batch, 1) below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1 @@ -328,6 +490,13 @@ def log_prob(self, x): return (target * log_pred).sum(-1) def log_prob_target(self, target): + """ + Overview: + Calculate the log probability of the target. + Arguments: + - target (:obj:`torch.Tensor`): Target tensor. + """ + log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True) return (target * log_pred).sum(-1) @@ -336,11 +505,21 @@ class SymlogDist: """ Overview: A kind of Symlog Dist for dreamerv3. - Interface: + Interfaces: ``__init__``, ``entropy``, ``mode``, ``mean``, ``log_prob`` """ def __init__(self, mode, dist='mse', aggregation='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]): + """ + Overview: + Initialize the SymlogDist class. + Arguments: + - mode (:obj:`torch.Tensor`): Mode. + - dist (:obj:`str`): Distribution function. + - aggregation (:obj:`str`): Aggregation function. + - tol (:obj:`float`): Tolerance. + - dim_to_reduce (:obj:`list`): Dimension to reduce. + """ self._mode = mode self._dist = dist self._aggregation = aggregation @@ -348,12 +527,29 @@ def __init__(self, mode, dist='mse', aggregation='sum', tol=1e-8, dim_to_reduce= self._dim_to_reduce = dim_to_reduce def mode(self): + """ + Overview: + Calculate the mode of the distribution. + """ + return inv_symlog(self._mode) def mean(self): + """ + Overview: + Calculate the mean of the distribution. + """ + return inv_symlog(self._mode) def log_prob(self, value): + """ + Overview: + Calculate the log probability of the distribution. + Arguments: + - value (:obj:`torch.Tensor`): Input tensor. + """ + assert self._mode.shape == value.shape if self._dist == 'mse': distance = (self._mode - symlog(value)) ** 2.0 @@ -376,25 +572,56 @@ class ContDist: """ Overview: A kind of ordinary Dist for dreamerv3. - Interface: + Interfaces: ``__init__``, ``entropy``, ``mode``, ``sample``, ``log_prob`` """ def __init__(self, dist=None): + """ + Overview: + Initialize the ContDist class. + Arguments: + - dist (:obj:`torch.Tensor`): Distribution. + """ + super().__init__() self._dist = dist self.mean = dist.mean def __getattr__(self, name): + """ + Overview: + Get attribute. + Arguments: + - name (:obj:`str`): Attribute name. + """ + return getattr(self._dist, name) def entropy(self): + """ + Overview: + Calculate the entropy of the distribution. + """ + return self._dist.entropy() def mode(self): + """ + Overview: + Calculate the mode of the distribution. + """ + return self._dist.mean def sample(self, sample_shape=()): + """ + Overview: + Sample from the distribution. + Arguments: + - sample_shape (:obj:`tuple`): Sample shape. + """ + return self._dist.rsample(sample_shape) def log_prob(self, x): @@ -405,29 +632,66 @@ class Bernoulli: """ Overview: A kind of Bernoulli Dist for dreamerv3. - Interface: + Interfaces: ``__init__``, ``entropy``, ``mode``, ``sample``, ``log_prob`` """ def __init__(self, dist=None): + """ + Overview: + Initialize the Bernoulli distribution. + Arguments: + - dist (:obj:`torch.Tensor`): Distribution. + """ + super().__init__() self._dist = dist self.mean = dist.mean def __getattr__(self, name): + """ + Overview: + Get attribute. + Arguments: + - name (:obj:`str`): Attribute name. + """ + return getattr(self._dist, name) def entropy(self): + """ + Overview: + Calculate the entropy of the distribution. + """ return self._dist.entropy() def mode(self): + """ + Overview: + Calculate the mode of the distribution. + """ + _mode = torch.round(self._dist.mean) return _mode.detach() + self._dist.mean - self._dist.mean.detach() def sample(self, sample_shape=()): + """ + Overview: + Sample from the distribution. + Arguments: + - sample_shape (:obj:`tuple`): Sample shape. + """ + return self._dist.rsample(sample_shape) def log_prob(self, x): + """ + Overview: + Calculate the log probability of the distribution. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor. + """ + _logits = self._dist.base_dist.logits log_probs0 = -F.softplus(_logits) log_probs1 = -F.softplus(-_logits) @@ -439,18 +703,38 @@ class UnnormalizedHuber(torchd.normal.Normal): """ Overview: A kind of UnnormalizedHuber Dist for dreamerv3. - Interface: + Interfaces: ``__init__``, ``mode``, ``log_prob`` """ def __init__(self, loc, scale, threshold=1, **kwargs): + """ + Overview: + Initialize the UnnormalizedHuber class. + Arguments: + - loc (:obj:`torch.Tensor`): Location. + - scale (:obj:`torch.Tensor`): Scale. + - threshold (:obj:`float`): Threshold. + """ super().__init__(loc, scale, **kwargs) self._threshold = threshold def log_prob(self, event): + """ + Overview: + Calculate the log probability of the distribution. + Arguments: + - event (:obj:`torch.Tensor`): Event. + """ + return -(torch.sqrt((event - self.mean) ** 2 + self._threshold ** 2) - self._threshold) def mode(self): + """ + Overview: + Calculate the mode of the distribution. + """ + return self.mean @@ -458,11 +742,23 @@ class SafeTruncatedNormal(torchd.normal.Normal): """ Overview: A kind of SafeTruncatedNormal Dist for dreamerv3. - Interface: + Interfaces: ``__init__``, ``sample`` """ def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): + """ + Overview: + Initialize the SafeTruncatedNormal class. + Arguments: + - loc (:obj:`torch.Tensor`): Location. + - scale (:obj:`torch.Tensor`): Scale. + - low (:obj:`float`): Low. + - high (:obj:`float`): High. + - clip (:obj:`float`): Clip. + - mult (:obj:`float`): Mult. + """ + super().__init__(loc, scale) self._low = low self._high = high @@ -470,6 +766,13 @@ def __init__(self, loc, scale, low, high, clip=1e-6, mult=1): self._mult = mult def sample(self, sample_shape): + """ + Overview: + Sample from the distribution. + Arguments: + - sample_shape (:obj:`tuple`): Sample shape. + """ + event = super().sample(sample_shape) if self._clip: clipped = torch.clip(event, self._low + self._clip, self._high - self._clip) @@ -483,27 +786,65 @@ class TanhBijector(torchd.Transform): """ Overview: A kind of TanhBijector Dist for dreamerv3. - Interface: + Interfaces: ``__init__``, ``_forward``, ``_inverse``, ``_forward_log_det_jacobian`` """ def __init__(self, validate_args=False, name='tanh'): + """ + Overview: + Initialize the TanhBijector class. + Arguments: + - validate_args (:obj:`bool`): Validate arguments. + - name (:obj:`str`): Name. + """ + super().__init__() def _forward(self, x): + """ + Overview: + Calculate the forward of the distribution. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor. + """ + return torch.tanh(x) def _inverse(self, y): + """ + Overview: + Calculate the inverse of the distribution. + Arguments: + - y (:obj:`torch.Tensor`): Input tensor. + """ + y = torch.where((torch.abs(y) <= 1.), torch.clamp(y, -0.99999997, 0.99999997), y) y = torch.atanh(y) return y def _forward_log_det_jacobian(self, x): + """ + Overview: + Calculate the forward log det jacobian of the distribution. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor. + """ + log2 = torch.math.log(2.0) return 2.0 * (log2 - x - torch.softplus(-2.0 * x)) def static_scan(fn, inputs, start): + """ + Overview: + Static scan function. + Arguments: + - fn (:obj:`function`): Function. + - inputs (:obj:`tuple`): Inputs. + - start (:obj:`torch.Tensor`): Start tensor. + """ + last = start # {logit, stoch, deter:[batch_size, self._deter]} indices = range(inputs[0].shape[0]) flag = True @@ -541,7 +882,10 @@ def weight_init(m): """ Overview: weight_init for Linear, Conv2d, ConvTranspose2d, and LayerNorm. + Arguments: + - m (:obj:`torch.nn`): Module. """ + if isinstance(m, nn.Linear): in_num = m.in_features out_num = m.out_features @@ -571,6 +915,8 @@ def uniform_weight_init(given_scale): """ Overview: weight_init for Linear and LayerNorm. + Arguments: + - given_scale (:obj:`float`): Given scale. """ def f(m): diff --git a/ding/torch_utils/network/gtrxl.py b/ding/torch_utils/network/gtrxl.py index 7af8f20192..16ac7702c7 100644 --- a/ding/torch_utils/network/gtrxl.py +++ b/ding/torch_utils/network/gtrxl.py @@ -15,6 +15,8 @@ class PositionalEmbedding(nn.Module): """ Overview: The PositionalEmbedding module implements the positional embedding used in the vanilla Transformer model. + Interfaces: + ``__init__``, ``forward`` .. note:: This implementation is adapted from https://github.com/kimiyoung/transformer-xl/blob/ \ @@ -23,9 +25,12 @@ class PositionalEmbedding(nn.Module): def __init__(self, embedding_dim: int): """ + Overview: + Initialize the PositionalEmbedding module. Arguments: - embedding_dim: (:obj:`int`): The dimensionality of the embeddings. """ + super(PositionalEmbedding, self).__init__() self.embedding_dim = embedding_dim inv_freq = 1 / (10000 ** (torch.arange(0.0, embedding_dim, 2.0) / embedding_dim)) # (embedding_dim / 2) @@ -42,6 +47,7 @@ def forward(self, pos_seq: torch.Tensor) -> torch.Tensor: - pos_embedding (:obj:`torch.Tensor`): The computed positional embeddings. \ The shape of the tensor is (seq_len, 1, embedding_dim). """ + sinusoid_inp = torch.outer(pos_seq, self.inv_freq) # For position embedding, the order of sin/cos is negligible. # This is because tokens are consumed by the matrix multiplication which is permutation-invariant. @@ -53,16 +59,21 @@ class GRUGatingUnit(torch.nn.Module): """ Overview: The GRUGatingUnit module implements the GRU gating mechanism used in the GTrXL model. + Interfaces: + ``__init__``, ``forward`` """ def __init__(self, input_dim: int, bg: float = 2.): """ + Overview: + Initialize the GRUGatingUnit module. Arguments: - input_dim (:obj:`int`): The dimensionality of the input. - bg (:obj:`bg`): The gate bias. By setting bg > 0 we can explicitly initialize the gating mechanism to \ be close to the identity map. This can greatly improve the learning speed and stability since it \ initializes the agent close to a Markovian policy (ignore attention at the beginning). """ + super(GRUGatingUnit, self).__init__() self.Wr = torch.nn.Linear(input_dim, input_dim, bias=False) self.Ur = torch.nn.Linear(input_dim, input_dim, bias=False) @@ -86,6 +97,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): - g: (:obj:`torch.Tensor`): The output of the GRU gating mechanism. \ The shape of g matches the shapes of x and y. """ + r = self.sigmoid(self.Wr(y) + self.Ur(x)) z = self.sigmoid(self.Wz(y) + self.Uz(x) - self.bg) h = self.tanh(self.Wg(y) + self.Ug(torch.mul(r, x))) # element wise multiplication @@ -97,6 +109,8 @@ class Memory: """ Overview: A class that stores the context used to add memory to Transformer. + Interfaces: + ``__init__``, ``init``, ``update``, ``get``, ``to`` .. note:: For details, refer to Transformer-XL: https://arxiv.org/abs/1901.02860 @@ -111,6 +125,8 @@ def __init__( memory: Optional[torch.Tensor] = None ) -> None: """ + Overview: + Initialize the Memory module. Arguments: - memory_len (:obj:`int`): The dimension of memory, i.e., how many past observations to use as memory. - batch_size (:obj:`int`): The dimension of each batch. @@ -119,6 +135,7 @@ def __init__( - layer_num (:obj:`int`): The number of transformer layers. - memory (:obj:`Optional[torch.Tensor]`): The initial memory. Default is None. """ + super(Memory, self).__init__() self.embedding_dim = embedding_dim self.bs = batch_size @@ -136,6 +153,7 @@ def init(self, memory: Optional[torch.Tensor] = None): (layer_num, memory_len, bs, embedding_dim). Its shape is (layer_num, memory_len, bs, embedding_dim), \ where memory_len is length of memory, bs is batch size and embedding_dim is the dimension of embedding. """ + if memory is not None: self.memory = memory layer_num_plus1, self.memory_len, self.bs, self.embedding_dim = memory.shape @@ -163,6 +181,7 @@ def update(self, hidden_state: List[torch.Tensor]): - memory: (:obj:`Optional[torch.Tensor]`): The updated memory, with shape \ (layer_num, memory_len, bs, embedding_dim). """ + if self.memory is None or hidden_state is None: raise ValueError('Failed to update memory! Memory would be None') # TODO add support of no memory sequence_len = hidden_state[0].shape[0] @@ -187,6 +206,7 @@ def get(self): - memory: (:obj:`Optional[torch.Tensor]`): The current memory, \ with shape (layer_num, memory_len, bs, embedding_dim). """ + return self.memory def to(self, device: str = 'cpu'): @@ -196,6 +216,7 @@ def to(self, device: str = 'cpu'): Arguments: device (:obj:`str`): The device to move the memory to. Default is 'cpu'. """ + self.memory = self.memory.to(device) @@ -203,6 +224,8 @@ class AttentionXL(torch.nn.Module): """ Overview: An implementation of the Attention mechanism used in the TransformerXL model. + Interfaces: + ``__init__``, ``forward`` """ def __init__(self, input_dim: int, head_dim: int, head_num: int, dropout: nn.Module) -> None: @@ -215,6 +238,7 @@ def __init__(self, input_dim: int, head_dim: int, head_num: int, dropout: nn.Mod - head_num (:obj:`int`): The number of attention heads. - dropout (:obj:`nn.Module`): The dropout layer to use """ + super(AttentionXL, self).__init__() self.head_num = head_num self.head_dim = head_dim @@ -250,6 +274,7 @@ def _rel_shift(self, x: torch.Tensor, zero_upper: bool = False) -> torch.Tensor: - x (:obj:`torch.Tensor`): The input tensor after the relative shift operation, \ with shape (cur_seq, full_seq, bs, head_num). """ + x_padded = F.pad(x, [1, 0]) # step 1 x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2)) # step 2 x = x_padded[:, :, 1:].view_as(x) # step 3 @@ -282,6 +307,7 @@ def forward( Returns: - output (:obj:`torch.Tensor`): The output of the attention mechanism with shape (cur_seq, bs, input_dim). """ + bs, cur_seq, full_seq = inputs.shape[1], inputs.shape[0], full_input.shape[0] prev_seq = full_seq - cur_seq @@ -330,6 +356,8 @@ class GatedTransformerXLLayer(torch.nn.Module): """ Overview: This class implements the attention layer of GTrXL (Gated Transformer-XL). + Interfaces: + ``__init__``, ``forward`` """ def __init__( @@ -359,6 +387,7 @@ def __init__( residual connections. Default is True. - gru_bias (:obj:`float`, optional): The bias of the GRU gate. Default is 2. """ + super(GatedTransformerXLLayer, self).__init__() self.dropout = dropout self.gating = gru_gating @@ -406,6 +435,7 @@ def forward( Returns: - output (:obj:`torch.Tensor`): layer output of shape (cur_seq, bs, input_dim) """ + # concat memory with input across sequence dimension full_input = torch.cat([memory, inputs], dim=0) # full_seq x bs x input_dim x1 = self.layernorm1(full_input) @@ -423,6 +453,8 @@ class GTrXL(nn.Module): Overview: GTrXL Transformer implementation as described in "Stabilizing Transformer for Reinforcement Learning" (https://arxiv.org/abs/1910.06764). + Interfaces: + ``__init__``, ``forward``, ``reset_memory``, ``get_memory`` """ def __init__( @@ -459,6 +491,7 @@ def __init__( Raises: - AssertionError: If `embedding_dim` is not an even number. """ + super(GTrXL, self).__init__() assert embedding_dim % 2 == 0, 'embedding_dim={} should be even'.format(input_dim) self.head_num = head_num @@ -505,6 +538,7 @@ def reset_memory(self, batch_size: Optional[int] = None, state: Optional[torch.T - state (:obj:`Optional[torch.Tensor]`): The input memory with shape \ (layer_num, memory_len, bs, embedding_dim). Default is None. """ + self.memory = Memory(memory_len=self.memory_len, layer_num=self.layer_num, embedding_dim=self.embedding_dim) if batch_size is not None: self.memory = Memory(self.memory_len, batch_size, self.embedding_dim, self.layer_num) @@ -519,6 +553,7 @@ def get_memory(self): - memory (:obj:`Optional[torch.Tensor]`): The output memory or None if memory has not been initialized. \ The shape is (layer_num, memory_len, bs, embedding_dim). """ + if self.memory is None: return None else: @@ -538,6 +573,7 @@ def forward(self, x: torch.Tensor, batch_first: bool = False, return_mem: bool = - x (:obj:`Dict[str, torch.Tensor]`): A dictionary containing the transformer output of shape \ (seq_len, bs, embedding_size) and memory of shape (layer_num, seq_len, bs, embedding_size). """ + if batch_first: x = torch.transpose(x, 1, 0) # bs x cur_seq x input_dim -> cur_seq x bs x input_dim cur_seq, bs = x.shape[:2] diff --git a/ding/torch_utils/network/gumbel_softmax.py b/ding/torch_utils/network/gumbel_softmax.py index 4866ad6f16..fea7612103 100644 --- a/ding/torch_utils/network/gumbel_softmax.py +++ b/ding/torch_utils/network/gumbel_softmax.py @@ -8,7 +8,7 @@ class GumbelSoftmax(nn.Module): Overview: An `nn.Module` that computes GumbelSoftmax. Interfaces: - __init__, forward, gumbel_softmax_sample + ``__init__``, ``forward``, ``gumbel_softmax_sample`` .. note:: For more information on GumbelSoftmax, refer to the paper [Categorical Reparameterization \ diff --git a/ding/torch_utils/network/merge.py b/ding/torch_utils/network/merge.py index b942a0b05a..25d89885dd 100644 --- a/ding/torch_utils/network/merge.py +++ b/ding/torch_utils/network/merge.py @@ -52,13 +52,20 @@ class BilinearGeneral(nn.Module): Overview: Bilinear implementation as in: Multiplicative Interactions and Where to Find Them, ICLR 2020, https://openreview.net/forum?id=rylnK6VtDH. - Arguments: - - in1_features (:obj:`int`): The size of each first input sample. - - in2_features (:obj:`int`): The size of each second input sample. - - out_features (:obj:`int`): The size of each output sample. + Interfaces: + ``__init__``, ``forward`` """ - def __init__(self, in1_features, in2_features, out_features): + def __init__(self, in1_features: int, in2_features: int, out_features: int): + """ + Overview: + Initialize the Bilinear layer. + Arguments: + - in1_features (:obj:`int`): The size of each first input sample. + - in2_features (:obj:`int`): The size of each second input sample. + - out_features (:obj:`int`): The size of each output sample. + """ + super(BilinearGeneral, self).__init__() # Initialize the weight matrices W and U, and the bias vectors V and b self.W = nn.Parameter(torch.Tensor(out_features, in1_features, in2_features)) @@ -71,13 +78,26 @@ def __init__(self, in1_features, in2_features, out_features): self.reset_parameters() def reset_parameters(self): + """ + Overview: + Initialize the parameters of the Bilinear layer. + """ + stdv = 1. / np.sqrt(self.in1_features) self.W.data.uniform_(-stdv, stdv) self.U.data.uniform_(-stdv, stdv) self.V.data.uniform_(-stdv, stdv) self.b.data.uniform_(-stdv, stdv) - def forward(self, x, z): + def forward(self, x: torch.Tensor, z: torch.Tensor): + """ + Overview: + compute the bilinear function. + Arguments: + - x (:obj:`torch.Tensor`): The first input tensor. + - z (:obj:`torch.Tensor`): The second input tensor. + """ + # Compute the bilinear function # x^TWz out_W = torch.einsum('bi,kij,bj->bk', x, self.W, z) @@ -94,13 +114,20 @@ class TorchBilinearCustomized(nn.Module): """ Overview: Customized Torch Bilinear implementation. - Arguments: - - in1_features (:obj:`int`): The size of each first input sample. - - in2_features (:obj:`int`): The size of each second input sample. - - out_features (:obj:`int`): The size of each output sample. + Interfaces: + ``__init__``, ``forward`` """ - def __init__(self, in1_features, in2_features, out_features): + def __init__(self, in1_features: int, in2_features: int, out_features: int): + """ + Overview: + Initialize the Bilinear layer. + Arguments: + - in1_features (:obj:`int`): The size of each first input sample. + - in2_features (:obj:`int`): The size of each second input sample. + - out_features (:obj:`int`): The size of each output sample. + """ + super(TorchBilinearCustomized, self).__init__() self.in1_features = in1_features self.in2_features = in2_features @@ -110,11 +137,24 @@ def __init__(self, in1_features, in2_features, out_features): self.reset_parameters() def reset_parameters(self): + """ + Overview: + Initialize the parameters of the Bilinear layer. + """ + bound = 1 / math.sqrt(self.in1_features) nn.init.uniform_(self.weight, -bound, bound) nn.init.uniform_(self.bias, -bound, bound) def forward(self, x, z): + """ + Overview: + Compute the bilinear function. + Arguments: + - x (:obj:`torch.Tensor`): The first input tensor. + - z (:obj:`torch.Tensor`): The second input tensor. + """ + # Using torch.einsum for the bilinear operation out = torch.einsum('bi,oij,bj->bo', x, self.weight, z) + self.bias return out.squeeze(-1) @@ -138,18 +178,25 @@ class FiLM(nn.Module): Overview: Feature-wise Linear Modulation (FiLM) Layer. This layer applies feature-wise affine transformation based on context. - Arguments: - - feature_dim (:obj:`int`). The dimension of the input feature vector. - - context_dim (:obj:`int`). The dimension of the input context vector. + Interfaces: + ``__init__``, ``forward`` """ - def __init__(self, feature_dim, context_dim): + def __init__(self, feature_dim: int, context_dim: int): + """ + Overview: + Initialize the FiLM layer. + Arguments: + - feature_dim (:obj:`int`). The dimension of the input feature vector. + - context_dim (:obj:`int`). The dimension of the input context vector. + """ + super(FiLM, self).__init__() # Define the fully connected layer for context # The output dimension is twice the feature dimension for gamma and beta self.context_layer = nn.Linear(context_dim, 2 * feature_dim) - def forward(self, feature, context): + def forward(self, feature: torch.Tensor, context: torch.Tensor): """ Overview: Forward propagation. @@ -159,6 +206,7 @@ def forward(self, feature, context): Returns: - conditioned_feature : torch.Tensor. The output feature after FiLM, shape (batch_size, feature_dim). """ + # Pass context through the fully connected layer out = self.context_layer(context) # Split the output into two parts: gamma and beta @@ -184,6 +232,8 @@ class SumMerge(nn.Module): Overview: A PyTorch module that merges a list of tensors by computing their sum. All input tensors must have the same size. This module can work with any type of tensor (vector, units or visual). + Interfaces: + ``__init__``, ``forward`` """ def forward(self, tensors: List[Tensor]) -> Tensor: @@ -209,12 +259,8 @@ class VectorMerge(nn.Module): Overview: Merges multiple vector streams. Streams are first transformed through layer normalization, relu, and linear layers, then summed. They don't need to have the same size. Gating can also be used before the sum. - Arguments: - - input_sizes (:obj:`Dict[str, int]`): A dictionary mapping input names to their size (a single \ - integer for 1d inputs, or None for 0d inputs). If an input size is None, we assume it's (). - - output_size (:obj:`int`): The size of the output vector. - - gating_type (:obj:`GatingType`): The type of gating mechanism to use. - - use_layer_norm (:obj:`bool`): Whether to use layer normalization. + Interfaces: + ``__init__``, ``encode``, ``_compute_gate``, ``forward`` .. note:: For more details about the gating types, please refer to the GatingType enum class. diff --git a/ding/torch_utils/network/nn_module.py b/ding/torch_utils/network/nn_module.py index 794cec8f79..64a21edfe4 100644 --- a/ding/torch_utils/network/nn_module.py +++ b/ding/torch_utils/network/nn_module.py @@ -199,7 +199,7 @@ def deconv2d_block( activation: int = None, norm_type: int = None ) -> nn.Sequential: - r""" + """ Overview: Create a 2-dimensional transpose convolution layer with activation and normalization. Arguments: @@ -457,6 +457,8 @@ class ChannelShuffle(nn.Module): Overview: Apply channel shuffle to the input tensor. For more details about the channel shuffle, please refer to the 'ShuffleNet' paper: https://arxiv.org/abs/1707.01083 + Interfaces: + ``__init__``, ``forward`` """ def __init__(self, group_num: int) -> None: @@ -538,11 +540,11 @@ def one_hot(val: torch.LongTensor, num: int, num_first: bool = False) -> torch.F class NearestUpsample(nn.Module): - r""" + """ Overview: This module upsamples the input to the given scale_factor using the nearest mode. - Interface: - ``forward`` + Interfaces: + ``__init__``, ``forward`` """ def __init__(self, scale_factor: Union[float, List[float]]) -> None: @@ -571,8 +573,8 @@ class BilinearUpsample(nn.Module): """ Overview: This module upsamples the input to the given scale_factor using the bilinear mode. - Interface: - ``forward`` + Interfaces: + ``__init__``, ``forward`` """ def __init__(self, scale_factor: Union[float, List[float]]) -> None: @@ -625,11 +627,11 @@ def binary_encode(y: torch.Tensor, max_val: torch.Tensor) -> torch.Tensor: class NoiseLinearLayer(nn.Module): - r""" + """ Overview: This is a linear layer with random noise. - Interface: - ``reset_noise``, ``reset_parameters``, ``forward`` + Interfaces: + ``__init__``, ``reset_noise``, ``reset_parameters``, ``forward`` """ def __init__(self, in_channels: int, out_channels: int, sigma0: int = 0.4) -> None: @@ -659,7 +661,10 @@ def _scale_noise(self, size: Union[int, Tuple]): """ Overview: Scale the noise. + Arguments: + - size (:obj:`Union[int, Tuple]`): The size of the noise. """ + x = torch.randn(size) x = x.sign().mul(x.abs().sqrt()) return x @@ -748,6 +753,8 @@ class NaiveFlatten(nn.Module): """ Overview: This module is a naive implementation of the flatten operation. + Interfaces: + ``__init__``, ``forward`` """ def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: diff --git a/ding/torch_utils/network/popart.py b/ding/torch_utils/network/popart.py index 75a1c9886a..e01406a57a 100644 --- a/ding/torch_utils/network/popart.py +++ b/ding/torch_utils/network/popart.py @@ -21,7 +21,7 @@ class PopArt(nn.Module): PopArt normalization, which is a method to automatically adapt the contribution of each task to the agent's updates in multi-task learning, as described in the paper . - Interface: + Interfaces: ``__init__``, ``reset_parameters``, ``forward``, ``update_parameters`` """ diff --git a/ding/torch_utils/network/res_block.py b/ding/torch_utils/network/res_block.py index 86070c3336..14223f940c 100644 --- a/ding/torch_utils/network/res_block.py +++ b/ding/torch_utils/network/res_block.py @@ -24,7 +24,7 @@ class ResBlock(nn.Module): For more details, please refer to `Deep Residual Learning for Image Recognition `_. Interfaces: - ``forward`` + ``__init__``, ``forward`` """ def __init__( @@ -110,7 +110,7 @@ class ResFCBlock(nn.Module): \_____________________________________/+ Interfaces: - ``forward`` + ``__init__``, ``forward`` """ def __init__( diff --git a/ding/torch_utils/network/resnet.py b/ding/torch_utils/network/resnet.py index c4789d48e5..643f353355 100644 --- a/ding/torch_utils/network/resnet.py +++ b/ding/torch_utils/network/resnet.py @@ -93,6 +93,8 @@ class AvgPool2dSame(nn.AvgPool2d): """ Overview: Tensorflow-like 'SAME' wrapper for 2D average pooling. + Interfaces: + ``__init__``, ``forward`` """ def __init__( @@ -203,6 +205,8 @@ class ClassifierHead(nn.Module): """ Overview: Classifier head with configurable global pooling and dropout. + Interfaces: + ``__init__``, ``forward`` """ def __init__( @@ -280,6 +284,8 @@ class BasicBlock(nn.Module): The basic building block for models like ResNet. This class extends pytorch's Module class. It represents a standard block of layers including two convolutions, batch normalization, an optional attention mechanism, and activation functions. + Interfaces: + ``__init__``, ``forward``, ``zero_init_last_bn`` Properties: - expansion (:obj:int): Specifies the expansion factor for the planes of the conv layers. """ @@ -409,7 +415,7 @@ class Bottleneck(nn.Module): implementation of ResNet. This block is designed with several layers including a convolutional layer, normalization layer, activation layer, attention layer, anti-aliasing layer, and a dropout layer. Interfaces: - forward, zero_init_last_bn + ``__init__``, ``forward``, ``zero_init_last_bn`` Properties: expansion, inplanes, planes, stride, downsample, cardinality, base_width, reduce_first, dilation, \ first_dilation, act_layer, norm_layer, attn_layer, aa_layer, drop_block, drop_path @@ -723,32 +729,8 @@ class ResNet(nn.Module): Implements ResNet, ResNeXt, SE-ResNeXt, and SENet models. This implementation supports various modifications based on the v1c, v1d, v1e, and v1s variants included in the MXNet Gluon ResNetV1b model. For more details about the variants and options, please refer to the 'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. - Arguments: - - block (:obj:`nn.Module`): Class for the residual block. Options are BasicBlockGl, BottleneckGl. - - layers (:obj:`list` of :obj:`int`): Numbers of layers in each block. - - num_classes (:obj:`int`, optional): Number of classification classes. Default is 1000. - - in_chans (:obj:`int`, optional): Number of input (color) channels. Default is 3. - - cardinality (:obj:`int`, optional): Number of convolution groups for 3x3 conv in Bottleneck. Default is 1. - - base_width (:obj:`int`, optional): Factor determining bottleneck channels. Default is 64. - - stem_width (:obj:`int`, optional): Number of channels in stem convolutions. Default is 64. - - stem_type (:obj:`str`, optional): The type of stem. Default is ''. - - replace_stem_pool (:obj:`bool`, optional): Whether to replace stem pooling. Default is False. - - output_stride (:obj:`int`, optional): Output stride of the network. Default is 32. - - block_reduce_first (:obj:`int`, optional): Reduction factor for first convolution output width of \ - residual blocks. - - down_kernel_size (:obj:`int`, optional): Kernel size of residual block downsampling path. Default is 1. - - avg_down (:obj:`bool`, optional): Whether to use average pooling for projection skip connection between \ - stages/downsample. Default is False. - - act_layer (:obj:`nn.Module`, optional): Activation layer. Default is nn.ReLU. - - norm_layer (:obj:`nn.Module`, optional): Normalization layer. Default is nn.BatchNorm2d. - - aa_layer (:obj:`nn.Module`, optional): Anti-aliasing layer. Default is None. - - drop_rate (:obj:`float`, optional): Dropout probability before classifier, for training. Default is 0.0. - - drop_path_rate (:obj:`float`, optional): Drop path rate. Default is 0.0. - - drop_block_rate (:obj:`float`, optional): Drop block rate. Default is 0.0. - - global_pool (:obj:`str`, optional): Global pooling type. Default is 'avg'. - - zero_init_last_bn (:obj:`bool`, optional): Whether to initialize last batch normalization with zero. \ - Default is True. - - block_args (:obj:`dict`, optional): Additional arguments for block. Default is None. + Interfaces: + ``__init__``, ``forward``, ``zero_init_last_bn``, ``get_classifier`` """ def __init__( diff --git a/ding/torch_utils/network/rnn.py b/ding/torch_utils/network/rnn.py index e7f4dada46..e24bd7e468 100644 --- a/ding/torch_utils/network/rnn.py +++ b/ding/torch_utils/network/rnn.py @@ -25,7 +25,7 @@ def is_sequence(data): def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] = None) -> torch.BoolTensor: - r""" + """ Overview: Generates a boolean mask for a batch of sequences with differing lengths. Arguments: @@ -132,8 +132,8 @@ class LSTM(nn.Module, LSTMForwardWrapper): """ Overview: Implementation of an LSTM cell with Layer Normalization (LN). - Interface: - ``forward`` + Interfaces: + ``__init__``, ``forward`` .. note:: @@ -178,6 +178,11 @@ def __init__( self._init() def _init(self): + """ + Overview: + Initialize the parameters of the LSTM cell. + """ + gain = math.sqrt(1. / self.hidden_size) for l in range(self.num_layers): torch.nn.init.uniform_(self.wx[l], -gain, gain) @@ -239,7 +244,7 @@ class PytorchLSTM(nn.LSTM, LSTMForwardWrapper): Overview: Wrapper class for PyTorch's nn.LSTM, formats the input and output. For more details on nn.LSTM, refer to https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM - Interface: + Interfaces: ``forward`` """ @@ -270,8 +275,8 @@ class GRU(nn.GRUCell, LSTMForwardWrapper): Overview: This class extends the `torch.nn.GRUCell` and `LSTMForwardWrapper` classes, and formats inputs and outputs accordingly. - Interface: - ``forward`` + Interfaces: + ``__init__``, ``forward`` Properties: hidden_size, num_layers diff --git a/ding/torch_utils/network/scatter_connection.py b/ding/torch_utils/network/scatter_connection.py index 4826a9bb5d..d596f3aa1c 100644 --- a/ding/torch_utils/network/scatter_connection.py +++ b/ding/torch_utils/network/scatter_connection.py @@ -32,6 +32,8 @@ class ScatterConnection(nn.Module): Overview: Scatter feature to its corresponding location. In AlphaStar, each entity is embedded into a tensor, and these tensors are scattered into a feature map with map size. + Interfaces: + ``__init__``, ``forward``, ``xy_forward`` """ def __init__(self, scatter_type: str) -> None: diff --git a/ding/torch_utils/network/soft_argmax.py b/ding/torch_utils/network/soft_argmax.py index 9a44de5ebc..166d0bb8f6 100644 --- a/ding/torch_utils/network/soft_argmax.py +++ b/ding/torch_utils/network/soft_argmax.py @@ -9,7 +9,7 @@ class SoftArgmax(nn.Module): A neural network module that computes the SoftArgmax operation (essentially a 2-dimensional spatial softmax), which is often used for location regression tasks. It converts a feature map (such as a heatmap) into precise coordinate locations. - Interface: + Interfaces: ``__init__``, ``forward`` .. note:: diff --git a/ding/torch_utils/network/transformer.py b/ding/torch_utils/network/transformer.py index cc4deb44e7..7a508b3909 100644 --- a/ding/torch_utils/network/transformer.py +++ b/ding/torch_utils/network/transformer.py @@ -12,7 +12,7 @@ class Attention(nn.Module): Overview: For each entry embedding, compute individual attention across all entries, add them up to get output attention. Interfaces: - ``split``, ``forward`` + ``__init__``, ``split``, ``forward`` """ def __init__(self, input_dim: int, head_dim: int, output_dim: int, head_num: int, dropout: nn.Module) -> None: @@ -87,7 +87,7 @@ class TransformerLayer(nn.Module): Overview: In transformer layer, first computes entries's attention and applies a feedforward layer. Interfaces: - ``forward`` + ``__init__``, ``forward`` """ def __init__( @@ -148,7 +148,7 @@ class Transformer(nn.Module): .. note:: For more details, refer to "Attention is All You Need": http://arxiv.org/abs/1706.03762. Interfaces: - ``forward`` + ``__init__``, ``forward`` """ def __init__( @@ -217,6 +217,8 @@ class ScaledDotProductAttention(nn.Module): Implementation of Scaled Dot Product Attention, a key component of Transformer models. This class performs the dot product of the query, key and value tensors, scales it with the square root of the dimension of the key vector (d_k) and applies dropout for regularization. + Interfaces: + ``__init__``, ``forward`` """ def __init__(self, d_k: int, dropout: float = 0.0) -> None: diff --git a/ding/torch_utils/optimizer_helper.py b/ding/torch_utils/optimizer_helper.py index 239b193c1b..d4d351cca2 100644 --- a/ding/torch_utils/optimizer_helper.py +++ b/ding/torch_utils/optimizer_helper.py @@ -14,7 +14,7 @@ def calculate_grad_norm(model: torch.nn.Module, norm_type=2) -> float: - r""" + """ Overview: calculate grad norm of the parameters whose grad norms are not None in the model. Arguments: @@ -38,7 +38,7 @@ def calculate_grad_norm(model: torch.nn.Module, norm_type=2) -> float: def calculate_grad_norm_without_bias_two_norm(model: torch.nn.Module) -> float: - r""" + """ Overview: calculate grad norm of the parameters whose grad norms are not None in the model. Arguments: @@ -54,6 +54,14 @@ def calculate_grad_norm_without_bias_two_norm(model: torch.nn.Module) -> float: def grad_ignore_norm(parameters, max_norm, norm_type=2): + """ + Overview: + Clip the gradient norm of an iterable of parameters. + Arguments: + - parameters (:obj:`Iterable`): an iterable of torch.Tensor + - max_norm (:obj:`float`): the max norm of the gradients + - norm_type (:obj:`float`): 2.0 means use norm2 to clip + """ if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) @@ -75,6 +83,13 @@ def grad_ignore_norm(parameters, max_norm, norm_type=2): def grad_ignore_value(parameters, clip_value): + """ + Overview: + Clip the gradient value of an iterable of parameters. + Arguments: + - parameters (:obj:`Iterable`): an iterable of torch.Tensor + - clip_value (:obj:`float`): the value to start clipping + """ if isinstance(parameters, torch.Tensor): parameters = [parameters] clip_value = float(clip_value) @@ -90,11 +105,11 @@ def grad_ignore_value(parameters, clip_value): class Adam(torch.optim.Adam): - r""" + """ Overview: Rewrited Adam optimizer to support more features. - Interface: - __init__, step + Interfaces: + ``__init__``, ``step``, ``_state_init``, ``get_grad`` """ def __init__( @@ -118,7 +133,7 @@ def __init__( ignore_norm_type: float = 2.0, ignore_momentum_timestep: int = 100, ): - r""" + """ Overview: init method of refactored Adam class Arguments: @@ -189,6 +204,14 @@ def __init__( ) def _state_init(self, p, amsgrad): + """ + Overview: + Initialize the state of the optimizer + Arguments: + - p (:obj:`torch.Tensor`): the parameter to be optimized + - amsgrad (:obj:`bool`): whether to use the AMSGrad variant of this algorithm from the paper\ + On the Convergence of Adam and Beyond + """ state = self.state[p] state['thre_exp_avg_sq'] = torch.zeros_like(p.data, device=p.data.device) # others @@ -208,7 +231,7 @@ def _state_init(self, p, amsgrad): state['max_exp_avg_sq'] = torch.zeros_like(p.data) def step(self, closure: Union[Callable, None] = None): - r""" + """ Overview: Performs a single optimization step Arguments: @@ -373,8 +396,8 @@ class RMSprop(torch.optim.RMSprop): r""" Overview: Rewrited RMSprop optimizer to support more features. - Interface: - __init__, step + Interfaces: + ``__init__``, ``step``, ``_state_init``, ``get_grad`` """ def __init__( @@ -398,7 +421,7 @@ def __init__( ignore_norm_type: float = 2.0, ignore_momentum_timestep: int = 100, ): - r""" + """ Overview: init method of refactored Adam class Arguments: @@ -455,6 +478,16 @@ def __init__( ) def _state_init(self, p, momentum, centered): + """ + Overview: + Initialize the state of the optimizer + Arguments: + - p (:obj:`torch.Tensor`): the parameter to be optimized + - momentum (:obj:`float`): the momentum coefficient + - centered (:obj:`bool`): if True, compute the centered RMSprop, \ + the gradient is normalized by an estimation of its variance + """ + state = self.state[p] state['step'] = 0 state['thre_square_avg'] = torch.zeros_like(p.data, device=p.data.device) @@ -465,7 +498,7 @@ def _state_init(self, p, momentum, centered): state['grad_avg'] = torch.zeros_like(p.data, device=p.data.device) def step(self, closure: Union[Callable, None] = None): - r""" + """ Overview: Performs a single optimization step Arguments: @@ -595,6 +628,11 @@ def step(self, closure: Union[Callable, None] = None): return super().step(closure=closure) def get_grad(self) -> float: + """ + Overview: + calculate grad norm of the parameters whose grad norms are not None in the model. + """ + total_norm = 0. params = [t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None] for p in params: @@ -608,35 +646,55 @@ class PCGrad(): Overview: PCGrad optimizer to support multi-task. you can view the paper in the following link https://arxiv.org/pdf/2001.06782.pdf + Interfaces: + ``__init__``, ``zero_grad``, ``step``, ``pc_backward`` + Properties: + - optimizer (:obj:`torch.optim`): the optimizer to be used """ def __init__(self, optimizer, reduction='mean'): + """ + Overview: + Initialization of PCGrad optimizer + Arguments: + - optimizer (:obj:`torch.optim`): the optimizer to be used + - reduction (:obj:`str`): the reduction method, support ['mean', 'sum'] + """ + self._optim, self._reduction = optimizer, reduction @property def optimizer(self): + """ + Overview: + get the optimizer + """ + return self._optim def zero_grad(self): - ''' - clear the gradient of the parameters - ''' + """ + Overview: + clear the gradient of the parameters + """ return self._optim.zero_grad(set_to_none=True) def step(self): - ''' - update the parameters with the gradient - ''' + """ + Overview: + update the parameters with the gradient + """ return self._optim.step() def pc_backward(self, objectives): - ''' - calculate the gradient of the parameters + """ + Overview: + calculate the gradient of the parameters Arguments: - objectives: a list of objectives - ''' + """ grads, shapes, has_grads = self._pack_grad(objectives) pc_grad = self._project_conflicting(grads, has_grads) @@ -645,6 +703,15 @@ def pc_backward(self, objectives): return def _project_conflicting(self, grads, has_grads, shapes=None): + """ + Overview: + project the conflicting gradient to the orthogonal space + Arguments: + - grads (:obj:`list`): a list of the gradient of the parameters + - has_grads (:obj:`list`): a list of mask represent whether the parameter has gradient + - shapes (:obj:`list`): a list of the shape of the parameters + """ + shared = torch.stack(has_grads).prod(0).bool() pc_grad, num_task = copy.deepcopy(grads), len(grads) for g_i in pc_grad: @@ -665,9 +732,12 @@ def _project_conflicting(self, grads, has_grads, shapes=None): return merged_grad def _set_grad(self, grads): - ''' - set the modified gradients to the network - ''' + """ + Overview: + set the modified gradients to the network + Arguments: + - grads (:obj:`list`): a list of the gradient of the parameters + """ idx = 0 for group in self._optim.param_groups: @@ -678,13 +748,16 @@ def _set_grad(self, grads): return def _pack_grad(self, objectives): - ''' - pack the gradient of the parameters of the network for each objective + """ + Overview: + pack the gradient of the parameters of the network for each objective + Arguments: + - objectives: a list of objectives Returns: - grad: a list of the gradient of the parameters - shape: a list of the shape of the parameters - has_grad: a list of mask represent whether the parameter has gradient - ''' + """ grads, shapes, has_grads = [], [], [] for obj in objectives: @@ -697,6 +770,14 @@ def _pack_grad(self, objectives): return grads, shapes, has_grads def _unflatten_grad(self, grads, shapes): + """ + Overview: + unflatten the gradient of the parameters of the network + Arguments: + - grads (:obj:`list`): a list of the gradient of the parameters + - shapes (:obj:`list`): a list of the shape of the parameters + """ + unflatten_grad, idx = [], 0 for shape in shapes: length = np.prod(shape) @@ -705,17 +786,26 @@ def _unflatten_grad(self, grads, shapes): return unflatten_grad def _flatten_grad(self, grads, shapes): + """ + Overview: + flatten the gradient of the parameters of the network + Arguments: + - grads (:obj:`list`): a list of the gradient of the parameters + - shapes (:obj:`list`): a list of the shape of the parameters + """ + flatten_grad = torch.cat([g.flatten() for g in grads]) return flatten_grad def _retrieve_grad(self): - ''' - get the gradient of the parameters of the network with specific objective + """ + Overview: + get the gradient of the parameters of the network with specific objective Returns: - grad: a list of the gradient of the parameters - shape: a list of the shape of the parameters - has_grad: a list of mask represent whether the parameter has gradient - ''' + """ grad, shape, has_grad = [], [], [] for group in self._optim.param_groups: @@ -734,7 +824,7 @@ def _retrieve_grad(self): def configure_weight_decay(model: nn.Module, weight_decay: float) -> List: - r""" + """ Overview: Separating out all parameters of the model into two buckets: those that will experience weight decay for regularization and those that won't (biases, and layer-norm or embedding weights). diff --git a/ding/torch_utils/parameter.py b/ding/torch_utils/parameter.py index 9126bf7fd3..08da7feb76 100644 --- a/ding/torch_utils/parameter.py +++ b/ding/torch_utils/parameter.py @@ -9,7 +9,7 @@ class NonegativeParameter(nn.Module): Overview: This module will output a non-negative parameter during the forward process. Interfaces: - __init__, forward, set_data. + ``__init__``, ``forward``, ``set_data``. """ def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = True, delta: float = 1e-8): @@ -29,7 +29,7 @@ def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = Tr def forward(self) -> torch.Tensor: """ - Overview: + Overview: Output the non-negative parameter during the forward process. Returns: parameter (:obj:`torch.Tensor`): The generated parameter. @@ -51,7 +51,7 @@ class TanhParameter(nn.Module): Overview: This module will output a tanh parameter during the forward process. Interfaces: - __init__, forward, set_data. + ``__init__``, ``forward``, ``set_data``. """ def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = True): @@ -72,7 +72,7 @@ def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = Tr def forward(self) -> torch.Tensor: """ - Overview: + Overview: Output the tanh parameter during the forward process. Returns: parameter (:obj:`torch.Tensor`): The generated parameter. @@ -81,7 +81,7 @@ def forward(self) -> torch.Tensor: def set_data(self, data: torch.Tensor) -> None: """ - Overview: + Overview: Set the value of the tanh parameter. Arguments: data (:obj:`torch.Tensor`): The new value of the tanh parameter. diff --git a/ding/torch_utils/reshape_helper.py b/ding/torch_utils/reshape_helper.py index e9e3f60676..49a3e3b8d2 100644 --- a/ding/torch_utils/reshape_helper.py +++ b/ding/torch_utils/reshape_helper.py @@ -4,7 +4,7 @@ def fold_batch(x: Tensor, nonbatch_ndims: int = 1) -> Tuple[Tensor, Size]: - r""" + """ Overview: :math:`(T, B, X) \leftarrow (T*B, X)`\ Fold the first (ndim - nonbatch_ndims) dimensions of a tensor as batch dimension.\ @@ -39,7 +39,7 @@ def fold_batch(x: Tensor, nonbatch_ndims: int = 1) -> Tuple[Tensor, Size]: def unfold_batch(x: Tensor, batch_dims: Union[Size, Tuple]) -> Tensor: - r""" + """ Overview: Unfold the batch dimension of a tensor. @@ -62,7 +62,7 @@ def unfold_batch(x: Tensor, batch_dims: Union[Size, Tuple]) -> Tensor: def unsqueeze_repeat(x: Tensor, repeat_times: int, unsqueeze_dim: int = 0) -> Tensor: - r""" + """ Overview: Squeeze the tensor on `unsqueeze_dim` and then repeat in this dimension for `repeat_times` times.\ This is useful for preproprocessing the input to an model ensemble. diff --git a/ding/utils/autolog/data.py b/ding/utils/autolog/data.py index 9b960e4fa0..e611b97f43 100644 --- a/ding/utils/autolog/data.py +++ b/ding/utils/autolog/data.py @@ -10,8 +10,24 @@ class RangedData(metaclass=ABCMeta): + """ + Overview: + A data structure that can store data for a period of time. + Interfaces: + ``__init__``, ``append``, ``extend``, ``current``, ``history``, ``expire``, ``__bool__``, ``_get_time``. + Properties: + - expire (:obj:`float`): The expire time. + """ def __init__(self, expire: float, use_pickle: bool = False): + """ + Overview: + Initialize the RangedData object. + Arguments: + - expire (:obj:`float`): The expire time of the data. + - use_pickle (:obj:`bool`): Whether to use pickle to serialize the data. + """ + self.__expire = expire self.__use_pickle = use_pickle self.__check_expire() @@ -25,6 +41,11 @@ def __init__(self, expire: float, use_pickle: bool = False): self.__lock = Lock() def __check_expire(self): + """ + Overview: + Check the expire time. + """ + if isinstance(self.__expire, (int, float)): if self.__expire <= 0: raise ValueError( @@ -36,6 +57,13 @@ def __check_expire(self): ) def __registry_data_item(self, data: _Tp) -> int: + """ + Overview: + Registry the data item. + Arguments: + - data (:obj:`_Tp`): The data item. + """ + with self.__data_lock: self.__data_max_id += 1 if self.__use_pickle: @@ -46,6 +74,13 @@ def __registry_data_item(self, data: _Tp) -> int: return self.__data_max_id def __get_data_item(self, data_id: int) -> _Tp: + """ + Overview: + Get the data item. + Arguments: + - data_id (:obj:`int`): The data id. + """ + with self.__data_lock: if self.__use_pickle: return pickle.loads(self.__data_items[data_id]) @@ -53,10 +88,24 @@ def __get_data_item(self, data_id: int) -> _Tp: return self.__data_items[data_id] def __remove_data_item(self, data_id: int): + """ + Overview: + Remove the data item. + Arguments: + - data_id (:obj:`int`): The data id. + """ + with self.__data_lock: del self.__data_items[data_id] def __check_time(self, time_: float): + """ + Overview: + Check the time. + Arguments: + - time_ (:obj:`float`): The time. + """ + if self.__queue: _time, _ = self.__queue[-1] if time_ < _time: @@ -67,9 +116,22 @@ def __check_time(self, time_: float): ) def __append_item(self, time_: float, data: _Tp): + """ + Overview: + Append the data item. + Arguments: + - time_ (:obj:`float`): The time. + - data (:obj:`_Tp`): The data item. + """ + self.__queue.append((time_, self.__registry_data_item(data))) def __flush_history(self): + """ + Overview: + Flush the history data. + """ + _time = self._get_time() _limit_time = _time - self.__expire while self.__queue: @@ -85,11 +147,21 @@ def __flush_history(self): self.__last_item = (_head_time, _head_id) def __append(self, time_: float, data: _Tp): + """ + Overview: + Append the data. + """ + self.__check_time(time_) self.__append_item(time_, data) self.__flush_history() def __current(self): + """ + Overview: + Get the current data. + """ + if self.__queue: _tail_time, _tail_id = self.__queue.pop() self.__queue.append((_tail_time, _tail_id)) @@ -101,6 +173,11 @@ def __current(self): raise ValueError("This range is empty.") def __history_yield(self): + """ + Overview: + Yield the history data. + """ + _time = self._get_time() _limit_time = _time - self.__expire _latest_time, _latest_id = None, None @@ -117,9 +194,19 @@ def __history_yield(self): yield _time, self.__get_data_item(_latest_id) def __history(self): + """ + Overview: + Get the history data. + """ + return list(self.__history_yield()) def append(self, data: _Tp): + """ + Overview: + Append the data. + """ + with self.__lock: self.__flush_history() _time = self._get_time() @@ -127,6 +214,11 @@ def append(self, data: _Tp): return self def extend(self, iter_: Iterable[_Tp]): + """ + Overview: + Extend the data. + """ + with self.__lock: self.__flush_history() _time = self._get_time() @@ -135,40 +227,92 @@ def extend(self, iter_: Iterable[_Tp]): return self def current(self) -> _Tp: + """ + Overview: + Get the current data. + """ + with self.__lock: self.__flush_history() return self.__current() def history(self) -> List[Tuple[Union[int, float], _Tp]]: + """ + Overview: + Get the history data. + """ + with self.__lock: self.__flush_history() return self.__history() @property def expire(self) -> float: + """ + Overview: + Get the expire time. + """ + with self.__lock: self.__flush_history() return self.__expire def __bool__(self): + """ + Overview: + Check whether the range is empty. + """ + with self.__lock: self.__flush_history() return not not (self.__queue or self.__last_item) @abstractmethod def _get_time(self) -> float: + """ + Overview: + Get the current time. + """ + raise NotImplementedError class TimeRangedData(RangedData): + """ + Overview: + A data structure that can store data for a period of time. + Interfaces: + ``__init__``, ``_get_time``, ``append``, ``extend``, ``current``, ``history``, ``expire``, ``__bool__``. + Properties: + - time (:obj:`BaseTime`): The time. + - expire (:obj:`float`): The expire time. + """ def __init__(self, time_: BaseTime, expire: float): + """ + Overview: + Initialize the TimeRangedData object. + Arguments: + - time_ (:obj:`BaseTime`): The time. + - expire (:obj:`float`): The expire time. + """ + RangedData.__init__(self, expire) self.__time = time_ def _get_time(self) -> float: + """ + Overview: + Get the current time. + """ + return self.__time.time() @property def time(self): + """ + Overview: + Get the time. + """ + return self.__time diff --git a/ding/utils/autolog/model.py b/ding/utils/autolog/model.py index 041c710caf..5c58bb6544 100644 --- a/ding/utils/autolog/model.py +++ b/ding/utils/autolog/model.py @@ -11,6 +11,12 @@ class _LoggedModelMeta(ABCMeta): + """ + Overview: + Metaclass of LoggedModel, used to find all LoggedValue properties and register them. + Interfaces: + ``__init__`` + """ def __init__(cls, name: str, bases: tuple, namespace: dict): @@ -75,14 +81,24 @@ class LoggedModel(metaclass=_LoggedModelMeta): >>> print(ll.range_values['value'](TimeMode.ABSOLUTE)) # use absolute time >>> print(ll.avg['value']()) # average value of last 10 secs - Interface: - __init__, fixed_time, current_time, freeze, unfreeze, register_attribute_value, __getattr__ + Interfaces: + ``__init__``, ``time``, ``expire``, ``fixed_time``, ``current_time``, ``freeze``, ``unfreeze``, \ + ``register_attribute_value``, ``__getattr__``, ``get_property_attribute`` Property: - time, expire + - time (:obj:`BaseTime`): The time. + - expire (:obj:`float`): The expire time. """ def __init__(self, time_: _TimeObjectType, expire: _TimeType): + """ + Overview: + Initialize the LoggedModel object using the given arguments. + Arguments: + - time_ (:obj:`BaseTime`): The time. + - expire (:obj:`float`): The expire time. + """ + self.__time = time_ self.__time_proxy = TimeProxy(self.__time, frozen=False) self.__init_time = self.__time_proxy.time() @@ -96,12 +112,29 @@ def __init__(self, time_: _TimeObjectType, expire: _TimeType): @property def __properties(self) -> List[str]: + """ + Overview: + Get all property names. + """ + return getattr(self, _LOGGED_MODEL__PROPERTIES) def __get_property_ranged_data(self, name: str) -> TimeRangedData: + """ + Overview: + Get ranged data of one property. + Arguments: + - name (:obj:`str`): The property name. + """ + return getattr(self, _LOGGED_MODEL__PROPERTY_ATTR_PREFIX + name) def __init_properties(self): + """ + Overview: + Initialize all properties. + """ + for name in self.__properties: setattr( self, _LOGGED_MODEL__PROPERTY_ATTR_PREFIX + name, @@ -109,6 +142,12 @@ def __init_properties(self): ) def __get_range_values_func(self, name: str): + """ + Overview: + Get range_values function of one property. + Arguments: + - name (:obj:`str`): The property name. + """ def _func(mode: TimeMode = TimeMode.RELATIVE_LIFECYCLE): _current_time = self.__time_proxy.time() @@ -130,6 +169,11 @@ def _func(mode: TimeMode = TimeMode.RELATIVE_LIFECYCLE): return _func def __register_default_funcs(self): + """ + Overview: + Register default functions. + """ + for name in self.__properties: self.register_attribute_value('range_values', name, self.__get_range_values_func(name)) @@ -196,6 +240,10 @@ def register_attribute_value(self, attribute_name: str, property_name: str, valu """ Overview: Register a new attribute for one of the values. Example can be found in overview of class. + Arguments: + - attribute_name (:obj:`str`): name of attribute + - property_name (:obj:`str`): name of property + - value (:obj:`Any`): value of attribute """ self.__methods[attribute_name] = self.__methods.get(attribute_name, {}) self.__methods[attribute_name][property_name] = value @@ -210,7 +258,7 @@ def __getattr__(self, attribute_name: str) -> Any: Overview: Support all methods registered. - Args: + Arguments: attribute_name (str): name of attribute Return: diff --git a/ding/utils/autolog/time_ctl.py b/ding/utils/autolog/time_ctl.py index e350dd92e9..110753e4cf 100644 --- a/ding/utils/autolog/time_ctl.py +++ b/ding/utils/autolog/time_ctl.py @@ -9,6 +9,8 @@ class BaseTime(metaclass=ABCMeta): """ Overview: Abstract time interface + Interfaces: + ``time`` """ @abstractmethod @@ -27,7 +29,8 @@ class NaturalTime(BaseTime): """ Overview: Natural time object - + Interfaces: + ``__init__``, ``time`` Example: >>> from ding.utils.autolog.time_ctl import NaturalTime >>> time_ = NaturalTime() @@ -62,7 +65,8 @@ class TickTime(BaseTime): """ Overview: Tick time object - + Interfaces: + ``__init__``, ``step``, ``time`` Example: >>> from ding.utils.autolog.time_ctl import TickTime >>> time_ = TickTime() @@ -73,8 +77,8 @@ def __init__(self, init: int = 0): Overview: Constructor of TickTime - Args: - init (int, optional): init tick time, default is 1 + Arguments: + - init (:obj:`int`): initial time, default is 0 """ self.__tick_time = init @@ -83,11 +87,11 @@ def step(self, delta: int = 1) -> int: Overview Step the time forward for this TickTime - Args: - delta (int, optional): steps to step forward, default is 1 + Arguments: + - delta (:obj:`int`): steps to step forward, default is 1 Returns: - int: new time after stepping + - time (:obj:`int`): new time after stepping Example: >>> from ding.utils.autolog.time_ctl import TickTime @@ -128,7 +132,8 @@ class TimeProxy(BaseTime): Overview: Proxy of time object, it can freeze time, sometimes useful when reproducing. This object is thread-safe, and also freeze and unfreeze operation is strictly ordered. - + Interfaces: + ``__init__``, ``freeze``, ``unfreeze``, ``time``, ``current_time`` Example: >>> from ding.utils.autolog.time_ctl import TickTime, TimeProxy >>> tick_time_ = TickTime() @@ -150,10 +155,10 @@ def __init__(self, time_: BaseTime, frozen: bool = False, lock_type: LockContext Overview: Constructor for Time proxy - Args: - time_ (BaseTime): another time object it based on - frozen (bool, optional): this object will be frozen immediately if true, otherwise not, default is False - lock_type (LockContextType, optional): type of the lock, default is THREAD_LOCK + Arguments: + - time_ (:obj:`BaseTime`): another time object it based on + - frozen (:obj:`bool`): this object will be frozen immediately if true, otherwise not, default is False + - lock_type (:obj:`LockContextType`): type of the lock, default is THREAD_LOCK """ self.__time = time_ self.__current_time = self.__time.time() diff --git a/ding/utils/autolog/value.py b/ding/utils/autolog/value.py index 07b7609b33..98510a036a 100644 --- a/ding/utils/autolog/value.py +++ b/ding/utils/autolog/value.py @@ -11,22 +11,61 @@ class LoggedValue: This class's instances will be associated with their owner LoggedModel instance, all the LoggedValue of one LoggedModel will shared the only one time object (defined in time_ctl), so that timeline can be managed properly. + Interfaces: + ``__init__``, ``__get__``, ``__set__`` + Properties: + - __property_name (:obj:`str`): The name of the property. """ def __init__(self, type_: Type[_ValueType] = object): + """ + Overview: + Initialize the LoggedValue object. + Interfaces: + ``__init__`` + """ + self.__type = type_ @property def __property_name(self): + """ + Overview: + Get the name of the property. + """ + return getattr(self, _LOGGED_VALUE__PROPERTY_NAME) def __get_ranged_data(self, instance) -> TimeRangedData: + """ + Overview: + Get the ranged data. + Interfaces: + ``__get_ranged_data`` + """ + return getattr(instance, _LOGGED_MODEL__PROPERTY_ATTR_PREFIX + self.__property_name) def __get__(self, instance, owner): + """ + Overview: + Get the value. + Arguments: + - instance (:obj:`LoggedModel`): The owner LoggedModel instance. + - owner (:obj:`type`): The owner class. + """ + return self.__get_ranged_data(instance).current() def __set__(self, instance, value: _ValueType): + """ + Overview: + Set the value. + Arguments: + - instance (:obj:`LoggedModel`): The owner LoggedModel instance. + - value (:obj:`_ValueType`): The value to set. + """ + if isinstance(value, self.__type): return self.__get_ranged_data(instance).append(value) else: diff --git a/ding/utils/collection_helper.py b/ding/utils/collection_helper.py index ea8fa715a8..7c8caed6b4 100644 --- a/ding/utils/collection_helper.py +++ b/ding/utils/collection_helper.py @@ -5,7 +5,7 @@ def iter_mapping(iter_: Iterable[_IterType], mapping: Callable[[_IterType], _IterTargetType]): - r""" + """ Overview: Map a list of iterable elements to input iteration callable Arguments: diff --git a/ding/utils/compression_helper.py b/ding/utils/compression_helper.py index 5cfa00de71..71eeef25b0 100644 --- a/ding/utils/compression_helper.py +++ b/ding/utils/compression_helper.py @@ -10,7 +10,7 @@ class CloudPickleWrapper: Overview: CloudPickleWrapper can be able to pickle more python object(e.g: an object with lambda expression). Interfaces: - __init__. + ``__init__``, ``__getstate__``, ``__setstate__``. """ def __init__(self, data: Any) -> None: @@ -23,9 +23,23 @@ def __init__(self, data: Any) -> None: self.data = data def __getstate__(self) -> bytes: + """ + Overview: + Get the state of the CloudPickleWrapper. + Returns: + - data (:obj:`bytes`): The dumped byte-like result. + """ + return cloudpickle.dumps(self.data) def __setstate__(self, data: bytes) -> None: + """ + Overview: + Set the state of the CloudPickleWrapper. + Arguments: + - data (:obj:`bytes`): The dumped byte-like result. + """ + if isinstance(data, (tuple, list, np.ndarray)): # pickle is faster self.data = pickle.loads(data) else: @@ -60,7 +74,7 @@ def zlib_data_compressor(data: Any) -> bytes: def lz4_data_compressor(data: Any) -> bytes: - r""" + """ Overview: Return the compressed original data (lz4 compressor).The compressor outputs in binary format. Arguments: diff --git a/ding/utils/data/base_dataloader.py b/ding/utils/data/base_dataloader.py index da93d8f24d..d19bd9fcde 100644 --- a/ding/utils/data/base_dataloader.py +++ b/ding/utils/data/base_dataloader.py @@ -4,7 +4,10 @@ def example_get_data_fn() -> Any: """ - Note: staticmethod or static function, all the operation is on CPU + Overview: + Get data from file or other middleware + .. note:: + staticmethod or static function, all the operation is on CPU """ # 1. read data from file or other middleware # 2. data post-processing(e.g.: normalization, to tensor) @@ -13,12 +16,20 @@ def example_get_data_fn() -> Any: class IDataLoader: + """ + Overview: + Base class of data loader + Interfaces: + ``__init__``, ``__next__``, ``__iter__``, ``_get_data``, ``close`` + """ def __next__(self, batch_size: Optional[int] = None) -> torch.Tensor: """ + Overview: + Get one batch data Arguments: - batch_size: sometimes, batch_size is specified by each iteration, if batch_size is None, - use default batch_size value + - batch_size (:obj:`Optional[int]`): sometimes, batch_size is specified by each iteration, \ + if batch_size is None, use default batch_size value """ # get one batch train data if batch_size is None: @@ -27,11 +38,29 @@ def __next__(self, batch_size: Optional[int] = None) -> torch.Tensor: return self._collate_fn(data) def __iter__(self) -> Iterable: + """ + Overview: + Get data iterator + """ + return self def _get_data(self, batch_size: Optional[int] = None) -> List[torch.Tensor]: + """ + Overview: + Get one batch data + Arguments: + - batch_size (:obj:`Optional[int]`): sometimes, batch_size is specified by each iteration, \ + if batch_size is None, use default batch_size value + """ + raise NotImplementedError def close(self) -> None: + """ + Overview: + Close data loader + """ + # release resource pass diff --git a/ding/utils/data/collate_fn.py b/ding/utils/data/collate_fn.py index 70afcd18aa..5397a9c450 100644 --- a/ding/utils/data/collate_fn.py +++ b/ding/utils/data/collate_fn.py @@ -176,7 +176,7 @@ def timestep_collate(batch: List[Dict[str, Any]]) -> Dict[str, Union[torch.Tenso Each timestepped data field is represented as a tensor with shape [T, B, any_dims], where T is the length \ of the sequence, B is the batch size, and any_dims represents the shape of the tensor at each timestep. - Args: + Arguments: - batch(:obj:`List[Dict[str, Any]]`): A list of dictionaries with length B, where each dictionary represents \ a timestepped data field. Each dictionary contains a key-value pair, where the key is the name of the \ data field and the value is a sequence of torch.Tensor objects with any shape. diff --git a/ding/utils/data/dataloader.py b/ding/utils/data/dataloader.py index 0670d0db7f..0b6ffeec83 100644 --- a/ding/utils/data/dataloader.py +++ b/ding/utils/data/dataloader.py @@ -13,11 +13,12 @@ class AsyncDataLoader(IDataLoader): - r""" + """ Overview: An asynchronous dataloader. - Interface: - __init__, __iter__, __next__, close + Interfaces: + ``__init__``, ``__iter__``, ``__next__``, ``_get_data``, ``_async_loop``, ``_worker_loop``, ``_cuda_loop``, \ + ``_get_data``, ``close`` """ def __init__( @@ -338,6 +339,10 @@ def __next__(self) -> Any: self.async_train_queue.join_thread() def __del__(self) -> None: + """ + Overview: + Delete this dataloader. + """ self.close() def close(self) -> None: diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index c38d0e0a8a..ad9c645bae 100755 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -17,6 +17,10 @@ @dataclass class DatasetStatistics: + """ + Overview: + Dataset statistics. + """ mean: np.ndarray # obs std: np.ndarray # obs action_bounds: np.ndarray @@ -24,8 +28,21 @@ class DatasetStatistics: @DATASET_REGISTRY.register('naive') class NaiveRLDataset(Dataset): + """ + Overview: + Naive RL dataset, which is used for offline RL algorithms. + Interfaces: + ``__init__``, ``__len__``, ``__getitem__`` + """ def __init__(self, cfg) -> None: + """ + Overview: + Initialization method. + Arguments: + - cfg (:obj:`dict`): Config dict. + """ + assert type(cfg) in [str, EasyDict], "invalid cfg type: {}".format(type(cfg)) if isinstance(cfg, EasyDict): self._data_path = cfg.policy.collect.data_path @@ -35,16 +52,44 @@ def __init__(self, cfg) -> None: self._data: List[Dict[str, torch.Tensor]] = pickle.load(f) def __len__(self) -> int: + """ + Overview: + Get the length of the dataset. + """ + return len(self._data) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """ + Overview: + Get the item of the dataset. + """ + return self._data[idx] @DATASET_REGISTRY.register('d4rl') class D4RLDataset(Dataset): + """ + Overview: + D4RL dataset, which is used for offline RL algorithms. + Interfaces: + ``__init__``, ``__len__``, ``__getitem__`` + Properties: + - mean (:obj:`np.ndarray`): Mean of the dataset. + - std (:obj:`np.ndarray`): Std of the dataset. + - action_bounds (:obj:`np.ndarray`): Action bounds of the dataset. + - statistics (:obj:`dict`): Statistics of the dataset. + """ def __init__(self, cfg: dict) -> None: + """ + Overview: + Initialization method. + Arguments: + - cfg (:obj:`dict`): Config dict. + """ + import gym try: import d4rl # register d4rl enviroments with open ai gym @@ -73,12 +118,29 @@ def __init__(self, cfg: dict) -> None: self._load_d4rl(dataset) def __len__(self) -> int: + """ + Overview: + Get the length of the dataset. + """ + return len(self._data) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """ + Overview: + Get the item of the dataset. + """ + return self._data[idx] def _load_d4rl(self, dataset: Dict[str, np.ndarray]) -> None: + """ + Overview: + Load the d4rl dataset. + Arguments: + - dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset. + """ + for i in range(len(dataset['observations'])): trans_data = {} trans_data['obs'] = torch.from_numpy(dataset['observations'][i]) @@ -89,6 +151,15 @@ def _load_d4rl(self, dataset: Dict[str, np.ndarray]) -> None: self._data.append(trans_data) def _cal_statistics(self, dataset, env, eps=1e-3, add_action_buffer=True): + """ + Overview: + Calculate the statistics of the dataset. + Arguments: + - dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset. + - env (:obj:`gym.Env`): The environment. + - eps (:obj:`float`): Epsilon. + """ + self._mean = dataset['observations'].mean(0) self._std = dataset['observations'].std(0) + eps action_max = dataset['actions'].max(0) @@ -100,31 +171,78 @@ def _cal_statistics(self, dataset, env, eps=1e-3, add_action_buffer=True): self._action_bounds = np.stack([action_min, action_max], axis=0) def _normalize_states(self, dataset): + """ + Overview: + Normalize the states. + Arguments: + - dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset. + """ + dataset['observations'] = (dataset['observations'] - self._mean) / self._std dataset['next_observations'] = (dataset['next_observations'] - self._mean) / self._std return dataset @property def mean(self): + """ + Overview: + Get the mean of the dataset. + """ + return self._mean @property def std(self): + """ + Overview: + Get the std of the dataset. + """ + return self._std @property def action_bounds(self) -> np.ndarray: + """ + Overview: + Get the action bounds of the dataset. + """ + return self._action_bounds @property def statistics(self) -> dict: + """ + Overview: + Get the statistics of the dataset. + """ + return DatasetStatistics(mean=self.mean, std=self.std, action_bounds=self.action_bounds) @DATASET_REGISTRY.register('hdf5') class HDF5Dataset(Dataset): + """ + Overview: + HDF5 dataset is saved in hdf5 format, which is used for offline RL algorithms. + The hdf5 format is a common format for storing large numerical arrays in Python. + For more details, please refer to https://support.hdfgroup.org/HDF5/. + Interfaces: + ``__init__``, ``__len__``, ``__getitem__`` + Properties: + - mean (:obj:`np.ndarray`): Mean of the dataset. + - std (:obj:`np.ndarray`): Std of the dataset. + - action_bounds (:obj:`np.ndarray`): Action bounds of the dataset. + - statistics (:obj:`dict`): Statistics of the dataset. + """ def __init__(self, cfg: dict) -> None: + """ + Overview: + Initialization method. + Arguments: + - cfg (:obj:`dict`): Config dict. + """ + try: import h5py except ImportError: @@ -147,9 +265,21 @@ def __init__(self, cfg: dict) -> None: pass def __len__(self) -> int: + """ + Overview: + Get the length of the dataset. + """ + return len(self._data['obs']) - self.context_len def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """ + Overview: + Get the item of the dataset. + Arguments: + - idx (:obj:`int`): The index of the dataset. + """ + if self.context_len == 0: # for other offline RL algorithms return {k: self._data[k][idx] for k in self._data.keys()} else: # for decision transformer @@ -166,12 +296,26 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: return timesteps, states, actions, rtgs, traj_mask def _load_data(self, dataset: Dict[str, np.ndarray]) -> None: + """ + Overview: + Load the dataset. + Arguments: + - dataset (:obj:`Dict[str, np.ndarray]`): The dataset. + """ + self._data = {} for k in dataset.keys(): logging.info(f'Load {k} data.') self._data[k] = dataset[k][:] - def _cal_statistics(self, eps=1e-3): + def _cal_statistics(self, eps: float = 1e-3): + """ + Overview: + Calculate the statistics of the dataset. + Arguments: + - eps (:obj:`float`): Epsilon. + """ + self._mean = self._data['obs'].mean(0) self._std = self._data['obs'].std(0) + eps action_max = self._data['action'].max(0) @@ -182,28 +326,59 @@ def _cal_statistics(self, eps=1e-3): self._action_bounds = np.stack([action_min, action_max], axis=0) def _normalize_states(self): + """ + Overview: + Normalize the states. + """ + self._data['obs'] = (self._data['obs'] - self._mean) / self._std self._data['next_obs'] = (self._data['next_obs'] - self._mean) / self._std @property def mean(self): + """ + Overview: + Get the mean of the dataset. + """ + return self._mean @property def std(self): + """ + Overview: + Get the std of the dataset. + """ + return self._std @property def action_bounds(self) -> np.ndarray: + """ + Overview: + Get the action bounds of the dataset. + """ + return self._action_bounds @property def statistics(self) -> dict: + """ + Overview: + Get the statistics of the dataset. + """ + return DatasetStatistics(mean=self.mean, std=self.std, action_bounds=self.action_bounds) @DATASET_REGISTRY.register('d4rl_trajectory') class D4RLTrajectoryDataset(Dataset): + """ + Overview: + D4RL trajectory dataset, which is used for offline RL algorithms. + Interfaces: + ``__init__``, ``__len__``, ``__getitem__`` + """ # from infos.py from official d4rl github repo REF_MIN_SCORE = { @@ -346,6 +521,13 @@ class D4RLTrajectoryDataset(Dataset): } def __init__(self, cfg: dict) -> None: + """ + Overview: + Initialization method. + Arguments: + - cfg (:obj:`dict`): Config dict. + """ + dataset_path = cfg.dataset.data_dir_prefix rtg_scale = cfg.dataset.rtg_scale self.context_len = cfg.dataset.context_len @@ -549,21 +731,50 @@ def __init__(self, cfg: dict) -> None: # return obss, actions, returns, done_idxs, rtg, timesteps def get_max_timestep(self) -> int: + """ + Overview: + Get the max timestep of the dataset. + """ + return max(self.timesteps) def get_state_stats(self) -> Tuple[np.ndarray, np.ndarray]: + """ + Overview: + Get the state mean and std of the dataset. + """ + return deepcopy(self.state_mean), deepcopy(self.state_std) def get_d4rl_dataset_stats(self, env_d4rl_name: str) -> Dict[str, list]: + """ + Overview: + Get the d4rl dataset stats. + Arguments: + - env_d4rl_name (:obj:`str`): The d4rl env name. + """ + return self.D4RL_DATASET_STATS[env_d4rl_name] def __len__(self) -> int: + """ + Overview: + Get the length of the dataset. + """ + if self.env_type != 'atari': return len(self.trajectories) else: return len(self.obss) - self.context_len def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Overview: + Get the item of the dataset. + Arguments: + - idx (:obj:`int`): The index of the dataset. + """ + if self.env_type != 'atari': traj = self.trajectories[idx] traj_len = traj['observations'].shape[0] @@ -631,8 +842,22 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tenso @DATASET_REGISTRY.register('d4rl_diffuser') class D4RLDiffuserDataset(Dataset): + """ + Overview: + D4RL diffuser dataset, which is used for offline RL algorithms. + Interfaces: + ``__init__``, ``__len__``, ``__getitem__`` + """ def __init__(self, dataset_path: str, context_len: int, rtg_scale: float) -> None: + """ + Overview: + Initialization method of D4RLDiffuserDataset. + Arguments: + - dataset_path (:obj:`str`): The dataset path. + - context_len (:obj:`int`): The length of the context. + - rtg_scale (:obj:`float`): The scale of the returns to go. + """ self.context_len = context_len @@ -677,17 +902,26 @@ def __init__(self, dataset_path: str, context_len: int, rtg_scale: float) -> Non class FixedReplayBuffer(object): - """Object composed of a list of OutofGraphReplayBuffers.""" + """ + Overview: + Object composed of a list of OutofGraphReplayBuffers. + Interfaces: + ``__init__``, ``get_transition_elements``, ``sample_transition_batch`` + """ def __init__(self, data_dir, replay_suffix, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg - """Initialize the FixedReplayBuffer class. - Args: - data_dir: str, log Directory from which to load the replay buffer. - replay_suffix: int, If not None, then only load the replay buffer - corresponding to the specific suffix in data directory. - *args: Arbitrary extra arguments. - **kwargs: Arbitrary keyword arguments. - """ + """ + Overview: + Initialize the FixedReplayBuffer class. + Arguments: + - data_dir (:obj:`str`): log Directory from which to load the replay buffer. + - replay_suffix (:obj:`int`): If not None, then only load the replay buffer \ + corresponding to the specific suffix in data directory. + - args (:obj:`list`): Arbitrary extra arguments. + - kwargs (:obj:`dict`): Arbitrary keyword arguments. + + """ + self._args = args self._kwargs = kwargs self._data_dir = data_dir @@ -703,7 +937,13 @@ def __init__(self, data_dir, replay_suffix, *args, **kwargs): # pylint: disable # self._load_replay_buffers(num_buffers=50) def load_single_buffer(self, suffix): - """Load a single replay buffer.""" + """ + Overview: + Load a single replay buffer. + Arguments: + - suffix (:obj:`int`): The suffix of the replay buffer. + """ + replay_buffer = self._load_buffer(suffix) if replay_buffer is not None: self._replay_buffers = [replay_buffer] @@ -712,7 +952,13 @@ def load_single_buffer(self, suffix): self._loaded_buffers = True def _load_buffer(self, suffix): - """Loads a OutOfGraphReplayBuffer replay buffer.""" + """ + Overview: + Loads a OutOfGraphReplayBuffer replay buffer. + Arguments: + - suffix (:obj:`int`): The suffix of the replay buffer. + """ + try: from dopamine.replay_memory import circular_replay_buffer STORE_FILENAME_PREFIX = circular_replay_buffer.STORE_FILENAME_PREFIX @@ -726,26 +972,72 @@ def _load_buffer(self, suffix): raise ('can not load') def get_transition_elements(self): + """ + Overview: + Returns the transition elements. + """ + return self._replay_buffers[0].get_transition_elements() def sample_transition_batch(self, batch_size=None, indices=None): + """ + Overview: + Returns a batch of transitions (including any extra contents). + Arguments: + - batch_size (:obj:`int`): The batch size. + - indices (:obj:`list`): The indices of the batch. + """ + buffer_index = np.random.randint(self._num_replay_buffers) return self._replay_buffers[buffer_index].sample_transition_batch(batch_size=batch_size, indices=indices) class PCDataset(Dataset): + """ + Overview: + Dataset for Procedure Cloning. + Interfaces: + ``__init__``, ``__len__``, ``__getitem__`` + """ def __init__(self, all_data): + """ + Overview: + Initialization method of PCDataset. + Arguments: + - all_data (:obj:`tuple`): The tuple of all data. + """ + self._data = all_data def __getitem__(self, item): + """ + Overview: + Get the item of the dataset. + Arguments: + - item (:obj:`int`): The index of the dataset. + """ + return {'obs': self._data[0][item], 'bfs_in': self._data[1][item], 'bfs_out': self._data[2][item]} def __len__(self): + """ + Overview: + Get the length of the dataset. + """ + return self._data[0].shape[0] def load_bfs_datasets(train_seeds=1, test_seeds=5): + """ + Overview: + Load BFS datasets. + Arguments: + - train_seeds (:obj:`int`): The number of train seeds. + - test_seeds (:obj:`int`): The number of test seeds. + """ + from dizoo.maze.envs import Maze def load_env(seed): @@ -807,32 +1099,83 @@ def load_env(seed): @DATASET_REGISTRY.register('bco') class BCODataset(Dataset): + """ + Overview: + Dataset for Behavioral Cloning from Observation. + Interfaces: + ``__init__``, ``__len__``, ``__getitem__`` + Properties: + - obs (:obj:`np.ndarray`): The observation array. + - action (:obj:`np.ndarray`): The action array. + """ def __init__(self, data=None): + """ + Overview: + Initialization method of BCODataset. + Arguments: + - data (:obj:`dict`): The data dict. + """ + if data is None: raise ValueError('Dataset can not be empty!') else: self._data = data def __len__(self): + """ + Overview: + Get the length of the dataset. + """ + return len(self._data['obs']) def __getitem__(self, idx): + """ + Overview: + Get the item of the dataset. + Arguments: + - idx (:obj:`int`): The index of the dataset. + """ + return {k: self._data[k][idx] for k in self._data.keys()} @property def obs(self): + """ + Overview: + Get the observation array. + """ + return self._data['obs'] @property def action(self): + """ + Overview: + Get the action array. + """ + return self._data['action'] @DATASET_REGISTRY.register('diffuser_traj') class SequenceDataset(torch.utils.data.Dataset): + """ + Overview: + Dataset for diffuser. + Interfaces: + ``__init__``, ``__len__``, ``__getitem__`` + """ def __init__(self, cfg): + """ + Overview: + Initialization method of SequenceDataset. + Arguments: + - cfg (:obj:`dict`): The config dict. + """ + import gym env_id = cfg.env.env_id @@ -896,6 +1239,13 @@ def __init__(self, cfg): # print(f'[ datasets/mujoco ] Dataset fields: {shapes}') def sequence_dataset(self, env, dataset=None): + """ + Overview: + Sequence the dataset. + Arguments: + - env (:obj:`gym.Env`): The gym env. + """ + import collections N = dataset['rewards'].shape[0] if 'maze2d' in env.spec.id: @@ -932,6 +1282,14 @@ def sequence_dataset(self, env, dataset=None): episode_step += 1 def maze2d_set_terminals(self, env, dataset): + """ + Overview: + Set the terminals for maze2d. + Arguments: + - env (:obj:`gym.Env`): The gym env. + - dataset (:obj:`dict`): The dataset dict. + """ + goal = env.get_target() threshold = 0.5 @@ -957,9 +1315,13 @@ def maze2d_set_terminals(self, env, dataset): return dataset def process_maze2d_episode(self, episode): - ''' - adds in `next_observations` field to episode - ''' + """ + Overview: + Process the maze2d episode, adds in `next_observations` field to episode. + Arguments: + - episode (:obj:`dict`): The episode dict. + """ + assert 'next_observations' not in episode length = len(episode['observations']) next_observations = episode['observations'][1:].copy() @@ -969,19 +1331,27 @@ def process_maze2d_episode(self, episode): return episode def normalize(self, keys=['observations', 'actions']): - ''' - normalize fields that will be predicted by the diffusion model - ''' + """ + Overview: + Normalize the dataset, normalize fields that will be predicted by the diffusion model + Arguments: + - keys (:obj:`list`): The list of keys. + """ + for key in keys: array = self.fields[key].reshape(self.n_episodes * self.max_path_length, -1) normed = self.normalizer.normalize(array, key) self.fields[f'normed_{key}'] = normed.reshape(self.n_episodes, self.max_path_length, -1) def make_indices(self, path_lengths, horizon): - ''' - makes indices for sampling from dataset; - each index maps to a datapoint - ''' + """ + Overview: + Make indices for sampling from dataset. Each index maps to a datapoint. + Arguments: + - path_lengths (:obj:`np.ndarray`): The path length array. + - horizon (:obj:`int`): The horizon. + """ + indices = [] for i, path_length in enumerate(path_lengths): max_start = min(path_length - 1, self.max_path_length - horizon) @@ -994,18 +1364,32 @@ def make_indices(self, path_lengths, horizon): return indices def get_conditions(self, observations): - ''' - condition on current observation for planning - ''' + """ + Overview: + Get the conditions on current observation for planning. + Arguments: + - observations (:obj:`np.ndarray`): The observation array. + """ + if 'maze2d' in self.env_id: return {'condition_id': [0, self.horizon - 1], 'condition_val': [observations[0], observations[-1]]} else: return {'condition_id': [0], 'condition_val': [observations[0]]} def __len__(self): + """ + Overview: + Get the length of the dataset. + """ + return len(self.indices) def _get_bounds(self): + """ + Overview: + Get the bounds of the dataset. + """ + print('[ datasets/sequence ] Getting value dataset bounds...', end=' ', flush=True) vmin = np.inf vmax = -np.inf @@ -1017,6 +1401,13 @@ def _get_bounds(self): return vmin, vmax def normalize_value(self, value): + """ + Overview: + Normalize the value. + Arguments: + - value (:obj:`np.ndarray`): The value array. + """ + # [0, 1] normed = (value - self.vmin) / (self.vmax - self.vmin) # [-1, 1] @@ -1024,6 +1415,14 @@ def normalize_value(self, value): return normed def __getitem__(self, idx, eps=1e-4): + """ + Overview: + Get the item of the dataset. + Arguments: + - idx (:obj:`int`): The index of the dataset. + - eps (:obj:`float`): The epsilon. + """ + path_ind, start, end = self.indices[idx] observations = self.fields['normed_observations'][path_ind, start:end] @@ -1058,6 +1457,11 @@ def __getitem__(self, idx, eps=1e-4): def hdf5_save(exp_data, expert_data_path): + """ + Overview: + Save the data to hdf5. + """ + try: import h5py except ImportError: @@ -1073,15 +1477,30 @@ def hdf5_save(exp_data, expert_data_path): def naive_save(exp_data, expert_data_path): + """ + Overview: + Save the data to pickle. + """ + with open(expert_data_path, 'wb') as f: pickle.dump(exp_data, f) def offline_data_save_type(exp_data, expert_data_path, data_type='naive'): + """ + Overview: + Save the offline data. + """ + globals()[data_type + '_save'](exp_data, expert_data_path) def create_dataset(cfg, **kwargs) -> Dataset: + """ + Overview: + Create dataset. + """ + cfg = EasyDict(cfg) import_module(cfg.get('import_names', [])) return DATASET_REGISTRY.build(cfg.policy.collect.data_type, cfg=cfg, **kwargs) diff --git a/ding/utils/data/structure/cache.py b/ding/utils/data/structure/cache.py index a7235220cc..836261e615 100644 --- a/ding/utils/data/structure/cache.py +++ b/ding/utils/data/structure/cache.py @@ -7,17 +7,17 @@ class Cache: - r""" + """ Overview: Data cache for reducing concurrent pressure, with timeout and full queue eject mechanism - Interface: - __init__, push_data, get_cached_data_iter, run, close + Interfaces: + ``__init__``, ``push_data``, ``get_cached_data_iter``, ``run``, ``close`` Property: remain_data_count """ def __init__(self, maxlen: int, timeout: float, monitor_interval: float = 1.0, _debug: bool = False) -> None: - r""" + """ Overview: Initialize the cache object. Arguments: @@ -40,7 +40,7 @@ def __init__(self, maxlen: int, timeout: float, monitor_interval: float = 1.0, _ self._timeout_thread_flag = True def push_data(self, data: Any) -> None: - r""" + """ Overview: Push data into receive queue, if the receive queue is full(after push), then push all the data in receive queue into send queue. @@ -60,7 +60,7 @@ def push_data(self, data: Any) -> None: self.send_queue.put(self.receive_queue.get()[0]) def get_cached_data_iter(self) -> 'callable_iterator': # noqa - r""" + """ Overview: Get the iterator of the send queue. Once a data is pushed into send queue, it can be accessed by this iterator. 'STOP' is the end flag of this iterator. @@ -70,7 +70,7 @@ def get_cached_data_iter(self) -> 'callable_iterator': # noqa return iter(self.send_queue.get, 'STOP') def _timeout_monitor(self) -> None: - r""" + """ Overview: The workflow of the timeout monitor thread. """ @@ -88,7 +88,7 @@ def _timeout_monitor(self) -> None: break def _warn_if_timeout(self) -> bool: - r""" + """ Overview: Return whether is timeout. Returns @@ -107,14 +107,14 @@ def _warn_if_timeout(self) -> bool: return False def run(self) -> None: - r""" + """ Overview: Launch the cache internal thread, e.g. timeout monitor thread. """ self._timeout_thread.start() def close(self) -> None: - r""" + """ Overview: Shut down the cache internal thread and send the end flag to send queue's iterator. """ @@ -122,7 +122,7 @@ def close(self) -> None: self.send_queue.put('STOP') def dprint(self, s: str) -> None: - r""" + """ Overview: In debug mode, print debug str. Arguments: @@ -133,7 +133,7 @@ def dprint(self, s: str) -> None: @property def remain_data_count(self) -> int: - r""" + """ Overview: Return receive queue's remain data count Returns: diff --git a/ding/utils/data/structure/lifo_deque.py b/ding/utils/data/structure/lifo_deque.py index 00d9221e5c..b18c4a0608 100644 --- a/ding/utils/data/structure/lifo_deque.py +++ b/ding/utils/data/structure/lifo_deque.py @@ -4,7 +4,10 @@ class LifoDeque(LifoQueue): """ - Like LifoQueue, but automatically replaces the oldest data when the queue is full. + Overview: + Like LifoQueue, but automatically replaces the oldest data when the queue is full. + Interfaces: + ``_init``, ``_put``, ``_get`` """ def _init(self, maxsize): diff --git a/ding/utils/default_helper.py b/ding/utils/default_helper.py index d76b6936f3..1881ca6cc0 100644 --- a/ding/utils/default_helper.py +++ b/ding/utils/default_helper.py @@ -42,7 +42,7 @@ def lists_to_dicts( data: Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]], recursive: bool = False, ) -> Union[Mapping[object, object], NamedTuple]: - r""" + """ Overview: Transform a list of dicts to a dict of lists. Arguments: @@ -77,7 +77,7 @@ def lists_to_dicts( def dicts_to_lists(data: Mapping[object, List[object]]) -> List[Mapping[object, object]]: - r""" + """ Overview: Transform a dict of lists to a list of dicts. @@ -121,6 +121,8 @@ def squeeze(data: object) -> object: """ Overview: Squeeze data from tuple, list or dict to single object + Arguments: + - data (:obj:`object`): data to be squeezed Example: >>> a = (4, ) >>> a = squeeze(a) @@ -148,7 +150,7 @@ def default_get( default_fn: Optional[Callable] = None, judge_fn: Optional[Callable] = None ) -> Any: - r""" + """ Overview: Getting the value by input, checks generically on the inputs with \ at least ``data`` and ``name``. If ``name`` exists in ``data``, \ @@ -180,7 +182,7 @@ def default_get( def list_split(data: list, step: int) -> List[list]: - r""" + """ Overview: Split list of data by step. Arguments: @@ -210,7 +212,7 @@ def list_split(data: list, step: int) -> List[list]: def error_wrapper(fn, default_ret, warning_msg=""): - r""" + """ Overview: wrap the function, so that any Exception in the function will be catched and return the default_ret Arguments: @@ -239,10 +241,10 @@ def wrapper(*args, **kwargs): class LimitedSpaceContainer: - r""" + """ Overview: A space simulator. - Interface: + Interfaces: ``__init__``, ``get_residual_space``, ``release_space`` """ @@ -438,10 +440,27 @@ def set_pkg_seed(seed: int, use_cuda: bool = True) -> None: @lru_cache() def one_time_warning(warning_msg: str) -> None: + """ + Overview: + Print warning message only once. + Arguments: + - warning_msg (:obj:`str`): Warning message. + """ + logging.warning(warning_msg) def split_fn(data, indices, start, end): + """ + Overview: + Split data by indices + Arguments: + - data (:obj:`Union[List, Dict, torch.Tensor, ttorch.Tensor]`): data to be analysed + - indices (:obj:`np.ndarray`): indices to split + - start (:obj:`int`): start index + - end (:obj:`int`): end index + """ + if data is None: return None elif isinstance(data, list): @@ -455,6 +474,15 @@ def split_fn(data, indices, start, end): def split_data_generator(data: dict, split_size: int, shuffle: bool = True) -> dict: + """ + Overview: + Split data into batches + Arguments: + - data (:obj:`dict`): data to be analysed + - split_size (:obj:`int`): split size + - shuffle (:obj:`bool`): whether shuffle + """ + assert isinstance(data, dict), type(data) length = [] for k, v in data.items(): @@ -493,7 +521,7 @@ class RunningMeanStd(object): """ Overview: Wrapper to update new variable, new mean, and new count - Interface: + Interfaces: ``__init__``, ``update``, ``reset``, ``new_shape`` Properties: - ``mean``, ``std``, ``_epsilon``, ``_shape``, ``_mean``, ``_var``, ``_count`` diff --git a/ding/utils/design_helper.py b/ding/utils/design_helper.py index 0850728d1a..24805218ff 100644 --- a/ding/utils/design_helper.py +++ b/ding/utils/design_helper.py @@ -4,15 +4,20 @@ # ABCMeta is a subclass of type, extending ABCMeta makes this metaclass is compatible with some classes # which extends ABC class SingletonMetaclass(ABCMeta): - r""" + """ Overview: Returns the given type instance in input class - Interface: + Interfaces: ``__call__`` """ instances = {} def __call__(cls: type, *args, **kwargs) -> object: + """ + Overview: + Returns the given type instance in input class + """ + if cls not in SingletonMetaclass.instances: SingletonMetaclass.instances[cls] = super(SingletonMetaclass, cls).__call__(*args, **kwargs) cls.instance = SingletonMetaclass.instances[cls] diff --git a/ding/utils/fake_linklink.py b/ding/utils/fake_linklink.py index a8faf69e0a..5998030b36 100644 --- a/ding/utils/fake_linklink.py +++ b/ding/utils/fake_linklink.py @@ -2,16 +2,30 @@ class FakeClass: + """ + Overview: + Fake class. + """ def __init__(self, *args, **kwargs): pass class FakeNN: + """ + Overview: + Fake nn class. + """ + SyncBatchNorm2d = FakeClass class FakeLink: + """ + Overview: + Fake link class. + """ + nn = FakeNN() syncbnVarMode_t = namedtuple("syncbnVarMode_t", "L2")(L2=None) allreduceOp_t = namedtuple("allreduceOp_t", ['Sum', 'Max']) diff --git a/ding/utils/fast_copy.py b/ding/utils/fast_copy.py index 0f1d53fde6..cf4185ecbd 100644 --- a/ding/utils/fast_copy.py +++ b/ding/utils/fast_copy.py @@ -5,13 +5,21 @@ class _FastCopy: """ - The idea of this class comes from this article - https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list. - We use recursive calls to copy each object that needs to be copied, which will be 5x faster - than copy.deepcopy. + Overview: + The idea of this class comes from this article \ + https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list. + We use recursive calls to copy each object that needs to be copied, which will be 5x faster \ + than copy.deepcopy. + Interfaces: + ``__init__``, ``_copy_list``, ``_copy_dict``, ``_copy_tensor``, ``_copy_ndarray``, ``copy``. """ def __init__(self): + """ + Overview: + Initialize the _FastCopy object. + """ + dispatch = {} dispatch[list] = self._copy_list dispatch[dict] = self._copy_dict @@ -20,6 +28,13 @@ def __init__(self): self.dispatch = dispatch def _copy_list(self, l: List) -> dict: + """ + Overview: + Copy the list. + Arguments: + - l (:obj:`List`): The list to be copied. + """ + ret = l.copy() for idx, item in enumerate(ret): cp = self.dispatch.get(type(item)) @@ -28,6 +43,13 @@ def _copy_list(self, l: List) -> dict: return ret def _copy_dict(self, d: dict) -> dict: + """ + Overview: + Copy the dict. + Arguments: + - d (:obj:`dict`): The dict to be copied. + """ + ret = d.copy() for key, value in ret.items(): cp = self.dispatch.get(type(value)) @@ -37,12 +59,33 @@ def _copy_dict(self, d: dict) -> dict: return ret def _copy_tensor(self, t: torch.Tensor) -> torch.Tensor: + """ + Overview: + Copy the tensor. + Arguments: + - t (:obj:`torch.Tensor`): The tensor to be copied. + """ + return t.clone() def _copy_ndarray(self, a: np.ndarray) -> np.ndarray: + """ + Overview: + Copy the ndarray. + Arguments: + - a (:obj:`np.ndarray`): The ndarray to be copied. + """ + return np.copy(a) def copy(self, sth: Any) -> Any: + """ + Overview: + Copy the object. + Arguments: + - sth (:obj:`Any`): The object to be copied. + """ + cp = self.dispatch.get(type(sth)) if cp is None: return sth diff --git a/ding/utils/file_helper.py b/ding/utils/file_helper.py index 94a70d0925..b14c42de79 100644 --- a/ding/utils/file_helper.py +++ b/ding/utils/file_helper.py @@ -266,7 +266,7 @@ def save_file_rediscluster(path, data): def read_file(path: str, fs_type: Union[None, str] = None, use_lock: bool = False) -> object: - r""" + """ Overview: Read file from path Arguments: @@ -296,7 +296,7 @@ def read_file(path: str, fs_type: Union[None, str] = None, use_lock: bool = Fals def save_file(path: str, data: object, fs_type: Union[None, str] = None, use_lock: bool = False) -> None: - r""" + """ Overview: Save data to file of path Arguments: @@ -327,7 +327,7 @@ def save_file(path: str, data: object, fs_type: Union[None, str] = None, use_loc def remove_file(path: str, fs_type: Union[None, str] = None) -> None: - r""" + """ Overview: Remove file Arguments: diff --git a/ding/utils/k8s_helper.py b/ding/utils/k8s_helper.py index 9a3572a729..e30bba497d 100644 --- a/ding/utils/k8s_helper.py +++ b/ding/utils/k8s_helper.py @@ -20,14 +20,15 @@ def get_operator_server_kwargs(cfg: EasyDict) -> dict: - r''' + """ Overview: Get kwarg dict from config file Arguments: - cfg (:obj:`EasyDict`) System config Returns: - result (:obj:`dict`) Containing ``api_version``, ``namespace``, ``name``, ``port``, ``host``. - ''' + """ + namespace = os.environ.get('KUBERNETES_POD_NAMESPACE', DEFAULT_NAMESPACE) name = os.environ.get('KUBERNETES_POD_NAME', DEFAULT_POD_NAME) url = cfg.get('system_addr', None) or os.environ.get('KUBERNETES_SERVER_URL', None) @@ -49,10 +50,24 @@ def get_operator_server_kwargs(cfg: EasyDict) -> dict: def exist_operator_server() -> bool: + """ + Overview: + Check if the 'KUBERNETES_SERVER_URL' environment variable exists. + """ + return 'KUBERNETES_SERVER_URL' in os.environ def pod_exec_command(kubeconfig: str, name: str, namespace: str, cmd: str) -> Tuple[int, str]: + """ + Overview: + Execute command in pod + Arguments: + - kubeconfig (:obj:`str`) The path of kubeconfig file + - name (:obj:`str`) The name of pod + - namespace (:obj:`str`) The namespace of pod + """ + try: from kubernetes import config from kubernetes.client import CoreV1Api @@ -102,10 +117,20 @@ class K8sType(Enum): class K8sLauncher(object): """ - Overview: object to manage the K8s cluster + Overview: + object to manage the K8s cluster + Interfaces: + ``__init__``, ``_load``, ``create_cluster``, ``_check_k3d_tools``, ``delete_cluster``, ``preload_images`` """ def __init__(self, config_path: str) -> None: + """ + Overview: + Initialize the K8sLauncher object. + Arguments: + - config_path (:obj:`str`): The path of the config file. + """ + self.name = None self.servers = 1 self.agents = 0 @@ -116,6 +141,13 @@ def __init__(self, config_path: str) -> None: self._check_k3d_tools() def _load(self, config_path: str) -> None: + """ + Overview: + Load the config file. + Arguments: + - config_path (:obj:`str`): The path of the config file. + """ + with open(config_path, 'r') as f: data = yaml.safe_load(f) self.name = data.get('name') if data.get('name') else self.name @@ -140,6 +172,11 @@ def _load(self, config_path: str) -> None: self._images = data.get('preload_images') def _check_k3d_tools(self) -> None: + """ + Overview: + Check if the k3d tools exist. + """ + if self.type != K8sType.K3s: return args = ['which', 'k3d'] @@ -151,6 +188,11 @@ def _check_k3d_tools(self) -> None: ) def create_cluster(self) -> None: + """ + Overview: + Create the k8s cluster. + """ + print('Creating k8s cluster...') if self.type != K8sType.K3s: return @@ -168,6 +210,11 @@ def create_cluster(self) -> None: self.preload_images(self._images) def delete_cluster(self) -> None: + """ + Overview: + Delete the k8s cluster. + """ + print('Deleting k8s cluster...') if self.type != K8sType.K3s: return @@ -180,6 +227,11 @@ def delete_cluster(self) -> None: raise RuntimeError(f'Failed to delete cluster {self.name}: {err_str}') def preload_images(self, images: list) -> None: + """ + Overview: + Preload images. + """ + if self.type != K8sType.K3s or len(images) == 0: return args = ['k3d', 'image', 'import', f'--cluster={self.name}'] diff --git a/ding/utils/linklink_dist_helper.py b/ding/utils/linklink_dist_helper.py index a886c3cbff..36fffa19a0 100644 --- a/ding/utils/linklink_dist_helper.py +++ b/ding/utils/linklink_dist_helper.py @@ -20,7 +20,7 @@ def is_fake_link(): def get_rank() -> int: - r""" + """ Overview: Get the rank of ``linklink`` model, return 0 if use ``FakeLink``. @@ -33,7 +33,7 @@ def get_rank() -> int: def get_world_size() -> int: - r""" + """ Overview: Get the ``world_size`` of ``linklink model``, return 0 if use ``FakeLink``. @@ -46,7 +46,7 @@ def get_world_size() -> int: def broadcast(value: torch.Tensor, rank: int) -> None: - r""" + """ Overview: Use ``linklink.broadcast`` and raise error when using ``FakeLink`` Arguments: @@ -59,7 +59,7 @@ def broadcast(value: torch.Tensor, rank: int) -> None: def allreduce(data: torch.Tensor, op: str = 'sum') -> None: - r""" + """ Overview: Call ``linklink.allreduce`` on the data Arguments: @@ -79,7 +79,7 @@ def allreduce(data: torch.Tensor, op: str = 'sum') -> None: def allreduce_async(data: torch.Tensor, op: str = 'sum') -> None: - r""" + """ Overview: Call ``linklink.allreduce_async`` on the data Arguments: @@ -99,7 +99,7 @@ def allreduce_async(data: torch.Tensor, op: str = 'sum') -> None: def get_group(group_size: int) -> List: - r""" + """ Overview: Get the group segmentation of ``group_size`` each group Arguments: @@ -114,9 +114,11 @@ def get_group(group_size: int) -> List: def dist_mode(func: Callable) -> Callable: - r""" + """ Overview: Wrap the function so that in can init and finalize automatically before each call + Arguments: + - func (:obj:`Callable`): the function to wrap """ def wrapper(*args, **kwargs): @@ -128,7 +130,7 @@ def wrapper(*args, **kwargs): def dist_init(method: str = 'slurm', device_id: int = 0) -> Tuple[int, int]: - r""" + """ Overview: Init the distribution Arguments: @@ -152,7 +154,7 @@ def dist_init(method: str = 'slurm', device_id: int = 0) -> Tuple[int, int]: def dist_finalize() -> None: - r""" + """ Overview: Finalize ``linklink``, see ``linklink.finalize()`` """ @@ -160,25 +162,53 @@ def dist_finalize() -> None: class DistContext: + """ + Overview: + A context manager for ``linklink`` distribution + Interfaces: + ``__init__``, ``__enter__``, ``__exit__`` + """ def __init__(self) -> None: + """ + Overview: + Initialize the ``DistContext`` + """ + pass def __enter__(self) -> None: + """ + Overview: + Initialize ``linklink`` distribution + """ + dist_init() def __exit__(self, *args, **kwargs) -> Any: + """ + Overview: + Finalize ``linklink`` distribution + Arugments: + - args (:obj:`Tuple`): The arguments passed to the ``__exit__`` function. + - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__exit__`` function. + """ + dist_finalize() def simple_group_split(world_size: int, rank: int, num_groups: int) -> List: - r""" + """ Overview: Split the group according to ``worldsize``, ``rank`` and ``num_groups`` - + Arguments: + - world_size (:obj:`int`): The world size + - rank (:obj:`int`): The rank + - num_groups (:obj:`int`): The number of groups .. note:: With faulty input, raise ``array split does not result in an equal division`` """ + groups = [] rank_list = np.split(np.arange(world_size), num_groups) rank_list = [list(map(int, x)) for x in rank_list] @@ -189,4 +219,9 @@ def simple_group_split(world_size: int, rank: int, num_groups: int) -> List: def synchronize(): + """ + Overview: + Synchronize the process + """ + get_link().synchronize() diff --git a/ding/utils/loader/base.py b/ding/utils/loader/base.py index cd6f9ec390..fd55bc8b62 100644 --- a/ding/utils/loader/base.py +++ b/ding/utils/loader/base.py @@ -6,6 +6,13 @@ def _to_exception(exception) -> Callable[[Any], Exception]: + """ + Overview: + Convert exception to callable exception. + Arguments: + - exception (:obj:`Exception`): The exception to be converted. + """ + if hasattr(exception, '__call__'): return exception elif isinstance(exception, Exception): @@ -21,6 +28,13 @@ def _to_exception(exception) -> Callable[[Any], Exception]: def _to_loader(value) -> 'ILoaderClass': + """ + Overview: + Convert value to loader. + Arguments: + - value (:obj:`Any`): The value to be converted. + """ + if isinstance(value, ILoaderClass): return value elif isinstance(value, tuple): @@ -78,6 +92,11 @@ def _load(self, value_): def _reset_exception(loader, eg: Callable[[Any, Exception], Exception]): + """ + Overview: + Reset exception of loader. + """ + loader = Loader(loader) def _load(value): @@ -90,15 +109,42 @@ def _load(value): class ILoaderClass: + """ + Overview: + Base class of loader. + Interfaces: + ``__init__``, ``_load``, ``load``, ``check``, ``__call__``, ``__and__``, ``__or__``, ``__rshift__`` + """ @abstractmethod def _load(self, value: _ValueType) -> _ValueType: + """ + Overview: + Load the value. + Arguments: + - value (:obj:`_ValueType`): The value to be loaded. + """ + raise NotImplementedError def __load(self, value: _ValueType) -> _ValueType: + """ + Overview: + Load the value. + Arguments: + - value (:obj:`_ValueType`): The value to be loaded. + """ + return self._load(value) def __check(self, value: _ValueType) -> bool: + """ + Overview: + Check whether the value is valid. + Arguments: + - value (:obj:`_ValueType`): The value to be checked. + """ + try: self._load(value) except CAPTURE_EXCEPTIONS: @@ -107,15 +153,42 @@ def __check(self, value: _ValueType) -> bool: return True def load(self, value: _ValueType) -> _ValueType: + """ + Overview: + Load the value. + Arguments: + - value (:obj:`_ValueType`): The value to be loaded. + """ + return self.__load(value) def check(self, value: _ValueType) -> bool: + """ + Overview: + Check whether the value is valid. + Arguments: + - value (:obj:`_ValueType`): The value to be checked. + """ + return self.__check(value) def __call__(self, value: _ValueType) -> _ValueType: + """ + Overview: + Load the value. + Arguments: + - value (:obj:`_ValueType`): The value to be loaded. + """ + return self.__load(value) def __and__(self, other) -> 'ILoaderClass': + """ + Overview: + Combine two loaders. + Arguments: + - other (:obj:`ILoaderClass`): The other loader. + """ def _load(value: _ValueType) -> _ValueType: self.load(value) @@ -124,9 +197,22 @@ def _load(value: _ValueType) -> _ValueType: return Loader(_load) def __rand__(self, other) -> 'ILoaderClass': + """ + Overview: + Combine two loaders. + Arguments: + - other (:obj:`ILoaderClass`): The other loader. + """ + return Loader(other) & self def __or__(self, other) -> 'ILoaderClass': + """ + Overview: + Combine two loaders. + Arguments: + - other (:obj:`ILoaderClass`): The other loader. + """ def _load(value: _ValueType) -> _ValueType: try: @@ -137,9 +223,22 @@ def _load(value: _ValueType) -> _ValueType: return Loader(_load) def __ror__(self, other) -> 'ILoaderClass': + """ + Overview: + Combine two loaders. + Arguments: + - other (:obj:`ILoaderClass`): The other loader. + """ + return Loader(other) | self def __rshift__(self, other) -> 'ILoaderClass': + """ + Overview: + Combine two loaders. + Arguments: + - other (:obj:`ILoaderClass`): The other loader. + """ def _load(value: _ValueType) -> _ValueType: _return_value = self.load(value) @@ -148,4 +247,11 @@ def _load(value: _ValueType) -> _ValueType: return Loader(_load) def __rrshift__(self, other) -> 'ILoaderClass': + """ + Overview: + Combine two loaders. + Arguments: + - other (:obj:`ILoaderClass`): The other loader. + """ + return Loader(other) >> self diff --git a/ding/utils/loader/collection.py b/ding/utils/loader/collection.py index cbac490df4..770e6c6c64 100644 --- a/ding/utils/loader/collection.py +++ b/ding/utils/loader/collection.py @@ -9,8 +9,23 @@ class CollectionError(CompositeStructureError): + """ + Overview: + Collection error. + Interfaces: + ``__init__``, ``errors`` + Properties: + ``errors`` + """ def __init__(self, errors: COLLECTION_ERRORS): + """ + Overview: + Initialize the CollectionError. + Arguments: + - errors (:obj:`COLLECTION_ERRORS`): The errors. + """ + self.__errors = list(errors or []) CompositeStructureError.__init__( self, '{count} error(s) found in collection.'.format(count=repr(list(self.__errors))) @@ -18,10 +33,23 @@ def __init__(self, errors: COLLECTION_ERRORS): @property def errors(self) -> COLLECTION_ERRORS: + """ + Overview: + Get the errors. + """ + return self.__errors def collection(loader, type_back: bool = True) -> ILoaderClass: + """ + Overview: + Create a collection loader. + Arguments: + - loader (:obj:`ILoaderClass`): The loader. + - type_back (:obj:`bool`): Whether to convert the type back. + """ + loader = Loader(loader) def _load(value): @@ -47,6 +75,13 @@ def _load(value): def tuple_(*loaders) -> ILoaderClass: + """ + Overview: + Create a tuple loader. + Arguments: + - loaders (:obj:`tuple`): The loaders. + """ + loaders = [Loader(loader) for loader in loaders] def _load(value: tuple): @@ -56,6 +91,13 @@ def _load(value: tuple): def length(min_length: Optional[int] = None, max_length: Optional[int] = None) -> ILoaderClass: + """ + Overview: + Create a length loader. + Arguments: + - min_length (:obj:`int`): The minimum length. + - max_length (:obj:`int`): The maximum length. + """ def _load(value): _length = len(value) @@ -74,10 +116,23 @@ def _load(value): def length_is(length_: int) -> ILoaderClass: + """ + Overview: + Create a length loader. + Arguments: + - length_ (:obj:`int`): The length. + """ + return length(min_length=length_, max_length=length_) def contains(content) -> ILoaderClass: + """ + Overview: + Create a contains loader. + Arguments: + - content (:obj:`Any`): The content. + """ def _load(value): if content not in value: @@ -89,6 +144,13 @@ def _load(value): def cofilter(checker: Callable[[Any], bool], type_back: bool = True) -> ILoaderClass: + """ + Overview: + Create a cofilter loader. + Arguments: + - checker (:obj:`Callable[[Any], bool]`): The checker. + - type_back (:obj:`bool`): Whether to convert the type back. + """ def _load(value): _result = [item for item in value if checker(item)] @@ -100,6 +162,12 @@ def _load(value): def tpselector(*indices) -> ILoaderClass: + """ + Overview: + Create a tuple selector loader. + Arguments: + - indices (:obj:`tuple`): The indices. + """ def _load(value: tuple): return tuple([value[index] for index in indices]) diff --git a/ding/utils/loader/dict.py b/ding/utils/loader/dict.py index 9bda68a46a..a14d3ff9f8 100644 --- a/ding/utils/loader/dict.py +++ b/ding/utils/loader/dict.py @@ -7,16 +7,43 @@ class DictError(CompositeStructureError): + """ + Overview: + Dict error. + Interfaces: + ``__init__``, ``errors`` + Properties: + ``errors`` + """ def __init__(self, errors: DICT_ERRORS): + """ + Overview: + Initialize the DictError. + Arguments: + - errors (:obj:`DICT_ERRORS`): The errors. + """ + self.__error = errors @property def errors(self) -> DICT_ERRORS: + """ + Overview: + Get the errors. + """ + return self.__error def dict_(**kwargs) -> ILoaderClass: + """ + Overview: + Create a dict loader. + Arguments: + - kwargs (:obj:`Mapping[str, ILoaderClass]`): The loaders. + """ + kwargs = [(k, Loader(v)) for k, v in kwargs.items()] def _load(value): diff --git a/ding/utils/loader/exception.py b/ding/utils/loader/exception.py index 96f2b53ad5..9358f1c85e 100644 --- a/ding/utils/loader/exception.py +++ b/ding/utils/loader/exception.py @@ -7,8 +7,21 @@ class CompositeStructureError(ValueError, metaclass=ABCMeta): + """ + Overview: + Composite structure error. + Interfaces: + ``__init__``, ``errors`` + Properties: + ``errors`` + """ @property @abstractmethod def errors(self) -> ERROR_ITEMS: + """ + Overview: + Get the errors. + """ + raise NotImplementedError diff --git a/ding/utils/loader/mapping.py b/ding/utils/loader/mapping.py index 0bb9e4e85a..c3993c2366 100644 --- a/ding/utils/loader/mapping.py +++ b/ding/utils/loader/mapping.py @@ -10,23 +10,61 @@ class MappingError(CompositeStructureError): + """ + Overview: + Mapping error. + Interfaces: + ``__init__``, ``errors`` + """ def __init__(self, key_errors: MAPPING_ERRORS, value_errors: MAPPING_ERRORS): + """ + Overview: + Initialize the MappingError. + Arguments: + - key_errors (:obj:`MAPPING_ERRORS`): The key errors. + - value_errors (:obj:`MAPPING_ERRORS`): The value errors. + """ + self.__key_errors = list(key_errors or []) self.__value_errors = list(value_errors or []) self.__errors = self.__key_errors + self.__value_errors def key_errors(self) -> MAPPING_ERRORS: + """ + Overview: + Get the key errors. + """ + return self.__key_errors def value_errors(self) -> MAPPING_ERRORS: + """ + Overview: + Get the value errors. + """ + return self.__value_errors def errors(self) -> MAPPING_ERRORS: + """ + Overview: + Get the errors. + """ + return self.__errors def mapping(key_loader, value_loader, type_back: bool = True) -> ILoaderClass: + """ + Overview: + Create a mapping loader. + Arguments: + - key_loader (:obj:`ILoaderClass`): The key loader. + - value_loader (:obj:`ILoaderClass`): The value loader. + - type_back (:obj:`bool`): Whether to convert the type back. + """ + key_loader = Loader(key_loader) value_loader = Loader(value_loader) @@ -67,6 +105,13 @@ def _load(value): def mpfilter(check: Callable[[Any, Any], bool], type_back: bool = True) -> ILoaderClass: + """ + Overview: + Create a mapping filter loader. + Arguments: + - check (:obj:`Callable[[Any, Any], bool]`): The check function. + - type_back (:obj:`bool`): Whether to convert the type back. + """ def _load(value): _result = {key_: value_ for key_, value_ in value.items() if check(key_, value_)} @@ -79,14 +124,29 @@ def _load(value): def mpkeys() -> ILoaderClass: + """ + Overview: + Create a mapping keys loader. + """ + return method('items') & method('keys') & Loader(lambda v: set(v.keys())) def mpvalues() -> ILoaderClass: + """ + Overview: + Create a mapping values loader. + """ + return method('items') & method('values') & Loader(lambda v: set(v.values())) def mpitems() -> ILoaderClass: + """ + Overview: + Create a mapping items loader. + """ + return method('items') & Loader(lambda v: set([(key, value) for key, value in v.items()])) @@ -94,10 +154,25 @@ def mpitems() -> ILoaderClass: def item(key) -> ILoaderClass: + """ + Overview: + Create a item loader. + Arguments: + - key (:obj:`Any`): The key. + """ + return _INDEX_PRECHECK & Loader( (lambda v: key in v.keys(), lambda v: v[key], KeyError('key {key} not found'.format(key=repr(key)))) ) def item_or(key, default) -> ILoaderClass: + """ + Overview: + Create a item or loader. + Arguments: + - key (:obj:`Any`): The key. + - default (:obj:`Any`): The default value. + """ + return _INDEX_PRECHECK & (item(key) | raw(default)) diff --git a/ding/utils/loader/norm.py b/ding/utils/loader/norm.py index ea99692a93..af142ed4e6 100644 --- a/ding/utils/loader/norm.py +++ b/ding/utils/loader/norm.py @@ -7,6 +7,12 @@ def _callable_to_norm(func: Callable[[Any], Any]) -> 'INormClass': + """ + Overview: + Convert callable to norm. + Arguments: + - func (:obj:`Callable[[Any], Any]`): The callable to be converted. + """ class _Norm(INormClass): @@ -17,6 +23,13 @@ def _call(self, value): def norm(value) -> 'INormClass': + """ + Overview: + Convert value to norm. + Arguments: + - value (:obj:`Any`): The value to be converted. + """ + if isinstance(value, INormClass): return value elif isinstance(value, ILoaderClass): @@ -26,6 +39,12 @@ def norm(value) -> 'INormClass': def normfunc(func): + """ + Overview: + Convert function to norm function. + Arguments: + - func (:obj:`Callable[[Any], Any]`): The function to be converted. + """ @wraps(func) def _new_func(*args_norm, **kwargs_norm): @@ -47,14 +66,37 @@ def _callable(v): def _unary(a: 'INormClass', func: UNARY_FUNC) -> 'INormClass': + """ + Overview: + Create a unary norm. + Arguments: + - a (:obj:`INormClass`): The norm. + - func (:obj:`UNARY_FUNC`): The function. + """ + return _callable_to_norm(lambda v: func(a(v))) def _binary(a: 'INormClass', b: 'INormClass', func: BINARY_FUNC) -> 'INormClass': + """ + Overview: + Create a binary norm. + Arguments: + - a (:obj:`INormClass`): The first norm. + - b (:obj:`INormClass`): The second norm. + - func (:obj:`BINARY_FUNC`): The function. + """ return _callable_to_norm(lambda v: func(a(v), b(v))) def _binary_reducing(func: BINARY_FUNC, zero): + """ + Overview: + Create a binary reducing norm. + Arguments: + - func (:obj:`BINARY_FUNC`): The function. + - zero (:obj:`Any`): The zero value. + """ @wraps(func) def _new_func(*args) -> 'INormClass': @@ -67,118 +109,382 @@ def _new_func(*args) -> 'INormClass': class INormClass: + """ + Overview: + The norm class. + Interfaces: + ``__call__``, ``__add__``, ``__radd__``, ``__sub__``, ``__rsub__``, ``__mul__``, ``__rmul__``, ``__matmul__``, + ``__rmatmul__``, ``__truediv__``, ``__rtruediv__``, ``__floordiv__``, ``__rfloordiv__``, ``__mod__``, + ``__rmod__``, ``__pow__``, ``__rpow__``, ``__lshift__``, ``__rlshift__``, ``__rshift__``, ``__rrshift__``, + ``__and__``, ``__rand__``, ``__or__``, ``__ror__``, ``__xor__``, ``__rxor__``, ``__invert__``, ``__pos__``, + ``__neg__``, ``__eq__``, ``__ne__``, ``__lt__``, ``__le__``, ``__gt__``, ``__ge__`` + """ @abstractmethod def _call(self, value): + """ + Overview: + Call the norm. + Arguments: + - value (:obj:`Any`): The value to be normalized. + """ + raise NotImplementedError def __call__(self, value): + """ + Overview: + Call the norm. + Arguments: + - value (:obj:`Any`): The value to be normalized. + """ + return self._call(value) def __add__(self, other): + """ + Overview: + Add the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__add__) def __radd__(self, other): + """ + Overview: + Add the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return norm(other) + self def __sub__(self, other): + """ + Overview: + Subtract the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__sub__) def __rsub__(self, other): + """ + Overview: + Subtract the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return norm(other) - self def __mul__(self, other): + """ + Overview: + Multiply the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__mul__) def __rmul__(self, other): + """ + Overview: + Multiply the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return norm(other) * self def __matmul__(self, other): + """ + Overview: + Matrix multiply the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__matmul__) def __rmatmul__(self, other): + """ + Overview: + Matrix multiply the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return norm(other) @ self def __truediv__(self, other): + """ + Overview: + Divide the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__truediv__) def __rtruediv__(self, other): + """ + Overview: + Divide the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return norm(other) / self def __floordiv__(self, other): + """ + Overview: + Floor divide the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__floordiv__) def __rfloordiv__(self, other): + """ + Overview: + Floor divide the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return norm(other) // self def __mod__(self, other): + """ + Overview: + Mod the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__mod__) def __rmod__(self, other): + """ + Overview: + Mod the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return norm(other) % self def __pow__(self, power, modulo=None): + """ + Overview: + Power the norm. + Arguments: + - power (:obj:`Any`): The power. + - modulo (:obj:`Any`): The modulo. + """ + return _binary(self, norm(power), operator.__pow__) def __rpow__(self, other): + """ + Overview: + Power the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return norm(other) ** self def __lshift__(self, other): + """ + Overview: + Lshift the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__lshift__) def __rlshift__(self, other): + """ + Overview: + Lshift the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return norm(other) << self def __rshift__(self, other): + """ + Overview: + Rshift the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__rshift__) def __rrshift__(self, other): + """ + Overview: + Rshift the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return norm(other) >> self def __and__(self, other): + """ + Overview: + And operation the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__and__) def __rand__(self, other): + """ + Overview: + And operation the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return norm(other) & self def __or__(self, other): + """ + Overview: + Or operation the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__or__) def __ror__(self, other): + """ + Overview: + Or operation the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return norm(other) | self def __xor__(self, other): + """ + Overview: + Xor operation the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__xor__) def __rxor__(self, other): + """ + Overview: + Xor operation the norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return norm(other) ^ self def __invert__(self): + """ + Overview: + Invert the norm. + """ + return _unary(self, operator.__invert__) def __pos__(self): + """ + Overview: + Positive the norm. + """ + return _unary(self, operator.__pos__) def __neg__(self): + """ + Overview: + Negative the norm. + """ + return _unary(self, operator.__neg__) # Attention: DO NOT USE LINKING COMPARE OPERATORS, IT WILL CAUSE ERROR. def __eq__(self, other): + """ + Overview: + Compare the norm if they are equal. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__eq__) def __ne__(self, other): + """ + Overview: + Compare the norm if they are not equal. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__ne__) def __lt__(self, other): + """ + Overview: + Compare the norm if it is less than the other norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__lt__) def __le__(self, other): + """ + Overview: + Compare the norm if it is less than or equal to the other norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__le__) def __gt__(self, other): + """ + Overview: + Compare the norm if it is greater than the other norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__gt__) def __ge__(self, other): + """ + Overview: + Compare the norm if it is greater than or equal to the other norm. + Arguments: + - other (:obj:`Any`): The other norm. + """ + return _binary(self, norm(other), operator.__ge__) @@ -204,6 +510,14 @@ def __ge__(self, other): @normfunc def lcmp(first, *items): + """ + Overview: + Compare the items. + Arguments: + - first (:obj:`Any`): The first item. + - items (:obj:`Any`): The other items. + """ + if len(items) % 2 == 1: raise ValueError('Count of items should be odd number but {number} found.'.format(number=len(items) + 1)) diff --git a/ding/utils/loader/number.py b/ding/utils/loader/number.py index 61b4e97225..a9fdc59a16 100644 --- a/ding/utils/loader/number.py +++ b/ding/utils/loader/number.py @@ -10,6 +10,15 @@ def numeric(int_ok: bool = True, float_ok: bool = True, inf_ok: bool = True) -> ILoaderClass: + """ + Overview: + Create a numeric loader. + Arguments: + - int_ok (:obj:`bool`): Whether int is allowed. + - float_ok (:obj:`bool`): Whether float is allowed. + - inf_ok (:obj:`bool`): Whether inf is allowed. + """ + if not int_ok and not float_ok: raise ValueError('Either int or float should be allowed.') @@ -42,6 +51,17 @@ def interval( right_ok: bool = True, eps=0.0 ) -> ILoaderClass: + """ + Overview: + Create a interval loader. + Arguments: + - left (:obj:`Optional[NUMBER_TYPING]`): The left bound. + - right (:obj:`Optional[NUMBER_TYPING]`): The right bound. + - left_ok (:obj:`bool`): Whether left bound is allowed. + - right_ok (:obj:`bool`): Whether right bound is allowed. + - eps (:obj:`float`): The epsilon. + """ + if left is None: left = -math.inf if right is None: @@ -93,70 +113,170 @@ def _load(value) -> NUMBER_TYPING: def is_negative() -> ILoaderClass: + """ + Overview: + Create a negative loader. + """ + return Loader((lambda x: x < 0, lambda x: ValueError('negative required but {value} found'.format(value=repr(x))))) def is_positive() -> ILoaderClass: + """ + Overview: + Create a positive loader. + """ + return Loader((lambda x: x > 0, lambda x: ValueError('positive required but {value} found'.format(value=repr(x))))) def non_negative() -> ILoaderClass: + """ + Overview: + Create a non-negative loader. + """ + return Loader( (lambda x: x >= 0, lambda x: ValueError('non-negative required but {value} found'.format(value=repr(x)))) ) def non_positive() -> ILoaderClass: + """ + Overview: + Create a non-positive loader. + """ + return Loader( (lambda x: x <= 0, lambda x: ValueError('non-positive required but {value} found'.format(value=repr(x)))) ) def negative() -> ILoaderClass: + """ + Overview: + Create a negative loader. + """ + return Loader(lambda x: -x) def positive() -> ILoaderClass: + """ + Overview: + Create a positive loader. + """ + return Loader(lambda x: +x) def _math_binary(func: Callable[[Any, Any], Any], attachment) -> ILoaderClass: + """ + Overview: + Create a math binary loader. + Arguments: + - func (:obj:`Callable[[Any, Any], Any]`): The function. + - attachment (:obj:`Any`): The attachment. + """ + return Loader(lambda x: func(x, Loader(attachment)(x))) def plus(addend) -> ILoaderClass: + """ + Overview: + Create a plus loader. + Arguments: + - addend (:obj:`Any`): The addend. + """ + return _math_binary(lambda x, y: x + y, addend) def minus(subtrahend) -> ILoaderClass: + """ + Overview: + Create a minus loader. + Arguments: + - subtrahend (:obj:`Any`): The subtrahend. + """ + return _math_binary(lambda x, y: x - y, subtrahend) def minus_with(minuend) -> ILoaderClass: + """ + Overview: + Create a minus loader. + Arguments: + - minuend (:obj:`Any`): The minuend. + """ + return _math_binary(lambda x, y: y - x, minuend) def multi(multiplier) -> ILoaderClass: + """ + Overview: + Create a multi loader. + Arguments: + - multiplier (:obj:`Any`): The multiplier. + """ + return _math_binary(lambda x, y: x * y, multiplier) def divide(divisor) -> ILoaderClass: + """ + Overview: + Create a divide loader. + Arguments: + - divisor (:obj:`Any`): The divisor. + """ + return _math_binary(lambda x, y: x / y, divisor) def divide_with(dividend) -> ILoaderClass: + """ + Overview: + Create a divide loader. + Arguments: + - dividend (:obj:`Any`): The dividend. + """ + return _math_binary(lambda x, y: y / x, dividend) def power(index) -> ILoaderClass: + """ + Overview: + Create a power loader. + Arguments: + - index (:obj:`Any`): The index. + """ + return _math_binary(lambda x, y: x ** y, index) def power_with(base) -> ILoaderClass: + """ + Overview: + Create a power loader. + Arguments: + - base (:obj:`Any`): The base. + """ + return _math_binary(lambda x, y: y ** x, base) def msum(*items) -> ILoaderClass: + """ + Overview: + Create a sum loader. + Arguments: + - items (:obj:`tuple`): The items. + """ def _load(value): return sum([item(value) for item in items]) @@ -165,6 +285,12 @@ def _load(value): def mmulti(*items) -> ILoaderClass: + """ + Overview: + Create a multi loader. + Arguments: + - items (:obj:`tuple`): The items. + """ def _load(value): _result = 1 @@ -186,6 +312,15 @@ def _load(value): def _msinglecmp(first, op, second) -> ILoaderClass: + """ + Overview: + Create a single compare loader. + Arguments: + - first (:obj:`Any`): The first item. + - op (:obj:`str`): The operator. + - second (:obj:`Any`): The second item. + """ + first = Loader(first) second = Loader(second) @@ -206,6 +341,14 @@ def _msinglecmp(first, op, second) -> ILoaderClass: def mcmp(first, *items) -> ILoaderClass: + """ + Overview: + Create a multi compare loader. + Arguments: + - first (:obj:`Any`): The first item. + - items (:obj:`tuple`): The items. + """ + if len(items) % 2 == 1: raise ValueError('Count of items should be odd number but {number} found.'.format(number=len(items) + 1)) diff --git a/ding/utils/loader/string.py b/ding/utils/loader/string.py index 4e7c9a8da3..16d4827cc4 100644 --- a/ding/utils/loader/string.py +++ b/ding/utils/loader/string.py @@ -9,6 +9,13 @@ def enum(*items, case_sensitive: bool = True) -> ILoaderClass: + """ + Overview: + Create an enum loader. + Arguments: + - items (:obj:`Iterable[str]`): The items. + - case_sensitive (:obj:`bool`): Whether case sensitive. + """ def _case_sensitive(func: STRING_PROCESSOR) -> STRING_PROCESSOR: if case_sensitive: @@ -38,6 +45,13 @@ def _load(value: str): def _to_regexp(regexp) -> Pattern: + """ + Overview: + Convert regexp to re.Pattern. + Arguments: + - regexp (:obj:`Union[str, re.Pattern]`): The regexp. + """ + if isinstance(regexp, Pattern): return regexp elif isinstance(regexp, str): @@ -49,6 +63,13 @@ def _to_regexp(regexp) -> Pattern: def rematch(regexp: Union[str, Pattern]) -> ILoaderClass: + """ + Overview: + Create a rematch loader. + Arguments: + - regexp (:obj:`Union[str, re.Pattern]`): The regexp. + """ + regexp = _to_regexp(regexp) def _load(value: str): @@ -66,6 +87,14 @@ def _load(value: str): def regrep(regexp: Union[str, Pattern], group: int = 0) -> ILoaderClass: + """ + Overview: + Create a regrep loader. + Arguments: + - regexp (:obj:`Union[str, re.Pattern]`): The regexp. + - group (:obj:`int`): The group. + """ + regexp = _to_regexp(regexp) def _load(value: str): diff --git a/ding/utils/loader/types.py b/ding/utils/loader/types.py index b868b05032..6039395ca6 100644 --- a/ding/utils/loader/types.py +++ b/ding/utils/loader/types.py @@ -5,6 +5,13 @@ def is_type(type_: type) -> ILoaderClass: + """ + Overview: + Create a type loader. + Arguments: + - type_ (:obj:`type`): The type. + """ + if isinstance(type_, type): return Loader(type_) else: @@ -12,10 +19,22 @@ def is_type(type_: type) -> ILoaderClass: def to_type(type_: type) -> ILoaderClass: + """ + Overview: + Create a type loader. + Arguments: + - type_ (:obj:`type`): The type. + """ + return Loader(lambda v: type_(v)) def is_callable() -> ILoaderClass: + """ + Overview: + Create a callable loader. + """ + return _reset_exception( check_only(prop('__call__')), lambda v, e: TypeError('callable expected but {func} not found'.format(func=repr('__call__'))) @@ -23,6 +42,13 @@ def is_callable() -> ILoaderClass: def prop(attr_name: str) -> ILoaderClass: + """ + Overview: + Create a attribute loader. + Arguments: + - attr_name (:obj:`str`): The attribute name. + """ + return Loader( ( lambda v: hasattr(v, attr_name), lambda v: getattr(v, attr_name), @@ -32,6 +58,13 @@ def prop(attr_name: str) -> ILoaderClass: def method(method_name: str) -> ILoaderClass: + """ + Overview: + Create a method loader. + Arguments: + - method_name (:obj:`str`): The method name. + """ + return _reset_exception( prop(method_name) >> is_callable(), lambda v, e: TypeError('type {type} not support function {func}'.format(type=repr(type(v).__name__), func=repr('__iter__'))) @@ -39,8 +72,24 @@ def method(method_name: str) -> ILoaderClass: def fcall(*args, **kwargs) -> ILoaderClass: + """ + Overview: + Create a function loader. + Arguments: + - args (:obj:`Tuple[Any]`): The args. + - kwargs (:obj:`Dict[str, Any]`): The kwargs. + """ + return Loader(lambda v: v(*args, **kwargs)) def fpartial(*args, **kwargs) -> ILoaderClass: + """ + Overview: + Create a partial function loader. + Arguments: + - args (:obj:`Tuple[Any]`): The args. + - kwargs (:obj:`Dict[str, Any]`): The kwargs. + """ + return Loader(lambda v: partial(v, *args, **kwargs)) diff --git a/ding/utils/loader/utils.py b/ding/utils/loader/utils.py index ac1998e26c..140bbf033d 100644 --- a/ding/utils/loader/utils.py +++ b/ding/utils/loader/utils.py @@ -2,20 +2,51 @@ def keep() -> ILoaderClass: + """ + Overview: + Create a keep loader. + """ + return Loader(lambda v: v) def raw(value) -> ILoaderClass: + """ + Overview: + Create a raw loader. + """ + return Loader(lambda v: value) def optional(loader) -> ILoaderClass: + """ + Overview: + Create a optional loader. + Arguments: + - loader (:obj:`ILoaderClass`): The loader. + """ + return Loader(loader) | None def check_only(loader) -> ILoaderClass: + """ + Overview: + Create a check only loader. + Arguments: + - loader (:obj:`ILoaderClass`): The loader. + """ + return Loader(loader) & keep() def check(loader) -> ILoaderClass: + """ + Overview: + Create a check loader. + Arguments: + - loader (:obj:`ILoaderClass`): The loader. + """ + return Loader(lambda x: Loader(loader).check(x)) diff --git a/ding/utils/lock_helper.py b/ding/utils/lock_helper.py index 67586b0767..368cac35f8 100644 --- a/ding/utils/lock_helper.py +++ b/ding/utils/lock_helper.py @@ -46,7 +46,7 @@ def __init__(self, type_: LockContextType = LockContextType.THREAD_LOCK): Init the lock according to the given type. Arguments: - type_ (:obj:`LockContextType`): The type of lock to be used. Defaults to LockContextType.THREAD_LOCK. + - type_ (:obj:`LockContextType`): The type of lock to be used. Defaults to LockContextType.THREAD_LOCK. """ self.lock = _LOCK_TYPE_MAPPING[type_]() @@ -75,6 +75,9 @@ def __exit__(self, *args, **kwargs): """ Overview: Exits the context and releases the lock. + Arguments: + - args (:obj:`Tuple`): The arguments passed to the ``__exit__`` function. + - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__exit__`` function. """ self.lock.release() @@ -120,7 +123,7 @@ class FcntlContext: Example: >>> lock_path = "/path/to/lock/file" - >>>with FcntlContext(lock_path) as lock: + >>> with FcntlContext(lock_path) as lock: >>> # Perform operations while the lock is held """ @@ -150,6 +153,9 @@ def __exit__(self, *args, **kwargs) -> None: """ Overview: Closes the file and releases any resources used by the lock_helper object. + Arguments: + - args (:obj:`Tuple`): The arguments passed to the ``__exit__`` function. + - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__exit__`` function. """ self.f.close() self.f = None diff --git a/ding/utils/log_helper.py b/ding/utils/log_helper.py index 3c83e5242f..0a966532ac 100644 --- a/ding/utils/log_helper.py +++ b/ding/utils/log_helper.py @@ -19,7 +19,7 @@ def build_logger( need_text: bool = True, text_level: Union[int, str] = logging.INFO ) -> Tuple[Optional[logging.Logger], Optional['SummaryWriter']]: # noqa - r""" + """ Overview: Build text logger and tensorboard logger. Arguments: @@ -41,6 +41,15 @@ def build_logger( class TBLoggerFactory(object): + """ + Overview: + TBLoggerFactory is a factory class for ``SummaryWriter``. + Interfaces: + ``create_logger`` + Properties: + - ``tb_loggers`` (:obj:`Dict[str, SummaryWriter]`): A dict that stores ``SummaryWriter`` instances. + """ + tb_loggers = {} @classmethod @@ -53,10 +62,16 @@ def create_logger(cls: type, logdir: str) -> DistributedWriter: class LoggerFactory(object): + """ + Overview: + LoggerFactory is a factory class for ``logging.Logger``. + Interfaces: + ``create_logger``, ``get_tabulate_vars``, ``get_tabulate_vars_hor`` + """ @classmethod def create_logger(cls, path: str, name: str = 'default', level: Union[int, str] = logging.INFO) -> logging.Logger: - r""" + """ Overview: Create logger using logging Arguments: @@ -80,7 +95,7 @@ def create_logger(cls, path: str, name: str = 'default', level: Union[int, str] @staticmethod def get_tabulate_vars(variables: Dict[str, Any]) -> str: - r""" + """ Overview: Get the text description in tabular form of all vars Arguments: @@ -97,6 +112,13 @@ def get_tabulate_vars(variables: Dict[str, Any]) -> str: @staticmethod def get_tabulate_vars_hor(variables: Dict[str, Any]) -> str: + """ + Overview: + Get the text description in tabular form of all vars + Arguments: + - variables (:obj:`List[str]`): Names of the vars to query. + """ + column_to_divide = 5 # which includes the header "Name & Value" datak = [] @@ -131,7 +153,7 @@ def get_tabulate_vars_hor(variables: Dict[str, Any]) -> str: def pretty_print(result: dict, direct_print: bool = True) -> str: - r""" + """ Overview: Print a dict ``result`` in a pretty way Arguments: diff --git a/ding/utils/log_writer_helper.py b/ding/utils/log_writer_helper.py index 7efbc32416..0f8a1c5115 100644 --- a/ding/utils/log_writer_helper.py +++ b/ding/utils/log_writer_helper.py @@ -16,10 +16,22 @@ class DistributedWriter(SummaryWriter): A simple subclass of SummaryWriter that supports writing to one process in multi-process mode. The best way is to use it in conjunction with the ``router`` to take advantage of the message \ and event components of the router (see ``writer.plugin``). + Interfaces: + ``get_instance``, ``plugin``, ``initialize``, ``__del__`` """ root = None def __init__(self, *args, **kwargs): + """ + Overview: + Initialize the DistributedWriter object. + Arguments: + - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \ + SummaryWriter. + - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ + SummaryWriter. + """ + self._default_writer_to_disk = kwargs.get("write_to_disk") if "write_to_disk" in kwargs else True # We need to write data to files lazily, so we should not use file writer in __init__, # On the contrary, we will initialize the file writer when the user calls the @@ -37,6 +49,11 @@ def get_instance(cls, *args, **kwargs) -> "DistributedWriter": Overview: Get instance and set the root level instance on the first called. If args and kwargs is none, this method will return root instance. + Arguments: + - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \ + SummaryWriter. + - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ + SummaryWriter. """ if args or kwargs: ins = cls(*args, **kwargs) @@ -52,6 +69,9 @@ def plugin(self, router: "Parallel", is_writer: bool = False) -> "DistributedWri Plugin ``router``, so when using this writer with active router, it will automatically send requests\ to the main writer instead of writing it to the disk. So we can collect data from multiple processes\ and write them into one file. + Arguments: + - router (:obj:`Parallel`): The router to be plugged in. + - is_writer (:obj:`bool`): Whether this writer is the main writer. Examples: >>> DistributedWriter().plugin(router, is_writer=True) """ @@ -66,20 +86,44 @@ def plugin(self, router: "Parallel", is_writer: bool = False) -> "DistributedWri return self def _on_distributed_writer(self, fn_name: str, *args, **kwargs): + """ + Overview: + This method is called when the router receives a request to write data. + Arguments: + - fn_name (:obj:`str`): The name of the function to be called. + - args (:obj:`Tuple`): The arguments passed to the function to be called. + - kwargs (:obj:`Dict`): The keyword arguments passed to the function to be called. + """ + if self._is_writer: getattr(self, fn_name)(*args, **kwargs) def initialize(self): + """ + Overview: + Initialize the file writer. + """ self.close() self._write_to_disk = self._default_writer_to_disk self._get_file_writer() self._lazy_initialized = True def __del__(self): + """ + Overview: + Close the file writer. + """ self.close() def enable_parallel(fn_name, fn): + """ + Overview: + Decorator to enable parallel writing. + Arguments: + - fn_name (:obj:`str`): The name of the function to be called. + - fn (:obj:`Callable`): The function to be called. + """ def _parallel_fn(self: DistributedWriter, *args, **kwargs): if not self._lazy_initialized: diff --git a/ding/utils/normalizer_helper.py b/ding/utils/normalizer_helper.py index 1b502ca5a9..0fc914f30e 100755 --- a/ding/utils/normalizer_helper.py +++ b/ding/utils/normalizer_helper.py @@ -7,7 +7,7 @@ class DatasetNormalizer: The `DatasetNormalizer` class provides functionality to normalize and unnormalize data in a dataset. It takes a dataset as input and applies a normalizer function to each key in the dataset. - Interface: + Interfaces: ``__init__``, ``__repr__``, ``normalize``, ``unnormalize``. """ @@ -109,25 +109,55 @@ class Normalizer: Overview: Parent class, subclass by defining the `normalize` and `unnormalize` methods - Interface: + Interfaces: ``__init__``, ``__repr__``, ``normalize``, ``unnormalize``. """ def __init__(self, X): + """ + Overview: + Initialize the Normalizer object. + Arguments: + - X (:obj:`np.ndarray`): The data to be normalized. + """ + self.X = X.astype(np.float32) self.mins = X.min(axis=0) self.maxs = X.max(axis=0) - def __repr__(self): + def __repr__(self) -> str: + """ + Overview: + Returns a string representation of the Normalizer object. + Returns: + - ret (:obj:`str`): A string representation of the Normalizer object. + """ + return ( f"""[ Normalizer ] dim: {self.mins.size}\n -: """ f"""{np.round(self.mins, 2)}\n +: {np.round(self.maxs, 2)}\n""" ) def normalize(self, *args, **kwargs): + """ + Overview: + Normalize the input data. + Arguments: + - args (:obj:`list`): The arguments passed to the ``normalize`` function. + - kwargs (:obj:`dict`): The keyword arguments passed to the ``normalize`` function. + """ + raise NotImplementedError() def unnormalize(self, *args, **kwargs): + """ + Overview: + Unnormalize the input data. + Arguments: + - args (:obj:`list`): The arguments passed to the ``unnormalize`` function. + - kwargs (:obj:`dict`): The keyword arguments passed to the ``unnormalize`` function. + """ + raise NotImplementedError() @@ -136,17 +166,34 @@ class GaussianNormalizer(Normalizer): Overview: A class that normalizes data to zero mean and unit variance. - Interface: + Interfaces: ``__init__``, ``__repr__``, ``normalize``, ``unnormalize``. """ def __init__(self, *args, **kwargs): + """ + Overview: + Initialize the GaussianNormalizer object. + Arguments: + - args (:obj:`list`): The arguments passed to the ``__init__`` function of the parent class, \ + i.e., the Normalizer class. + - kwargs (:obj:`dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ + i.e., the Normalizer class. + """ + super().__init__(*args, **kwargs) self.means = self.X.mean(axis=0) self.stds = self.X.std(axis=0) self.z = 1 - def __repr__(self): + def __repr__(self) -> str: + """ + Overview: + Returns a string representation of the GaussianNormalizer object. + Returns: + - ret (:obj:`str`): A string representation of the GaussianNormalizer object. + """ + return ( f"""[ Normalizer ] dim: {self.mins.size}\n """ f"""means: {np.round(self.means, 2)}\n """ @@ -185,16 +232,30 @@ class CDFNormalizer(Normalizer): Overview: A class that makes training data uniform (over each dimension) by transforming it with marginal CDFs. - Interface: + Interfaces: ``__init__``, ``__repr__``, ``normalize``, ``unnormalize``. """ def __init__(self, X): + """ + Overview: + Initialize the CDFNormalizer object. + Arguments: + - X (:obj:`np.ndarray`): The data to be normalized. + """ + super().__init__(atleast_2d(X)) self.dim = self.X.shape[1] self.cdfs = [CDFNormalizer1d(self.X[:, i]) for i in range(self.dim)] - def __repr__(self): + def __repr__(self) -> str: + """ + Overview: + Returns a string representation of the CDFNormalizer object. + Returns: + - ret (:obj:`str`): A string representation of the CDFNormalizer object. + """ + return f'[ CDFNormalizer ] dim: {self.mins.size}\n' + ' | '.join( f'{i:3d}: {cdf}' for i, cdf in enumerate(self.cdfs) ) @@ -252,11 +313,18 @@ class CDFNormalizer1d: Overview: CDF normalizer for a single dimension. This class provides methods to normalize and unnormalize data \ using the Cumulative Distribution Function (CDF) approach. - Interface: + Interfaces: ``__init__``, ``__repr__``, ``normalize``, ``unnormalize``. """ def __init__(self, X: np.ndarray): + """ + Overview: + Initialize the CDFNormalizer1d object. + Arguments: + - X (:obj:`np.ndarray`): The data to be normalized. + """ + import scipy.interpolate as interpolate assert X.ndim == 1 self.X = X.astype(np.float32) @@ -272,6 +340,11 @@ def __init__(self, X: np.ndarray): self.ymin, self.ymax = cumprob.min(), cumprob.max() def __repr__(self) -> str: + """ + Overview: + Returns a string representation of the CDFNormalizer1d object. + """ + return (f'[{np.round(self.xmin, 2):.4f}, {np.round(self.xmax, 2):.4f}') def normalize(self, x: np.ndarray) -> np.ndarray: @@ -375,7 +448,7 @@ class LimitsNormalizer(Normalizer): A class that normalizes and unnormalizes values within specified limits. \ This class maps values within the range [xmin, xmax] to the range [-1, 1]. - Interface: + Interfaces: ``__init__``, ``__repr__``, ``normalize``, ``unnormalize``. """ diff --git a/ding/utils/orchestrator_launcher.py b/ding/utils/orchestrator_launcher.py index 4b18195e72..69324ecc08 100644 --- a/ding/utils/orchestrator_launcher.py +++ b/ding/utils/orchestrator_launcher.py @@ -6,7 +6,10 @@ class OrchestratorLauncher(object): """ - Overview: object to manage di-orchestrator in existing k8s cluster + Overview: + Object to manage di-orchestrator in existing k8s cluster + Interfaces: + ``__init__``, ``create_orchestrator``, ``delete_orchestrator`` """ def __init__( @@ -18,6 +21,18 @@ def __init__( cert_manager_version: str = 'v1.3.1', cert_manager_registry: str = 'quay.io/jetstack' ) -> None: + """ + Overview: + Initialize the OrchestratorLauncher object. + Arguments: + - version (:obj:`str`): The version of di-orchestrator. + - name (:obj:`str`): The name of di-orchestrator. + - cluster (:obj:`K8sLauncher`): The k8s cluster to deploy di-orchestrator. + - registry (:obj:`str`): The docker registry to pull images. + - cert_manager_version (:obj:`str`): The version of cert-manager. + - cert_manager_registry (:obj:`str`): The docker registry to pull cert-manager images. + """ + self.name = name self.version = version self.cluster = cluster @@ -47,6 +62,11 @@ def __init__( self._check_kubectl_tools() def _check_kubectl_tools(self) -> None: + """ + Overview: + Check if kubectl tools is installed. + """ + args = ['which', 'kubectl'] proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, _ = proc.communicate() @@ -56,6 +76,11 @@ def _check_kubectl_tools(self) -> None: ) def create_orchestrator(self) -> None: + """ + Overview: + Create di-orchestrator in k8s cluster. + """ + print('Creating orchestrator...') if self.cluster is not None: self.cluster.preload_images(self._images) @@ -69,6 +94,11 @@ def create_orchestrator(self) -> None: wait_to_be_ready(self._namespace, self._webhook) def delete_orchestrator(self) -> None: + """ + Overview: + Delete di-orchestrator in k8s cluster. + """ + print('Deleting orchestrator...') for item in [self.cert_manager, self.installer]: args = ['kubectl', 'delete', '-f', f'{item}'] @@ -81,6 +111,13 @@ def delete_orchestrator(self) -> None: def create_components_from_config(config: str) -> None: + """ + Overview: + Create components from config file. + Arguments: + - config (:obj:`str`): The config file. + """ + args = ['kubectl', 'create', '-f', f'{config}'] proc = subprocess.Popen(args, stderr=subprocess.PIPE) _, err = proc.communicate() @@ -93,6 +130,15 @@ def create_components_from_config(config: str) -> None: def wait_to_be_ready(namespace: str, component: str, timeout: int = 120) -> None: + """ + Overview: + Wait for the component to be ready. + Arguments: + - namespace (:obj:`str`): The namespace of the component. + - component (:obj:`str`): The name of the component. + - timeout (:obj:`int`): The timeout of waiting. + """ + try: from kubernetes import config, client, watch except ModuleNotFoundError: diff --git a/ding/utils/profiler_helper.py b/ding/utils/profiler_helper.py index 1a61e9ea0f..96c2a1a076 100644 --- a/ding/utils/profiler_helper.py +++ b/ding/utils/profiler_helper.py @@ -14,11 +14,16 @@ class Profiler: Overview: A class for profiling code execution. It can be used as a context manager or a decorator. - Interface: + Interfaces: ``__init__``, ``mkdir``, ``write_profile``, ``profile``. """ def __init__(self): + """ + Overview: + Initialize the Profiler object. + """ + self.pr = cProfile.Profile() def mkdir(self, directory: str): diff --git a/ding/utils/pytorch_ddp_dist_helper.py b/ding/utils/pytorch_ddp_dist_helper.py index 43398450aa..13d9e1e299 100644 --- a/ding/utils/pytorch_ddp_dist_helper.py +++ b/ding/utils/pytorch_ddp_dist_helper.py @@ -12,7 +12,7 @@ def get_rank() -> int: - r""" + """ Overview: Get the rank of current process in total world_size """ @@ -21,7 +21,7 @@ def get_rank() -> int: def get_world_size() -> int: - r""" + """ Overview: Get the world_size(total process number in data parallel training) """ @@ -35,16 +35,39 @@ def get_world_size() -> int: def allreduce(x: torch.Tensor) -> None: + """ + Overview: + All reduce the tensor ``x`` in the world + Arguments: + - x (:obj:`torch.Tensor`): the tensor to be reduced + """ + dist.all_reduce(x) x.div_(get_world_size()) def allreduce_async(name: str, x: torch.Tensor) -> None: + """ + Overview: + All reduce the tensor ``x`` in the world asynchronously + Arguments: + - name (:obj:`str`): the name of the tensor + - x (:obj:`torch.Tensor`): the tensor to be reduced + """ + x.div_(get_world_size()) dist.all_reduce(x, async_op=True) def reduce_data(x: Union[int, float, torch.Tensor], dst: int) -> Union[int, float, torch.Tensor]: + """ + Overview: + Reduce the tensor ``x`` to the destination process ``dst`` + Arguments: + - x (:obj:`Union[int, float, torch.Tensor]`): the tensor to be reduced + - dst (:obj:`int`): the destination process + """ + if np.isscalar(x): x_tensor = torch.as_tensor([x]).cuda() dist.reduce(x_tensor, dst) @@ -57,6 +80,14 @@ def reduce_data(x: Union[int, float, torch.Tensor], dst: int) -> Union[int, floa def allreduce_data(x: Union[int, float, torch.Tensor], op: str) -> Union[int, float, torch.Tensor]: + """ + Overview: + All reduce the tensor ``x`` in the world + Arguments: + - x (:obj:`Union[int, float, torch.Tensor]`): the tensor to be reduced + - op (:obj:`str`): the operation to perform on data, support ``['sum', 'avg']`` + """ + assert op in ['sum', 'avg'], op if np.isscalar(x): x_tensor = torch.as_tensor([x]).cuda() @@ -77,7 +108,7 @@ def allreduce_data(x: Union[int, float, torch.Tensor], op: str) -> Union[int, fl def get_group(group_size: int) -> List: - r""" + """ Overview: Get the group segmentation of ``group_size`` each group Arguments: @@ -92,9 +123,11 @@ def get_group(group_size: int) -> List: def dist_mode(func: Callable) -> Callable: - r""" + """ Overview: Wrap the function so that in can init and finalize automatically before each call + Arguments: + - func (:obj:`Callable`): the function to be wrapped """ def wrapper(*args, **kwargs): @@ -110,10 +143,17 @@ def dist_init(backend: str = 'nccl', port: str = None, rank: int = None, world_size: int = None) -> Tuple[int, int]: - r""" + """ Overview: - Init the distributed training setting + Initialize the distributed training setting + Arguments: + - backend (:obj:`str`): The backend of the distributed training, support ``['nccl', 'gloo']`` + - addr (:obj:`str`): The address of the master node + - port (:obj:`str`): The port of the master node + - rank (:obj:`int`): The rank of current process + - world_size (:obj:`int`): The total number of processes """ + assert backend in ['nccl', 'gloo'], backend os.environ['MASTER_ADDR'] = addr or os.environ.get('MASTER_ADDR', "localhost") os.environ['MASTER_PORT'] = port or os.environ.get('MASTER_PORT', "10314") # hard-code @@ -141,7 +181,7 @@ def dist_init(backend: str = 'nccl', def dist_finalize() -> None: - r""" + """ Overview: Finalize distributed training resources """ @@ -151,21 +191,46 @@ def dist_finalize() -> None: class DDPContext: + """ + Overview: + A context manager for ``linklink`` distribution + Interfaces: + ``__init__``, ``__enter__``, ``__exit__`` + """ def __init__(self) -> None: + """ + Overview: + Initialize the ``DDPContext`` + """ + pass def __enter__(self) -> None: + """ + Overview: + Initialize ``linklink`` distribution + """ + dist_init() def __exit__(self, *args, **kwargs) -> Any: + """ + Overview: + Finalize ``linklink`` distribution + """ + dist_finalize() def simple_group_split(world_size: int, rank: int, num_groups: int) -> List: - r""" + """ Overview: Split the group according to ``worldsize``, ``rank`` and ``num_groups`` + Arguments: + - world_size (:obj:`int`): The world size + - rank (:obj:`int`): The rank + - num_groups (:obj:`int`): The number of groups .. note:: With faulty input, raise ``array split does not result in an equal division`` @@ -180,6 +245,13 @@ def simple_group_split(world_size: int, rank: int, num_groups: int) -> List: def to_ddp_config(cfg: EasyDict) -> EasyDict: + """ + Overview: + Convert the config to ddp config + Arguments: + - cfg (:obj:`EasyDict`): The config to be converted + """ + w = get_world_size() if 'batch_size' in cfg.policy: cfg.policy.batch_size = int(np.ceil(cfg.policy.batch_size / w)) diff --git a/ding/utils/registry.py b/ding/utils/registry.py index e0c98d29f7..1d55041ffb 100644 --- a/ding/utils/registry.py +++ b/ding/utils/registry.py @@ -10,28 +10,41 @@ class Registry(dict): """ - A helper class for managing registering modules, it extends a dictionary - and provides a register functions. - - Eg. creeting a registry: - some_registry = Registry({"default": default_module}) - - There're two ways of registering new modules: - 1): normal way is just calling register function: - def foo(): - ... - some_registry.register("foo_module", foo) - 2): used as decorator when declaring the module: - @some_registry.register("foo_module") - @some_registry.register("foo_modeul_nickname") - def foo(): - ... - - Access of module is just like using a dictionary, eg: - f = some_registry["foo_module"] + Overview: + A helper class for managing registering modules, it extends a dictionary + and provides a register functions. + Interfaces: + ``__init__``, ``register``, ``get``, ``build``, ``query``, ``query_details`` + Examples: + creeting a registry: + >>> some_registry = Registry({"default": default_module}) + + There're two ways of registering new modules: + 1): normal way is just calling register function: + >>> def foo(): + >>> ... + some_registry.register("foo_module", foo) + 2): used as decorator when declaring the module: + >>> @some_registry.register("foo_module") + >>> @some_registry.register("foo_modeul_nickname") + >>> def foo(): + >>> ... + + Access of module is just like using a dictionary, eg: + >>> f = some_registry["foo_module"] """ def __init__(self, *args, **kwargs) -> None: + """ + Overview: + Initialize the Registry object. + Arguments: + - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \ + dict. + - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ + dict. + """ + super(Registry, self).__init__(*args, **kwargs) self.__trace__ = dict() @@ -41,6 +54,15 @@ def register( module: Optional[Callable] = None, force_overwrite: bool = False ) -> Callable: + """ + Overview: + Register the module. + Arguments: + - module_name (:obj:`Optional[str]`): The name of the module. + - module (:obj:`Optional[Callable]`): The module to be registered. + - force_overwrite (:obj:`bool`): Whether to overwrite the module with the same name. + """ + if _DI_ENGINE_REG_TRACE_IS_ON: frame = inspect.stack()[1][0] info = inspect.getframeinfo(frame) @@ -69,14 +91,40 @@ def register_fn(fn: Callable) -> Callable: @staticmethod def _register_generic(module_dict: dict, module_name: str, module: Callable, force_overwrite: bool = False) -> None: + """ + Overview: + Register the module. + Arguments: + - module_dict (:obj:`dict`): The dict to store the module. + - module_name (:obj:`str`): The name of the module. + - module (:obj:`Callable`): The module to be registered. + - force_overwrite (:obj:`bool`): Whether to overwrite the module with the same name. + """ + if not force_overwrite: assert module_name not in module_dict, module_name module_dict[module_name] = module def get(self, module_name: str) -> Callable: + """ + Overview: + Get the module. + Arguments: + - module_name (:obj:`str`): The name of the module. + """ + return self[module_name] def build(self, obj_type: str, *obj_args, **obj_kwargs) -> object: + """ + Overview: + Build the object. + Arguments: + - obj_type (:obj:`str`): The type of the object. + - obj_args (:obj:`Tuple`): The arguments passed to the object. + - obj_kwargs (:obj:`Dict`): The keyword arguments passed to the object. + """ + try: build_fn = self[obj_type] return build_fn(*obj_args, **obj_kwargs) @@ -96,9 +144,21 @@ def build(self, obj_type: str, *obj_args, **obj_kwargs) -> object: raise e def query(self) -> Iterable: + """ + Overview: + all registered module names. + """ + return self.keys() def query_details(self, aliases: Optional[Iterable] = None) -> OrderedDict: + """ + Overview: + Get the details of the registered modules. + Arguments: + - aliases (:obj:`Optional[Iterable]`): The aliases of the modules. + """ + assert _DI_ENGINE_REG_TRACE_IS_ON, "please exec 'export DIENGINEREGTRACE=ON' first" if aliases is None: aliases = self.keys() diff --git a/ding/utils/render_helper.py b/ding/utils/render_helper.py index 679263eca1..11aed75941 100644 --- a/ding/utils/render_helper.py +++ b/ding/utils/render_helper.py @@ -6,7 +6,7 @@ def render_env(env, render_mode: Optional[str] = 'rgb_array') -> "ndarray": - ''' + """ Overview: Render the environment's current frame. Arguments: @@ -14,7 +14,7 @@ def render_env(env, render_mode: Optional[str] = 'rgb_array') -> "ndarray": - render_mode (:obj:`str`): Render mode. Returns: - frame (:obj:`numpy.ndarray`): [H * W * C] - ''' + """ if hasattr(env, 'sim'): # mujoco: mujoco frame is unside-down by default return env.sim.render(camera_name='track', height=128, width=128)[::-1] @@ -24,7 +24,7 @@ def render_env(env, render_mode: Optional[str] = 'rgb_array') -> "ndarray": def render(env: "BaseEnv", render_mode: Optional[str] = 'rgb_array') -> "ndarray": - ''' + """ Overview: Render the environment's current frame. Arguments: @@ -32,20 +32,20 @@ def render(env: "BaseEnv", render_mode: Optional[str] = 'rgb_array') -> "ndarray - render_mode (:obj:`str`): Render mode. Returns: - frame (:obj:`numpy.ndarray`): [H * W * C] - ''' + """ gym_env = env._env return render_env(gym_env, render_mode=render_mode) def get_env_fps(env) -> "int": - ''' + """ Overview: Get the environment's fps. Arguments: - env (:obj:`gym.Env`): DI-engine env instance. Returns: - fps (:obj:`int`). - ''' + """ if hasattr(env, 'model'): # mujoco @@ -60,14 +60,14 @@ def get_env_fps(env) -> "int": def fps(env_manager: "BaseEnvManager") -> "int": - ''' + """ Overview: Render the environment's fps. Arguments: - env (:obj:`BaseEnvManager`): DI-engine env manager instance. Returns: - fps (:obj:`int`). - ''' + """ try: # env_ref is a ding gym environment gym_env = env_manager.env_ref._env diff --git a/ding/utils/scheduler_helper.py b/ding/utils/scheduler_helper.py index 9b1c2600a4..d37ce97c52 100644 --- a/ding/utils/scheduler_helper.py +++ b/ding/utils/scheduler_helper.py @@ -9,7 +9,7 @@ class Scheduler(object): For example, models often benefits from reducing entropy weight once the learning process stagnates. This scheduler reads a metrics quantity and if no improvement is seen for a 'patience' number of epochs, the corresponding parameter is increased or decreased, which decides on the 'schedule_mode'. - Args: + Arguments: - schedule_flag (:obj:`bool`): Indicates whether to use scheduler in training pipeline. Default: False - schedule_mode (:obj:`str`): One of 'reduce', 'add','multi','div'. The schecule_mode @@ -48,13 +48,13 @@ class Scheduler(object): ) def __init__(self, merged_scheduler_config: EasyDict) -> None: - ''' + """ Overview: Initialize the scheduler. - Args: + Arguments: - merged_scheduler_config (:obj:`EasyDict`): the scheduler config, which merges the user config and defaul config - ''' + """ schedule_mode = merged_scheduler_config.schedule_mode factor = merged_scheduler_config.factor @@ -100,7 +100,7 @@ def __init__(self, merged_scheduler_config: EasyDict) -> None: self.bad_epochs_num = 0 def step(self, metrics: float, param: float) -> float: - ''' + """ Overview: Decides whether to update the scheduled parameter Args: @@ -108,7 +108,7 @@ def step(self, metrics: float, param: float) -> float: - param (:obj:`float`): parameter need to be updated Returns: - step_param (:obj:`float`): parameter after one step - ''' + """ assert isinstance(metrics, float), 'The metrics should be converted to a float number' cur_metrics = metrics @@ -129,14 +129,14 @@ def step(self, metrics: float, param: float) -> float: return param def update_param(self, param: float) -> float: - ''' + """ Overview: update the scheduling parameter Args: - param (:obj:`float`): parameter need to be updated Returns: - updated param (:obj:`float`): parameter after updating - ''' + """ schedule_fn = { 'reduce': lambda x, y, z: max(x - y, z[0]), 'add': lambda x, y, z: min(x + y, z[1]), @@ -153,20 +153,20 @@ def update_param(self, param: float) -> float: @property def in_cooldown(self) -> bool: - ''' + """ Overview: Checks whether the scheduler is in cooldown peried. If in cooldown, the scheduler will ignore any bad epochs. - ''' + """ return self.cooldown_counter > 0 def is_better(self, cur: float) -> bool: - ''' + """ Overview: Checks whether the current metrics is better than last matric with respect to threshold. Args: - cur (:obj:`float`): current metrics - ''' + """ if self.last_metrics is None: return True diff --git a/ding/utils/segment_tree.py b/ding/utils/segment_tree.py index b92dbed742..5c87280ab4 100644 --- a/ding/utils/segment_tree.py +++ b/ding/utils/segment_tree.py @@ -9,6 +9,11 @@ @lru_cache() def njit(): + """ + Overview: + Decorator to compile a function using numba. + """ + try: if ding.enable_numba: import numba @@ -34,7 +39,7 @@ class SegmentTree: Overview: Segment tree data structure, implemented by the tree-like array. Only the leaf nodes are real value, non-leaf nodes are to do some operations on its left and right child. - Interface: + Interfaces: ``__init__``, ``reduce``, ``__setitem__``, ``__getitem__`` """ @@ -111,6 +116,11 @@ def __getitem__(self, idx: int) -> float: return self.value[idx + self.capacity] def _compile(self) -> None: + """ + Overview: + Compile the functions using numba. + """ + f64 = np.array([0, 1], dtype=np.float64) f32 = np.array([0, 1], dtype=np.float32) i64 = np.array([0, 1], dtype=np.int64) @@ -121,11 +131,19 @@ def _compile(self) -> None: class SumSegmentTree(SegmentTree): + """ + Overview: + Sum segment tree, which is inherited from ``SegmentTree``. Init by passing ``operation='sum'``. + Interfaces: + ``__init__``, ``find_prefixsum_idx`` + """ def __init__(self, capacity: int) -> None: """ Overview: Init sum segment tree by passing ``operation='sum'`` + Arguments: + - capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes). """ super(SumSegmentTree, self).__init__(capacity, operation='sum') @@ -148,17 +166,35 @@ def find_prefixsum_idx(self, prefixsum: float, trust_caller: bool = True) -> int class MinSegmentTree(SegmentTree): + """ + Overview: + Min segment tree, which is inherited from ``SegmentTree``. Init by passing ``operation='min'``. + Interfaces: + ``__init__`` + """ def __init__(self, capacity: int) -> None: """ Overview: - Init sum segment tree by passing ``operation='min'`` + Initialize sum segment tree by passing ``operation='min'`` + Arguments: + - capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes). """ super(MinSegmentTree, self).__init__(capacity, operation='min') @njit() def _setitem(tree: np.ndarray, idx: int, val: float, operation: str) -> None: + """ + Overview: + Set ``tree[idx] = val``; Then update the related nodes. + Arguments: + - tree (:obj:`np.ndarray`): The tree array. + - idx (:obj:`int`): The index of the leaf node. + - val (:obj:`float`): The value that will be assigned to ``leaf[idx]``. + - operation (:obj:`str`): The operation function to construct the tree, e.g. sum, max, min, etc. + """ + tree[idx] = val # Update from specified node to the root node while idx > 1: @@ -172,6 +208,18 @@ def _setitem(tree: np.ndarray, idx: int, val: float, operation: str) -> None: @njit() def _reduce(tree: np.ndarray, start: int, end: int, neutral_element: float, operation: str) -> float: + """ + Overview: + Reduce the tree in range ``[start, end)`` + Arguments: + - tree (:obj:`np.ndarray`): The tree array. + - start (:obj:`int`): Start index(relative index, the first leaf node is 0). + - end (:obj:`int`): End index(relative index). + - neutral_element (:obj:`float`): The value of the neutral element, which is used to init \ + all nodes value in the tree. + - operation (:obj:`str`): The operation function to construct the tree, e.g. sum, max, min, etc. + """ + # Nodes in 【start, end) will be aggregated result = neutral_element while start < end: @@ -197,6 +245,18 @@ def _reduce(tree: np.ndarray, start: int, end: int, neutral_element: float, oper @njit() def _find_prefixsum_idx(tree: np.ndarray, capacity: int, prefixsum: float, neutral_element: float) -> int: + """ + Overview: + Find the highest non-zero index i, sum_{j}leaf[j] <= ``prefixsum`` (where 0 <= j < i) + and sum_{j}leaf[j] > ``prefixsum`` (where 0 <= j < i+1) + Arguments: + - tree (:obj:`np.ndarray`): The tree array. + - capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes). + - prefixsum (:obj:`float`): The target prefixsum. + - neutral_element (:obj:`float`): The value of the neutral element, which is used to init \ + all nodes value in the tree. + """ + # The function is to find a non-leaf node's index which satisfies: # self.value[idx] > input prefixsum and self.value[idx + 1] <= input prefixsum # In other words, we can assume that there are intervals: [num_0, num_1), [num_1, num_2), ... [num_k, num_k+1), diff --git a/ding/utils/slurm_helper.py b/ding/utils/slurm_helper.py index 4a1c35f4d2..c03b3e9463 100644 --- a/ding/utils/slurm_helper.py +++ b/ding/utils/slurm_helper.py @@ -14,6 +14,11 @@ def get_ip() -> str: + """ + Overview: + Get the ip of the current node + """ + assert os.environ.get('SLURMD_NODENAME'), 'not found SLURMD_NODENAME env variable' # expecting nodename to be like: 'SH-IDC1-10-5-36-64' nodename = os.environ.get('SLURMD_NODENAME', '') @@ -22,9 +27,11 @@ def get_ip() -> str: def get_manager_node_ip(node_ip: Optional[str] = None) -> str: - r""" + """ Overview: Look up the manager node of the slurm cluster and return the node ip + Arguments: + - node_ip (:obj:`Optional[str]`): The ip of the current node """ if 'SLURM_JOB_ID' not in os.environ: from ditk import logging @@ -44,6 +51,11 @@ def get_manager_node_ip(node_ip: Optional[str] = None) -> str: # get all info of cluster def get_cls_info() -> Dict[str, list]: + """ + Overview: + Get the cluster info + """ + ret_dict = {} info = subprocess.getoutput('sinfo -Nh').split('\n') for line in info: @@ -61,6 +73,13 @@ def get_cls_info() -> Dict[str, list]: def node_to_partition(target_node: str) -> Tuple[str, str]: + """ + Overview: + Get the partition of the target node + Arguments: + - target_node (:obj:`str`): The target node + """ + info = subprocess.getoutput('sinfo -Nh').split('\n') for line in info: line = line.strip().split() @@ -73,10 +92,24 @@ def node_to_partition(target_node: str) -> Tuple[str, str]: def node_to_host(node: str) -> str: + """ + Overview: + Get the host of the node + Arguments: + - node (:obj:`str`): The node + """ + return '.'.join(node.split('-')[-4:]) def find_free_port_slurm(node: str) -> int: + """ + Overview: + Find a free port on the node + Arguments: + - node (:obj:`str`): The node + """ + partition = node_to_partition(node) if partition == 'spring_scheduler': comment = '--comment=spring-submit' diff --git a/ding/utils/system_helper.py b/ding/utils/system_helper.py index 118513d36b..915ef380e9 100644 --- a/ding/utils/system_helper.py +++ b/ding/utils/system_helper.py @@ -8,7 +8,7 @@ def get_ip() -> str: - r""" + """ Overview: Get the ``ip(host)`` of socket Returns: @@ -22,7 +22,7 @@ def get_ip() -> str: def get_pid() -> int: - r""" + """ Overview: ``os.getpid`` """ @@ -30,7 +30,7 @@ def get_pid() -> int: def get_task_uid() -> str: - r""" + """ Overview: Get the slurm ``job_id``, ``pid`` and ``uid`` """ @@ -41,7 +41,7 @@ class PropagatingThread(Thread): """ Overview: Subclass of Thread that propagates execution exception in the thread to the caller - Interface: + Interfaces: ``run``, ``join`` Examples: >>> def func(): @@ -52,6 +52,11 @@ class PropagatingThread(Thread): """ def run(self) -> None: + """ + Overview: + Run the thread + """ + self.exc = None try: self.ret = self._target(*self._args, **self._kwargs) @@ -59,6 +64,11 @@ def run(self) -> None: self.exc = e def join(self) -> Any: + """ + Overview: + Join the thread + """ + super(PropagatingThread, self).join() if self.exc: raise RuntimeError('Exception in thread({})'.format(id(self))) from self.exc @@ -66,9 +76,11 @@ def join(self) -> Any: def find_free_port(host: str) -> int: - r""" + """ Overview: Look up the free port list and return one + Arguments: + - host (:obj:`str`): The host """ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(('', 0)) diff --git a/ding/utils/time_helper.py b/ding/utils/time_helper.py index 76cfccb2e8..06498e4546 100644 --- a/ding/utils/time_helper.py +++ b/ding/utils/time_helper.py @@ -9,7 +9,7 @@ def build_time_helper(cfg: EasyDict = None, wrapper_type: str = None) -> Callable[[], 'TimeWrapper']: - r""" + """ Overview: Build the timehelper @@ -46,11 +46,11 @@ def build_time_helper(cfg: EasyDict = None, wrapper_type: str = None) -> Callabl class EasyTimer: - r""" + """ Overview: A decent timer wrapper that can be used easily. - Interface: + Interfaces: ``__init__``, ``__enter__``, ``__exit__`` Example: @@ -61,7 +61,7 @@ class EasyTimer: """ def __init__(self, cuda=True): - r""" + """ Overview: Init class EasyTimer @@ -76,7 +76,7 @@ def __init__(self, cuda=True): self.value = 0.0 def __enter__(self): - r""" + """ Overview: Enter timer, start timing """ @@ -84,7 +84,7 @@ def __enter__(self): self._timer.start_time() def __exit__(self, *args): - r""" + """ Overview: Exit timer, stop timing """ @@ -92,18 +92,18 @@ def __exit__(self, *args): class TimeWrapperTime(TimeWrapper): - r""" + """ Overview: A class method that inherit from ``TimeWrapper`` class - Interface: + Interfaces: ``start_time``, ``end_time`` """ # overwrite @classmethod def start_time(cls): - r""" + """ Overview: Implement and override the ``start_time`` method in ``TimeWrapper`` class """ @@ -112,7 +112,7 @@ def start_time(cls): # overwrite @classmethod def end_time(cls): - r""" + """ Overview: Implement and override the end_time method in ``TimeWrapper`` class @@ -134,7 +134,7 @@ class WatchDog(object): .. note:: If it is not reset before exceeding this value, ``TimeourError`` raised. - Interface: + Interfaces: ``start``, ``stop`` Examples: @@ -146,11 +146,18 @@ class WatchDog(object): """ def __init__(self, timeout: int = 1): + """ + Overview: + Initialize watchdog with ``timeout`` value. + Arguments: + - timeout (:obj:`int`): Timeout value of the ``watchdog [seconds]``. + """ + self._timeout = timeout + 1 self._failed = False def start(self): - r""" + """ Overview: Start watchdog. """ @@ -159,10 +166,18 @@ def start(self): @staticmethod def _event(signum: Any, frame: Any): + """ + Overview: + Event handler for watchdog. + Arguments: + - signum (:obj:`Any`): Signal number. + - frame (:obj:`Any`): Current stack frame. + """ + raise TimeoutError() def stop(self): - r""" + """ Overview: Stop watchdog with ``alarm(0)``, ``SIGALRM``, and ``SIG_DFL`` signals. """ diff --git a/ding/utils/time_helper_base.py b/ding/utils/time_helper_base.py index c77aa1c8d4..86f58d0fe8 100644 --- a/ding/utils/time_helper_base.py +++ b/ding/utils/time_helper_base.py @@ -1,19 +1,19 @@ class TimeWrapper(object): - r""" + """ Overview: Abstract class method that defines ``TimeWrapper`` class - Interface: + Interfaces: ``wrapper``, ``start_time``, ``end_time`` """ @classmethod def wrapper(cls, fn): - r""" + """ Overview: Classmethod wrapper, wrap a function and automatically return its running time - - - fn (:obj:`function`): The function to be wrap and timed + Arguments: + - fn (:obj:`function`): The function to be wrap and timed """ def time_func(*args, **kwargs): @@ -26,7 +26,7 @@ def time_func(*args, **kwargs): @classmethod def start_time(cls): - r""" + """ Overview: Abstract classmethod, start timing """ @@ -34,7 +34,7 @@ def start_time(cls): @classmethod def end_time(cls): - r""" + """ Overview: Abstract classmethod, stop timing """ diff --git a/ding/utils/time_helper_cuda.py b/ding/utils/time_helper_cuda.py index da691bb6b3..51ea5e925a 100644 --- a/ding/utils/time_helper_cuda.py +++ b/ding/utils/time_helper_cuda.py @@ -4,7 +4,7 @@ def get_cuda_time_wrapper() -> Callable[[], 'TimeWrapper']: - r""" + """ Overview: Return the ``TimeWrapperCuda`` class, this wrapper aims to ensure compatibility in no cuda device @@ -18,7 +18,7 @@ def get_cuda_time_wrapper() -> Callable[[], 'TimeWrapper']: # TODO find a way to autodoc the class within method class TimeWrapperCuda(TimeWrapper): - r""" + """ Overview: A class method that inherit from ``TimeWrapper`` class @@ -26,7 +26,7 @@ class TimeWrapperCuda(TimeWrapper): Must use torch.cuda.synchronize(), reference: \ - Interface: + Interfaces: ``start_time``, ``end_time`` """ # cls variable is initialized on loading this class @@ -36,7 +36,7 @@ class TimeWrapperCuda(TimeWrapper): # overwrite @classmethod def start_time(cls): - r""" + """ Overview: Implement and overide the ``start_time`` method in ``TimeWrapper`` class """ @@ -46,7 +46,7 @@ def start_time(cls): # overwrite @classmethod def end_time(cls): - r""" + """ Overview: Implement and overide the end_time method in ``TimeWrapper`` class Returns: