Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimization algorithms and stabilize the backward pass #12

Merged
merged 53 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
edb67ec
`MPE` and `GeneralLL` for forward pass (block sparse kernels)
liuanji Mar 9, 2024
851bd64
`MPE` and `GeneralLL` for forward pass (sparse kernels)
liuanji Mar 9, 2024
5f48d12
`MPE` and `GeneralLL` for forward pass (pytorch kernels)
liuanji Mar 9, 2024
4d6674d
`MPE` and `GeneralLL` for backward pass (block sparse kernels)
liuanji Mar 9, 2024
5da7471
`MPE` and `GeneralLL` backward pass for sparse kernels + triton bug f…
liuanji Mar 10, 2024
ab27e0e
fix arg passing
liuanji Mar 10, 2024
c2bd055
fix visualization functions
liuanji Mar 10, 2024
8b49e74
improve numerical stability of hclt correctness tests
liuanji Mar 10, 2024
b925bdc
setup pytest
liuanji Mar 10, 2024
05b9160
stabilize runtests
liuanji Mar 10, 2024
8005977
fix typo in HMM
liuanji Mar 10, 2024
ee0596b
limit tile size for `MPE` propagation method to avoid kernel stall
liuanji Mar 10, 2024
a419fdb
typo
liuanji Mar 10, 2024
3fd1c74
runtests for the viterbi algorithm and the generalized EM algorithm
liuanji Mar 10, 2024
e783187
fast runtest for `GeneralLL` with HMMs
liuanji Mar 10, 2024
d78cbf1
add fast runtest for `MPE` propagation method + fix tile size allocat…
liuanji Mar 10, 2024
c7caf3f
speedup runtests
liuanji Mar 10, 2024
9bb3bc9
update optim tests
liuanji Mar 11, 2024
b3ac66a
use `bfloat16` in forward pass instead of `float16`
liuanji Mar 12, 2024
3c763b1
use `bfloat16` instead of `float16` in forward pas
liuanji Mar 12, 2024
16c41e0
fix general EM parameter update kernels
liuanji Mar 12, 2024
ae66bc6
temp commit
liuanji Mar 12, 2024
6299bdc
fix runtests except `hclt_correctness_test`
liuanji Mar 12, 2024
2787361
stabilize hclt correctness test
liuanji Mar 12, 2024
80ef1a2
improve sum layer ele backward pass (compute on arithmetic space)
liuanji Mar 12, 2024
42c80d5
ensure nodes in a layer are different
liuanji Mar 13, 2024
4741c3f
fix layering when adding new nodes
liuanji Mar 13, 2024
4de4052
more runtests
liuanji Mar 13, 2024
d92f342
fix runtests
liuanji Mar 13, 2024
5b34c30
improve numerical stability of forward pass
liuanji Mar 13, 2024
9f13826
temporarily fix conditional query
liuanji Mar 14, 2024
5b58875
change eps to 1e-24
liuanji Mar 14, 2024
ab66287
allows deciding propagation alg per layer
liuanji Mar 14, 2024
e298620
add `num_ch_nodes` for sum nodes
liuanji Mar 14, 2024
c61463d
add log-space backward option for product layers
liuanji Mar 15, 2024
c5d1ff8
log-space backward for block-sparse layers
liuanji Mar 15, 2024
917843d
logspace backward for sparse layers + fixes
liuanji Mar 15, 2024
59b89bc
logspace backward with pytorch kernels
liuanji Mar 15, 2024
e0e6f9b
update cudagraph signature
liuanji Mar 17, 2024
0db6431
support logspace backward for input layers
liuanji Mar 17, 2024
6de9045
receive `logspace_flows` in `ProdLayer`
liuanji Mar 17, 2024
d1e7c02
receive input `logspace_flows` for `SumLayer`
liuanji Mar 17, 2024
7bd6f74
add runtests for logspace flows
liuanji Mar 17, 2024
83694a6
homogeneous PD
liuanji Mar 17, 2024
fd41521
avoid nans in backward pass for zero-flow inner nodes
liuanji Mar 17, 2024
e483917
remove num_vars assertion in forward pass
liuanji Mar 22, 2024
d6ab521
add `len(ns)` function
liuanji Mar 22, 2024
77fa53f
fix sum node construction when input `edge_ids` is a list
liuanji Mar 22, 2024
ff7d05e
provide `max_block_size` option for `deepcopy`
liuanji Mar 22, 2024
4e7180e
SGD update function
liuanji Mar 22, 2024
7c03d18
add option to accumulate negative parameter flows in the backward pass
liuanji Mar 22, 2024
ca8b641
fix typo in SGD update kernel
liuanji Mar 30, 2024
5373ddc
change kernel allocation for sum layers
liuanji Mar 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,26 @@ authors = [
{name="StarAI", email="[email protected]"},
]

