Skip to content

Commit

Permalink
Patch torch.nn.Paeameter
Browse files Browse the repository at this point in the history
Closes #6
  • Loading branch information
xl0 committed Jul 8, 2024
1 parent bbc80f5 commit 5d15bee
Show file tree
Hide file tree
Showing 4 changed files with 5,928 additions and 19 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ named_numbers = numbers.rename("C", "H","W")
named_numbers
```

/home/xl0/mambaforge/envs/lovely-py31-torch25/lib/python3.10/site-packages/torch/_tensor.py:1420: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ../c10/core/TensorImpl.h:1925.)
return super().rename(names)

tensor[C=3, H=196, W=196] n=115248 (0.4Mb) x∈[-2.118, 2.640] μ=-0.388 σ=1.073

## Going `.deeper`
Expand Down Expand Up @@ -217,7 +220,7 @@ numbers.rgb
*Maaaaybe?* Looks like someone normalized him.

``` python
in_stats = ( (0.485, 0.456, 0.406), # mean
in_stats = ( (0.485, 0.456, 0.406), # mean
(0.229, 0.224, 0.225) ) # std

# numbers.rgb(in_stats, cl=True) # For channel-last input format
Expand Down Expand Up @@ -321,7 +324,6 @@ eight_images.rgb
features[3].weight
```

Parameter containing:
Parameter[128, 64, 3, 3] n=73728 (0.3Mb) x∈[-0.783, 0.776] μ=-0.004 σ=0.065 grad

I want +/- 2σ to fall in the range \[-1..1\]
Expand Down
10 changes: 7 additions & 3 deletions lovely_tensors/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

# %% ../nbs/10_patch.ipynb 4
def monkey_patch(cls=torch.Tensor):
"Monkey-patch lovely features into `cls`"
"Monkey-patch lovely features into `cls`"

if not hasattr(cls, '_plain_repr'):
if cls is torch.Tensor:
Expand All @@ -29,7 +29,7 @@ def monkey_patch(cls=torch.Tensor):
cls._plain_str = cls.__str__

@patch_to(cls)
def __repr__(self: torch.Tensor, *, tensor_contents=None):
def __repr__(self: torch.Tensor, *, tensor_contents=None):
return str(StrProxy(self))

# Plain - the old behavior
Expand All @@ -51,7 +51,7 @@ def deeper(self: torch.Tensor):
@patch_to(cls, as_prop=True)
def rgb(t: torch.Tensor):
return RGBProxy(t)

# .chans and .chans(...)
@patch_to(cls, as_prop=True)
def chans(t: torch.Tensor):
Expand All @@ -61,3 +61,7 @@ def chans(t: torch.Tensor):
@patch_to(cls, as_prop=True)
def plt(t: torch.Tensor):
return PlotProxy(t)

# The base class repr handler nn.Parameter better.
if "__repr__" in torch.nn.Parameter.__dict__:
del torch.nn.Parameter.__repr__
5,916 changes: 5,906 additions & 10 deletions nbs/10_patch.ipynb

Large diffs are not rendered by default.

15 changes: 11 additions & 4 deletions nbs/index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
}
],
"source": [
"numbers "
"numbers"
]
},
{
Expand Down Expand Up @@ -363,6 +363,14 @@
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/xl0/mambaforge/envs/lovely-py31-torch25/lib/python3.10/site-packages/torch/_tensor.py:1420: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ../c10/core/TensorImpl.h:1925.)\n",
" return super().rename(names)\n"
]
},
{
"data": {
"text/plain": [
Expand Down Expand Up @@ -525,7 +533,7 @@
}
],
"source": [
"in_stats = ( (0.485, 0.456, 0.406), # mean \n",
"in_stats = ( (0.485, 0.456, 0.406), # mean\n",
" (0.229, 0.224, 0.225) ) # std\n",
"\n",
"# numbers.rgb(in_stats, cl=True) # For channel-last input format\n",
Expand Down Expand Up @@ -5846,7 +5854,6 @@
{
"data": {
"text/plain": [
"Parameter containing:\n",
"Parameter[128, 64, 3, 3] n=73728 (0.3Mb) x∈[-0.783, 0.776] μ=-0.004 σ=0.065 grad"
]
},
Expand Down Expand Up @@ -10408,7 +10415,7 @@
"rcParams[\"svg.hashsalt\"] = \"1\"\n",
"\n",
"# No date, don't include matplotlib version\n",
"kwargs = {'metadata': {\"Date\": None, \"Creator\": \"Matplotlib, https://matplotlib.org/\" }} \n",
"kwargs = {'metadata': {\"Date\": None, \"Creator\": \"Matplotlib, https://matplotlib.org/\" }}\n",
"%config InlineBackend.print_figure_kwargs = kwargs"
]
},
Expand Down

0 comments on commit 5d15bee

Please sign in to comment.