Skip to content

Commit

Permalink
docstrings for input distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Feb 14, 2024
1 parent b012853 commit ba027f0
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 58 deletions.
9 changes: 9 additions & 0 deletions src/pyjuice/nodes/distributions/bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,22 @@


class Bernoulli(Distribution):
"""
A class representing Bernoulli distributions.
"""
def __init__(self):
super(Bernoulli, self).__init__()

def get_signature(self):
"""
Get the signature of the current distribution.
"""
return "Bernoulli"

def get_metadata(self):
"""
Get the metadata of the current distribution.
"""
return []

def num_parameters(self):
Expand Down
12 changes: 12 additions & 0 deletions src/pyjuice/nodes/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,27 @@


class Categorical(Distribution):
"""
A class representing Categorical distributions.
:param num_cats: number of categories
:type num_cats: int
"""
def __init__(self, num_cats: int):
super(Categorical, self).__init__()

self.num_cats = num_cats

def get_signature(self):
"""
Get the signature of the current distribution.
"""
return "Categorical"

def get_metadata(self):
"""
Get the metadata of the current distribution.
"""
return [self.num_cats]

def normalize_parameters(self, params: torch.Tensor):
Expand Down
19 changes: 19 additions & 0 deletions src/pyjuice/nodes/distributions/discrete_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@


class DiscreteLogistic(Distribution):
"""
A class representing Discrete Logistic distributions.
:param val_range: range of the values represented by the distribution
:type val_range: Tuple[float,float]
:param num_cats: number of categories
:type num_cats: int
:param min_std: minimum standard deviation
:type min_std: float
"""

def __init__(self, val_range: Tuple[float,float], num_cats: int, min_std: float = 0.01):
super(DiscreteLogistic, self).__init__()

