Skip to content

Commit

Permalink
flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Nov 25, 2023
1 parent 985b443 commit bce2cc6
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 33 deletions.
5 changes: 0 additions & 5 deletions nngeometry/generator/jacobian/grads_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,6 @@ def conv_backward(
return weight_bgrad


def conv1d_backward(*args, **kwargs):
"""Computes per-example gradients for nn.Conv1d layers."""
return conv_backward(*args, nd=1, **kwargs)


def conv2d_backward_using_conv(mod, x, gy):
"""Computes per-example gradients for nn.Conv2d layers."""
return conv_backward(
Expand Down
5 changes: 4 additions & 1 deletion nngeometry/generator/jacobian/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def __init__(
self.centering = centering

if function is None:
function = lambda *x: model(x[0])

def function(*x):
return model(x[0])

Check warning on line 46 in nngeometry/generator/jacobian/jacobian.py

View check run for this annotation

Codecov / codecov/patch

nngeometry/generator/jacobian/jacobian.py#L45-L46

Added lines #L45 - L46 were not covered by tests

self.function = function

if layer_collection is None:
Expand Down
12 changes: 10 additions & 2 deletions nngeometry/object/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from .fspace import FMatDense
from .map import PullBackDense, PushForwardDense, PushForwardImplicit
from .pspace import (PMatBlockDiag, PMatDense, PMatDiag, PMatEKFAC,
PMatImplicit, PMatKFAC, PMatLowRank, PMatQuasiDiag)
from .pspace import (
PMatBlockDiag,
PMatDense,
PMatDiag,
PMatEKFAC,
PMatImplicit,
PMatKFAC,
PMatLowRank,
PMatQuasiDiag,
)
from .vector import FVector, PVector

__all__ = [
Expand Down
15 changes: 0 additions & 15 deletions tests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,21 +370,6 @@ def get_fullyconnect_affine_task():
return get_fullyconnect_task(normalization="affine")


def get_conv_task(normalization="none"):
train_set = get_mnist()
train_set = Subset(train_set, range(70))
train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False)
net = ConvNet(normalization=normalization)
to_device_model(net)
net.eval()

def output_fn(input, target):
return net(to_device(input))

layer_collection = LayerCollection.from_model(net)
return (train_loader, layer_collection, net.parameters(), net, output_fn, 3)


def get_conv_bn_task():
return get_conv_task(normalization="batch_norm")

Expand Down
1 change: 1 addition & 0 deletions tests/test_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@ def test_bn_eval_mode():
model.train()
with pytest.raises(RuntimeError):
FMat_dense = FMatDense(generator=generator, examples=loader)
FMat_dense.get_dense_tensor()

Check warning on line 836 in tests/test_jacobian.py

View check run for this annotation

Codecov / codecov/patch

tests/test_jacobian.py#L836

Added line #L836 was not covered by tests


def test_example_passing():
Expand Down
3 changes: 0 additions & 3 deletions tests/test_jacobian_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from tasks import (
get_conv_task,
get_fullyconnect_task,
Expand Down Expand Up @@ -196,8 +195,6 @@ def test_jacobian_kfac():
# Test mv
mv_direct = torch.mv(G_kfac_split, random_v.get_flat_representation())
mv_kfac = M_kfac.mv(random_v)
print(mv_direct.size(), lc.layers)
pvec = PVector(layer_collection=lc, vector_repr=mv_direct)
check_tensors(mv_direct, mv_kfac.get_flat_representation())

# Test vTMv
Expand Down
11 changes: 8 additions & 3 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import pytest
import torch
import torch.nn.functional as tF
from tasks import (device, get_conv_gn_task, get_conv_task,
get_fullyconnect_segm_task, get_fullyconnect_task,
to_device)
from tasks import (
device,
get_conv_gn_task,
get_conv_task,
get_fullyconnect_segm_task,
get_fullyconnect_task,
to_device,
)
from test_jacobian import get_output_vector, update_model

from nngeometry.metrics import FIM, FIM_MonteCarlo
Expand Down
9 changes: 7 additions & 2 deletions tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
from utils import check_tensors

from nngeometry.generator import Jacobian
from nngeometry.object.pspace import (PMatBlockDiag, PMatDense, PMatDiag,
PMatLowRank, PMatQuasiDiag)
from nngeometry.object.pspace import (
PMatBlockDiag,
PMatDense,
PMatDiag,
PMatLowRank,
PMatQuasiDiag,
)
from nngeometry.object.vector import PVector


Expand Down
3 changes: 1 addition & 2 deletions tests/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from utils import check_ratio, check_tensors

from nngeometry.layercollection import LayerCollection
from nngeometry.object.vector import (PVector, random_pvector,
random_pvector_dict)
from nngeometry.object.vector import PVector, random_pvector, random_pvector_dict


class ConvNet(nn.Module):
Expand Down

0 comments on commit bce2cc6

Please sign in to comment.