From ba027f0d133f7cfeaaca07bcfbf196dfde153672 Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Wed, 14 Feb 2024 14:19:05 +0800 Subject: [PATCH] docstrings for input distributions --- src/pyjuice/nodes/distributions/bernoulli.py | 9 ++ .../nodes/distributions/categorical.py | 12 ++ .../nodes/distributions/discrete_logistic.py | 19 +++ .../nodes/distributions/distributions.py | 112 ++++++++++-------- src/pyjuice/nodes/distributions/gaussian.py | 27 +++-- .../nodes/distributions/masked_categorical.py | 16 +++ 6 files changed, 137 insertions(+), 58 deletions(-) diff --git a/src/pyjuice/nodes/distributions/bernoulli.py b/src/pyjuice/nodes/distributions/bernoulli.py index 2344e3d6..d32f38bc 100644 --- a/src/pyjuice/nodes/distributions/bernoulli.py +++ b/src/pyjuice/nodes/distributions/bernoulli.py @@ -10,13 +10,22 @@ class Bernoulli(Distribution): + """ + A class representing Bernoulli distributions. + """ def __init__(self): super(Bernoulli, self).__init__() def get_signature(self): + """ + Get the signature of the current distribution. + """ return "Bernoulli" def get_metadata(self): + """ + Get the metadata of the current distribution. + """ return [] def num_parameters(self): diff --git a/src/pyjuice/nodes/distributions/categorical.py b/src/pyjuice/nodes/distributions/categorical.py index a72fbda1..069ee33f 100644 --- a/src/pyjuice/nodes/distributions/categorical.py +++ b/src/pyjuice/nodes/distributions/categorical.py @@ -10,15 +10,27 @@ class Categorical(Distribution): + """ + A class representing Categorical distributions. + + :param num_cats: number of categories + :type num_cats: int + """ def __init__(self, num_cats: int): super(Categorical, self).__init__() self.num_cats = num_cats def get_signature(self): + """ + Get the signature of the current distribution. + """ return "Categorical" def get_metadata(self): + """ + Get the metadata of the current distribution. + """ return [self.num_cats] def normalize_parameters(self, params: torch.Tensor): diff --git a/src/pyjuice/nodes/distributions/discrete_logistic.py b/src/pyjuice/nodes/distributions/discrete_logistic.py index 3b2df028..e74488df 100644 --- a/src/pyjuice/nodes/distributions/discrete_logistic.py +++ b/src/pyjuice/nodes/distributions/discrete_logistic.py @@ -10,6 +10,19 @@ class DiscreteLogistic(Distribution): + """ + A class representing Discrete Logistic distributions. + + :param val_range: range of the values represented by the distribution + :type val_range: Tuple[float,float] + + :param num_cats: number of categories + :type num_cats: int + + :param min_std: minimum standard deviation + :type min_std: float + """ + def __init__(self, val_range: Tuple[float,float], num_cats: int, min_std: float = 0.01): super(DiscreteLogistic, self).__init__() @@ -18,9 +31,15 @@ def __init__(self, val_range: Tuple[float,float], num_cats: int, min_std: float self.min_std = min_std def get_signature(self): + """ + Get the signature of the current distribution. + """ return "DiscreteLogistic" def get_metadata(self): + """ + Get the metadata of the current distribution. + """ return [self.val_range[0], self.val_range[1], self.num_cats, self.min_std] def num_parameters(self): diff --git a/src/pyjuice/nodes/distributions/distributions.py b/src/pyjuice/nodes/distributions/distributions.py index 2eb4caa9..9b7c3217 100644 --- a/src/pyjuice/nodes/distributions/distributions.py +++ b/src/pyjuice/nodes/distributions/distributions.py @@ -10,18 +10,28 @@ def __init__(self): pass def get_signature(self): + """ + Get the signature of the current distribution. + """ raise NotImplementedError() def get_metadata(self): + """ + Get the metadata of the current distribution. + """ return [] # no metadata def normalize_parameters(self, params: torch.Tensor, **kwargs): + """ + Normalize node parameters. + """ return params def set_meta_parameters(self, **kwargs): """ Assign meta-parameters to `self._params`. - Note: the actual parameters are not initialized after this function call. + + :note: the actual parameters are not initialized after this function call. """ raise NotImplementedError() @@ -63,16 +73,16 @@ def need_meta_parameters(self): def fw_mar_fn(*args, **kwargs): """ Forward evaluation for log-probabilities. - Args: - `local_offsets`: [BLOCK_SIZE] the local indices of the to-be-processed input nodes - `data`: [BLOCK_SIZE, num_vars_per_node] data of the corresponding nodes - `params_ptr`: pointer to the parameter vector - `s_pids`: [BLOCK_SIZE] start parameter index (offset) for all input nodes - `metadata_ptr`: pointer to metadata - `s_mids_ptr`: pointer to the start metadata index (offset) - `mask`: [BLOCK_SIZE] indicate whether each node should be processed - `num_vars_per_node`: numbers of variables per input node/distribution - `BLOCK_SIZE`: CUDA block size + + :param local_offsets: [BLOCK_SIZE] the local indices of the to-be-processed input nodes + :param data: [BLOCK_SIZE, num_vars_per_node] data of the corresponding nodes + :param params_ptr: pointer to the parameter vector + :param s_pids: [BLOCK_SIZE] start parameter index (offset) for all input nodes + :param metadata_ptr: pointer to metadata + :param s_mids_ptr: pointer to the start metadata index (offset) + :param mask: [BLOCK_SIZE] indicate whether each node should be processed + :param num_vars_per_node: numbers of variables per input node/distribution + :param BLOCK_SIZE: CUDA block size """ raise NotImplementedError() @@ -80,21 +90,21 @@ def fw_mar_fn(*args, **kwargs): def bk_flow_fn(*args, **kwargs): """ Accumulate statistics and compute input parameter flows. - Args: - `local_offsets`: [BLOCK_SIZE] the local indices of the to-be-processed input nodes - `ns_offsets`: [BLOCK_SIZE] the global offsets used to load from `node_mars_ptr` - `data`: [BLOCK_SIZE, num_vars_per_node] data of the corresponding nodes - `flows`: [BLOCK_SIZE] node flows - `node_mars_ptr`: pointer to the forward values - `params_ptr`: pointer to the parameter vector - `param_flows_ptr`: pointer to the parameter flow vector - `s_pids`: [BLOCK_SIZE] start parameter index (offset) for all input nodes - `s_pfids`: [BLOCK_SIZE] start parameter flow index (offset) for all input nodes - `metadata_ptr`: pointer to metadata - `s_mids_ptr`: pointer to the start metadata index (offset) - `mask`: [BLOCK_SIZE] indicate whether each node should be processed - `num_vars_per_node`: numbers of variables per input node/distribution - `BLOCK_SIZE`: CUDA block size + + :param local_offsets: [BLOCK_SIZE] the local indices of the to-be-processed input nodes + :param ns_offsets: [BLOCK_SIZE] the global offsets used to load from `node_mars_ptr` + :param data: [BLOCK_SIZE, num_vars_per_node] data of the corresponding nodes + :param flows: [BLOCK_SIZE] node flows + :param node_mars_ptr: pointer to the forward values + :param params_ptr: pointer to the parameter vector + :param param_flows_ptr: pointer to the parameter flow vector + :param s_pids: [BLOCK_SIZE] start parameter index (offset) for all input nodes + :param s_pfids: [BLOCK_SIZE] start parameter flow index (offset) for all input nodes + :param metadata_ptr: pointer to metadata + :param s_mids_ptr: pointer to the start metadata index (offset) + :param mask: [BLOCK_SIZE] indicate whether each node should be processed + :param num_vars_per_node: numbers of variables per input node/distribution + :param BLOCK_SIZE: CUDA block size """ raise NotImplementedError() @@ -102,19 +112,19 @@ def bk_flow_fn(*args, **kwargs): def sample_fn(*args, **kwargs): """ Sample from the distribution. - Args: - `samples_ptr`: pointer to store the resultant samples - `local_offsets`: [BLOCK_SIZE] the local indices of the to-be-processed input nodes - `batch_offsets`: [BLOCK_SIZE] batch id corresponding to every node - `vids`: [BLOCK_SIZE] variable ids (only univariate distributions are supported) - `s_pids`: [BLOCK_SIZE] start parameter index (offset) for all input nodes - `params_ptr`: pointer to the parameter vector - `metadata_ptr`: pointer to metadata - `s_mids_ptr`: pointer to the start metadata index (offset) - `mask`: [BLOCK_SIZE] indicate whether each node should be processed - `batch_size`: batch size - `BLOCK_SIZE`: CUDA block size - `seed`: random seed + + :param samples_ptr: pointer to store the resultant samples + :param local_offsets: [BLOCK_SIZE] the local indices of the to-be-processed input nodes + :param batch_offsets: [BLOCK_SIZE] batch id corresponding to every node + :param vids: [BLOCK_SIZE] variable ids (only univariate distributions are supported) + :param s_pids: [BLOCK_SIZE] start parameter index (offset) for all input nodes + :param params_ptr: pointer to the parameter vector + :param metadata_ptr: pointer to metadata + :param s_mids_ptr: pointer to the start metadata index (offset) + :param mask: [BLOCK_SIZE] indicate whether each node should be processed + :param batch_size: batch size + :param BLOCK_SIZE: CUDA block size + :param seed: random seed """ raise NotImplementedError() @@ -122,18 +132,18 @@ def sample_fn(*args, **kwargs): def em_fn(*args, **kwargs): """ Parameter update with EM - Args: - `local_offsets`: [BLOCK_SIZE] the local indices of the to-be-processed input nodes - `params_ptr`: pointer to the parameter vector - `param_flows_ptr`: pointer to the parameter flow vector - `s_pids`: [BLOCK_SIZE] start parameter index (offset) for all input nodes - `s_pfids`: [BLOCK_SIZE] start parameter flow index (offset) for all input nodes - `metadata_ptr`: pointer to metadata - `s_mids_ptr`: pointer to the start metadata index (offset) - `mask`: [BLOCK_SIZE] indicate whether each node should be processed - `step_size`: EM step size (0, 1] - `pseudocount`: pseudocount - `BLOCK_SIZE`: CUDA block size + + :param local_offsets: [BLOCK_SIZE] the local indices of the to-be-processed input nodes + :param params_ptr: pointer to the parameter vector + :param param_flows_ptr: pointer to the parameter flow vector + :param s_pids: [BLOCK_SIZE] start parameter index (offset) for all input nodes + :param s_pfids: [BLOCK_SIZE] start parameter flow index (offset) for all input nodes + :param metadata_ptr: pointer to metadata + :param s_mids_ptr: pointer to the start metadata index (offset) + :param mask: [BLOCK_SIZE] indicate whether each node should be processed + :param step_size: EM step size (0, 1] + :param pseudocount: pseudocount + :param BLOCK_SIZE: CUDA block size """ raise NotImplementedError() diff --git a/src/pyjuice/nodes/distributions/gaussian.py b/src/pyjuice/nodes/distributions/gaussian.py index b2cfaf7a..deb88ea9 100644 --- a/src/pyjuice/nodes/distributions/gaussian.py +++ b/src/pyjuice/nodes/distributions/gaussian.py @@ -10,14 +10,21 @@ class Gaussian(Distribution): + """ + A class representing Gaussian distributions. + + :note: `mu` and `sigma` are used to specify (approximately) the mean and std of the data. This is used for parameter initialization. + + :note: The parameters will NOT be initialized directly using the values of `mu` and `sigma`, perturbations will be added. You can specify the initialization behavior by passing `perturbation`, `mu`, and `sigma` to the `init_parameters` function. + + :param mu: mean of the Gaussian + :type mu: float + + :param sigma: standard deviation of the Gaussian + :type sigma: float + """ + def __init__(self, mu: Optional[float] = None, sigma: Optional[float] = None, min_sigma: float = 0.01): - """ - `mu` and `sigma` are used to specify (approximately) the mean and std of the data. - This is used for parameter initialization. - Note: the parameters will NOT be initialized directly using the values of `mu` and `sigma`, - perturbations will be added. You can specify the initialization behavior by passing - `perturbation`, `mu`, and `sigma` to the `init_parameters` function. - """ super(Gaussian, self).__init__() self.mu = mu @@ -25,9 +32,15 @@ def __init__(self, mu: Optional[float] = None, sigma: Optional[float] = None, mi self.min_sigma = min_sigma def get_signature(self): + """ + Get the signature of the current distribution. + """ return "Gaussian" def get_metadata(self): + """ + Get the metadata of the current distribution. + """ return [self.min_sigma] def num_parameters(self): diff --git a/src/pyjuice/nodes/distributions/masked_categorical.py b/src/pyjuice/nodes/distributions/masked_categorical.py index b70ac437..d5aef348 100644 --- a/src/pyjuice/nodes/distributions/masked_categorical.py +++ b/src/pyjuice/nodes/distributions/masked_categorical.py @@ -10,6 +10,16 @@ class MaskedCategorical(Distribution): + """ + A class representing Categorical distributions with masks. + + :param num_cats: number of categories + :type num_cats: int + + :param mask_mode: type of mask; should be in ["range", "full_mask", "rev_range"] + :type num_cats: str + """ + def __init__(self, num_cats: int, mask_mode: str): super(MaskedCategorical, self).__init__() @@ -39,9 +49,15 @@ def __init__(self, num_cats: int, mask_mode: str): self.em_fn = self.em_fn_rev_range def get_signature(self): + """ + Get the signature of the current distribution. + """ return f"MaskedCategorical-{self.mask_mode}" def get_metadata(self): + """ + Get the metadata of the current distribution. + """ return [self.num_cats] def num_parameters(self):