Expand All @@ -18,9 +31,15 @@ def __init__(self, val_range: Tuple[float,float], num_cats: int, min_std: float
self.min_std = min_std

def get_signature(self):
"""
Get the signature of the current distribution.
"""
return "DiscreteLogistic"

def get_metadata(self):
"""
Get the metadata of the current distribution.
"""
return [self.val_range[0], self.val_range[1], self.num_cats, self.min_std]

def num_parameters(self):
Expand Down
112 changes: 61 additions & 51 deletions src/pyjuice/nodes/distributions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,28 @@ def __init__(self):
pass

def get_signature(self):
"""
Get the signature of the current distribution.
"""
raise NotImplementedError()

def get_metadata(self):
"""
Get the metadata of the current distribution.
"""
return [] # no metadata

def normalize_parameters(self, params: torch.Tensor, **kwargs):
"""
Normalize node parameters.
"""
return params

def set_meta_parameters(self, **kwargs):
"""
Assign meta-parameters to `self._params`.
Note: the actual parameters are not initialized after this function call.
:note: the actual parameters are not initialized after this function call.
"""
raise NotImplementedError()

Expand Down Expand Up @@ -63,77 +73,77 @@ def need_meta_parameters(self):
def fw_mar_fn(*args, **kwargs):
"""
Forward evaluation for log-probabilities.
Args:
`local_offsets`: [BLOCK_SIZE] the local indices of the to-be-processed input nodes
`data`: [BLOCK_SIZE, num_vars_per_node] data of the corresponding nodes
`params_ptr`: pointer to the parameter vector
`s_pids`: [BLOCK_SIZE] start parameter index (offset) for all input nodes
`metadata_ptr`: pointer to metadata
`s_mids_ptr`: pointer to the start metadata index (offset)
`mask`: [BLOCK_SIZE] indicate whether each node should be processed
`num_vars_per_node`: numbers of variables per input node/distribution
`BLOCK_SIZE`: CUDA block size
:param local_offsets: [BLOCK_SIZE] the local indices of the to-be-processed input nodes
:param data: [BLOCK_SIZE, num_vars_per_node] data of the corresponding nodes
:param params_ptr: pointer to the parameter vector
:param s_pids: [BLOCK_SIZE] start parameter index (offset) for all input nodes
:param metadata_ptr: pointer to metadata
:param s_mids_ptr: pointer to the start metadata index (offset)
:param mask: [BLOCK_SIZE] indicate whether each node should be processed
:param num_vars_per_node: numbers of variables per input node/distribution
:param BLOCK_SIZE: CUDA block size
"""
raise NotImplementedError()

@staticmethod
def bk_flow_fn(*args, **kwargs):
"""
Accumulate statistics and compute input parameter flows.
Args:
`local_offsets`: [BLOCK_SIZE] the local indices of the to-be-processed input nodes
`ns_offsets`: [BLOCK_SIZE] the global offsets used to load from `node_mars_ptr`
`data`: [BLOCK_SIZE, num_vars_per_node] data of the corresponding nodes
`flows`: [BLOCK_SIZE] node flows
`node_mars_ptr`: pointer to the forward values
`params_ptr`: pointer to the parameter vector
`param_flows_ptr`: pointer to the parameter flow vector
`s_pids`: [BLOCK_SIZE] start parameter index (offset) for all input nodes
`s_pfids`: [BLOCK_SIZE] start parameter flow index (offset) for all input nodes
`metadata_ptr`: pointer to metadata
`s_mids_ptr`: pointer to the start metadata index (offset)
`mask`: [BLOCK_SIZE] indicate whether each node should be processed
`num_vars_per_node`: numbers of variables per input node/distribution
`BLOCK_SIZE`: CUDA block size
:param local_offsets: [BLOCK_SIZE] the local indices of the to-be-processed input nodes
:param ns_offsets: [BLOCK_SIZE] the global offsets used to load from `node_mars_ptr`
:param data: [BLOCK_SIZE, num_vars_per_node] data of the corresponding nodes
:param flows: [BLOCK_SIZE] node flows
:param node_mars_ptr: pointer to the forward values
:param params_ptr: pointer to the parameter vector
:param param_flows_ptr: pointer to the parameter flow vector
:param s_pids: [BLOCK_SIZE] start parameter index (offset) for all input nodes
:param s_pfids: [BLOCK_SIZE] start parameter flow index (offset) for all input nodes
:param metadata_ptr: pointer to metadata
:param s_mids_ptr: pointer to the start metadata index (offset)
:param mask: [BLOCK_SIZE] indicate whether each node should be processed
:param num_vars_per_node: numbers of variables per input node/distribution
:param BLOCK_SIZE: CUDA block size
"""
raise NotImplementedError()

@staticmethod
def sample_fn(*args, **kwargs):
"""
Sample from the distribution.
Args:
`samples_ptr`: pointer to store the resultant samples
`local_offsets`: [BLOCK_SIZE] the local indices of the to-be-processed input nodes
`batch_offsets`: [BLOCK_SIZE] batch id corresponding to every node
`vids`: [BLOCK_SIZE] variable ids (only univariate distributions are supported)
`s_pids`: [BLOCK_SIZE] start parameter index (offset) for all input nodes
`params_ptr`: pointer to the parameter vector
`metadata_ptr`: pointer to metadata
`s_mids_ptr`: pointer to the start metadata index (offset)
`mask`: [BLOCK_SIZE] indicate whether each node should be processed
`batch_size`: batch size
`BLOCK_SIZE`: CUDA block size
`seed`: random seed
:param samples_ptr: pointer to store the resultant samples
:param local_offsets: [BLOCK_SIZE] the local indices of the to-be-processed input nodes
:param batch_offsets: [BLOCK_SIZE] batch id corresponding to every node
:param vids: [BLOCK_SIZE] variable ids (only univariate distributions are supported)
:param s_pids: [BLOCK_SIZE] start parameter index (offset) for all input nodes
:param params_ptr: pointer to the parameter vector
:param metadata_ptr: pointer to metadata
:param s_mids_ptr: pointer to the start metadata index (offset)
:param mask: [BLOCK_SIZE] indicate whether each node should be processed
:param batch_size: batch size
:param BLOCK_SIZE: CUDA block size
:param seed: random seed
"""
raise NotImplementedError()

@staticmethod
def em_fn(*args, **kwargs):
"""
Parameter update with EM
Args:
`local_offsets`: [BLOCK_SIZE] the local indices of the to-be-processed input nodes
`params_ptr`: pointer to the parameter vector
`param_flows_ptr`: pointer to the parameter flow vector
`s_pids`: [BLOCK_SIZE] start parameter index (offset) for all input nodes
`s_pfids`: [BLOCK_SIZE] start parameter flow index (offset) for all input nodes
`metadata_ptr`: pointer to metadata
`s_mids_ptr`: pointer to the start metadata index (offset)
`mask`: [BLOCK_SIZE] indicate whether each node should be processed
`step_size`: EM step size (0, 1]
`pseudocount`: pseudocount
`BLOCK_SIZE`: CUDA block size
:param local_offsets: [BLOCK_SIZE] the local indices of the to-be-processed input nodes
:param params_ptr: pointer to the parameter vector
:param param_flows_ptr: pointer to the parameter flow vector
:param s_pids: [BLOCK_SIZE] start parameter index (offset) for all input nodes
:param s_pfids: [BLOCK_SIZE] start parameter flow index (offset) for all input nodes
:param metadata_ptr: pointer to metadata
:param s_mids_ptr: pointer to the start metadata index (offset)
:param mask: [BLOCK_SIZE] indicate whether each node should be processed
:param step_size: EM step size (0, 1]
:param pseudocount: pseudocount
:param BLOCK_SIZE: CUDA block size
"""
raise NotImplementedError()

Expand Down
27 changes: 20 additions & 7 deletions src/pyjuice/nodes/distributions/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,37 @@


class Gaussian(Distribution):
"""
A class representing Gaussian distributions.
:note: `mu` and `sigma` are used to specify (approximately) the mean and std of the data. This is used for parameter initialization.
:note: The parameters will NOT be initialized directly using the values of `mu` and `sigma`, perturbations will be added. You can specify the initialization behavior by passing `perturbation`, `mu`, and `sigma` to the `init_parameters` function.
:param mu: mean of the Gaussian
:type mu: float
:param sigma: standard deviation of the Gaussian
:type sigma: float
"""

def __init__(self, mu: Optional[float] = None, sigma: Optional[float] = None, min_sigma: float = 0.01):
"""
`mu` and `sigma` are used to specify (approximately) the mean and std of the data.
This is used for parameter initialization.
Note: the parameters will NOT be initialized directly using the values of `mu` and `sigma`,
perturbations will be added. You can specify the initialization behavior by passing
`perturbation`, `mu`, and `sigma` to the `init_parameters` function.
"""
super(Gaussian, self).__init__()

self.mu = mu
self.sigma = sigma
self.min_sigma = min_sigma

def get_signature(self):
"""
Get the signature of the current distribution.
"""
return "Gaussian"

def get_metadata(self):
"""
Get the metadata of the current distribution.
"""
return [self.min_sigma]

def num_parameters(self):
Expand Down
16 changes: 16 additions & 0 deletions src/pyjuice/nodes/distributions/masked_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@


class MaskedCategorical(Distribution):
"""
A class representing Categorical distributions with masks.
:param num_cats: number of categories
:type num_cats: int
:param mask_mode: type of mask; should be in ["range", "full_mask", "rev_range"]
:type num_cats: str
"""

def __init__(self, num_cats: int, mask_mode: str):
super(MaskedCategorical, self).__init__()

Expand Down Expand Up @@ -39,9 +49,15 @@ def __init__(self, num_cats: int, mask_mode: str):
self.em_fn = self.em_fn_rev_range

def get_signature(self):
"""
Get the signature of the current distribution.
"""
return f"MaskedCategorical-{self.mask_mode}"

def get_metadata(self):
"""
Get the metadata of the current distribution.
"""
return [self.num_cats]

def num_parameters(self):
Expand Down

0 comments on commit ba027f0

Please sign in to comment.