[project.optional-dependencies]
dev = [
"pytest",
"pytest-xdist",
"pytest-skip-slow",
"torchvision",
"torchtext",
"matplotlib"
]

[options.packages.find]
where = "src"

[tool.setuptools.dynamic]
readme = {file = "README.md"}


[tool.pytest.ini_options]
addopts = [
"--import-mode=importlib",
]
testpaths = [
"tests/"
]
13 changes: 10 additions & 3 deletions src/pyjuice/layer/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(self, nodes: Sequence[InputNodes], cum_nodes: int = 0, pc_num_vars:
in gradient accumulation.
"""

assert len(nodes) == len(set(nodes)), "Input node list contains duplicates."

nn.Module.__init__(self)
Layer.__init__(self, nodes, disable_block_size_check = True)

Expand Down Expand Up @@ -213,7 +215,7 @@ def init_param_flows(self, flows_memory: float = 1.0):

def forward(self, data: torch.Tensor, node_mars: torch.Tensor, params: Optional[Dict] = None,
missing_mask: Optional[torch.Tensor] = None, _batch_first: bool = True,
_apply_missing_mask_only: bool = False):
_apply_missing_mask_only: bool = False, **kwargs):
self._used_external_params = (params is not None)

if params is None:
Expand Down Expand Up @@ -300,7 +302,8 @@ def forward(self, data: torch.Tensor, node_mars: torch.Tensor, params: Optional[
raise NotImplementedError("CPU forward fn for input nodes is not implemented.")

def backward(self, data: torch.Tensor, node_flows: torch.Tensor,
node_mars: torch.Tensor, params: Optional[Dict] = None):
node_mars: torch.Tensor, params: Optional[Dict] = None,
logspace_flows: bool = False, **kwargs):
"""
data: [num_vars, B]
node_flows: [num_nodes, B]
Expand Down Expand Up @@ -355,6 +358,7 @@ def backward(self, data: torch.Tensor, node_flows: torch.Tensor,
node_offset = node_offset,
BLOCK_SIZE = BLOCK_SIZE,
partial_eval = 1 if bk_local_ids is not None else 0,
logspace_flows = logspace_flows,
num_warps = 8
)

Expand Down Expand Up @@ -681,7 +685,7 @@ def _fw_missing_mask_kernel(missing_mask_ptr, node_mars_ptr, vids_ptr, fw_local_

@staticmethod
def _flows_kernel_template(flow_fn, params_ptr, param_flows_ptr, node_flows_ptr, node_mars_ptr, data_ptr, vids_ptr, s_pids_ptr, s_pfids_ptr,
metadata_ptr, s_mids_ptr, bk_local_ids_ptr, partial_eval: tl.constexpr, layer_num_nodes: tl.constexpr,
metadata_ptr, s_mids_ptr, bk_local_ids_ptr, partial_eval: tl.constexpr, logspace_flows: tl.constexpr, layer_num_nodes: tl.constexpr,
batch_size: tl.constexpr, num_vars_per_node: tl.constexpr, nv_block_size: tl.constexpr, node_offset: tl.constexpr,
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis = 0)
Expand Down Expand Up @@ -720,6 +724,9 @@ def _flows_kernel_template(flow_fn, params_ptr, param_flows_ptr, node_flows_ptr,
ns_offsets = (local_offsets + node_offset) * batch_size + batch_offsets
flows = tl.load(node_flows_ptr + ns_offsets, mask = mask, other = 0)

if logspace_flows:
flows = tl.exp(flows)

flow_fn(local_offsets, ns_offsets, data, flows, node_mars_ptr, params_ptr, param_flows_ptr, s_pids, s_pfids, metadata_ptr,
s_mids_ptr, mask, num_vars_per_node, BLOCK_SIZE)

Expand Down
17 changes: 17 additions & 0 deletions src/pyjuice/layer/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@


class Layer():

propagation_alg_mapping = {
"LL": 0,
"MPE": 1,
"GeneralLL": 2
}

def __init__(self, nodes: Sequence[CircuitNodes], disable_block_size_check: bool = False) -> None:

if disable_block_size_check:
Expand Down Expand Up @@ -60,3 +67,13 @@ def disable_partial_evaluation(self, forward: bool = True, backward: bool = True

def provided(self, var_name):
return hasattr(self, var_name) and getattr(self, var_name) is not None

def _get_propagation_alg_kwargs(self, propagation_alg: str, **kwargs):
if propagation_alg == "LL":
return {"alpha": 0.0}
elif propagation_alg == "MPE":
return {"alpha": 0.0}
elif propagation_alg == "GeneralLL":
return {"alpha": kwargs["alpha"]}
else:
raise ValueError(f"Unknown propagation algorithm {propagation_alg}.")
Loading
Loading