-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #128 from JuliaGNI/volume_preserving_feedforward
Volume preserving feedforward neural network (vpffnn)
- Loading branch information
Showing
25 changed files
with
1,654 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# `SymmetricMatrix` and `SkewSymMatrix` | ||
|
||
There are special implementations of symmetric and skew-symmetric matrices in `GeometricMachineLearning.jl`. They are implemented to work on GPU and for multiplication with tensors. The following image demonstrates how the data necessary for an instance of `SkewSymMatrix` are stored[^1]: | ||
|
||
[^1]: It works similarly for `SymmetricMatrix`. | ||
|
||
```@example | ||
import Images, Plots # hide | ||
if Main.output_type == :html # hide | ||
HTML("""<object type="image/svg+xml" class="display-light-only" data=$(joinpath(Main.buildpath, "../tikz/skew_sym_visualization.png"))></object>""") # hide | ||
else # hide | ||
Plots.plot(Images.load("../tikz/skew_sym_visualization.png"), axis=([], false)) # hide | ||
end # hide | ||
``` | ||
|
||
```@example | ||
if Main.output_type == :html # hide | ||
HTML("""<object type="image/svg+xml" class="display-dark-only" data=$(joinpath(Main.buildpath, "../tikz/skew_sym_visualization_dark.png"))></object>""") # hide | ||
end # | ||
``` | ||
|
||
So what is stored internally is a vector of size ``n(n-1)/2`` for the skew-symmetric matrix and a vector of size ``n(n+1)/2`` for the symmetric matrix. We can sample a random skew-symmetric matrix: | ||
|
||
```@example skew_sym | ||
using GeometricMachineLearning # hide | ||
A = rand(SkewSymMatrix, 5) | ||
``` | ||
|
||
and then access the vector: | ||
|
||
```@example skew_sym | ||
A.S | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Volume-Preserving Feedforward Layer | ||
|
||
**Volume preserving feedforward layers** are a special type of ResNet layer for which we restrict the weight matrices to be of a particular form. I.e. each layer computes: | ||
|
||
```math | ||
x \mapsto x + \sigma(Ax + b), | ||
``` | ||
where ``\sigma`` is a nonlinearity, ``A`` is the weight and ``b`` is the bias. The matrix ``A`` is either a lower-triangular matrix ``L`` or an upper-triangular matrix ``U``[^1]. The lower triangular matrix is of the form (the upper-triangular layer is simply the transpose of the lower triangular): | ||
|
||
[^1]: Implemented as `LowerTriangular` and `UpperTriangular` in `GeometricMachineLearning`. | ||
|
||
```math | ||
L = \begin{pmatrix} | ||
0 & 0 & \cdots & 0 \\ | ||
a_{21} & \ddots & & \vdots \\ | ||
\vdots & \ddots & \ddots & \vdots \\ | ||
a_{n1} & \cdots & a_{n(n-1)} & 0 | ||
\end{pmatrix}. | ||
``` | ||
|
||
The Jacobian of a layer of the above form then is of the form | ||
|
||
```math | ||
J = \begin{pmatrix} | ||
1 & 0 & \cdots & 0 \\ | ||
b_{21} & \ddots & & \vdots \\ | ||
\vdots & \ddots & \ddots & \vdots \\ | ||
b_{n1} & \cdots & b_{n(n-1)} & 1 | ||
\end{pmatrix}, | ||
``` | ||
and the determinant of ``J`` is 1, i.e. the map is volume-preserving. | ||
|
||
## Neural network architecture | ||
|
||
Volume-preserving feedforward neural networks should be used as `Architecture`s in `GeometricMachineLearning`. The constructor for them is: | ||
|
||
```@eval | ||
using GeometricMachineLearning, Markdown | ||
Markdown.parse(description(Val(:VPFconstructor))) | ||
``` | ||
|
||
The constructor produces the following architecture[^2]: | ||
|
||
[^2]: Based on the input arguments `n_linear` and `n_blocks`. In this example `init_upper` is set to false, which means that the first layer is of type *lower* followed by a layer of type *upper*. | ||
|
||
```@example | ||
import Images, Plots # hide | ||
if Main.output_type == :html # hide | ||
HTML("""<object type="image/svg+xml" class="display-light-only" data=$(joinpath(Main.buildpath, "../tikz/vp_feedforward.png"))></object>""") # hide | ||
else # hide | ||
Plots.plot(Images.load("../tikz/vp_feedforward.png"), axis=([], false)) # hide | ||
end # hide | ||
``` | ||
|
||
```@example | ||
if Main.output_type == :html # hide | ||
HTML("""<object type="image/svg+xml" class="display-dark-only" data=$(joinpath(Main.buildpath, "../tikz/vp_feedforward_dark.png"))></object>""") # hide | ||
end # hide | ||
``` | ||
|
||
Here *LinearLowerLayer* performs ``x \mapsto x + Lx`` and *NonLinearLowerLayer* performs ``x \mapsto x + \sigma(Lx + b)``. The activation function ``\sigma`` is the forth input argument to the constructor and `tanh` by default. | ||
|
||
## Note on Sympnets | ||
|
||
As [SympNets](../architectures/sympnet.md) are symplectic maps, they also conserve phase space volume and therefore form a subcategory of volume-preserving feedforward layers. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
\documentclass[tikz]{standalone} | ||
|
||
\usepackage{xcolor} | ||
\definecolor{morange}{RGB}{255,127,14} | ||
\definecolor{mblue}{RGB}{31,119,180} | ||
\definecolor{mred}{RGB}{214,39,40} | ||
\definecolor{mpurple}{RGB}{148,103,189} | ||
\definecolor{mgreen}{RGB}{44,160,44} | ||
|
||
\usepackage{mathtools} | ||
\usepackage{amssymb} | ||
\usepackage{nicematrix} | ||
|
||
\usepackage{tikz} | ||
\usetikzlibrary{fit,shapes.geometric} | ||
\tikzset{highlight/.style={rectangle, draw=mblue, semithick, inner sep=1pt}} | ||
|
||
\begin{document} | ||
|
||
\begin{tikzpicture} | ||
|
||
\node (matrix) {$ | ||
\begin{pNiceMatrix}[name=A] | ||
0 & -a_{21} & \Cdots & -a_{n1} \\ | ||
a_{21} & \Ddots & & \Vdots \\ | ||
\Vdots & \Ddots & \Ddots & \Vdots \\ | ||
a_{n1} & \Cdots & a_{n(n-1)} & 0 | ||
\CodeAfter | ||
\tikz \node [highlight, fit=(2-1)(2-1), mblue] {}; | ||
%\tikz \node [highlight, fit=(2-1) (2-1), mpurple] {}; | ||
\tikz \node [highlight, fit=(4-1)(4-3), mpurple] {}; | ||
\end{pNiceMatrix} | ||
$}; | ||
|
||
\node[below of=matrix, xshift=5cm, yshift=-.5cm] (vector) {$ | ||
\begin{pNiceMatrix} | ||
a_{21} \\ | ||
a_{31} \\ | ||
a_{32} \\ | ||
a_{41} \\ | ||
a_{42} \\ | ||
a_{43} \\ | ||
a_{51} \\ | ||
\vdots \\ | ||
a_{54} \\ | ||
\vdots \\ | ||
a_{n1} \\ | ||
\vdots \\ | ||
a_{n(n-1)} | ||
\CodeAfter | ||
\tikz \node [highlight, fit=(1-1)(1-1), mblue] {}; | ||
\tikz \node [highlight, fit=(2-1)(3-1), morange] {}; | ||
\tikz \node [highlight, fit=(4-1)(6-1), mgreen] {}; | ||
\tikz \node [highlight, fit=(7-1)(9-1), mred] {}; | ||
\tikz \node [highlight, fit=(11-1)(13-1), mpurple] {}; | ||
\end{pNiceMatrix} | ||
$}; | ||
|
||
\draw[-stealth, thick, rounded corners, mred] (matrix) -- (vector); | ||
|
||
\end{tikzpicture} | ||
\end{document} | ||
|
||
\end{document} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
\documentclass[tikz]{standalone} | ||
|
||
\usepackage{xcolor} | ||
\definecolor{morange}{RGB}{255,127,14} | ||
\definecolor{mblue}{RGB}{31,119,180} | ||
\definecolor{mred}{RGB}{214,39,40} | ||
\definecolor{mpurple}{RGB}{148,103,189} | ||
\definecolor{mgreen}{RGB}{44,160,44} | ||
|
||
\usepackage{mathtools} | ||
\usepackage{amssymb} | ||
\usepackage{nicematrix} | ||
|
||
\usepackage{tikz} | ||
\usetikzlibrary{fit,shapes.geometric} | ||
\tikzset{highlight/.style={rectangle, draw=mblue, semithick, inner sep=1pt}} | ||
|
||
\begin{document} | ||
|
||
\begin{tikzpicture} | ||
|
||
\node[color=white] (matrix) {$ | ||
\begin{pNiceMatrix}[name=A] | ||
0 & -a_{21} & \Cdots & -a_{n1} \\ | ||
a_{21} & \Ddots & & \Vdots \\ | ||
\Vdots & \Ddots & \Ddots & \Vdots \\ | ||
a_{n1} & \Cdots & a_{n(n-1)} & 0 | ||
\CodeAfter | ||
\tikz \node [highlight, fit=(2-1)(2-1), mblue] {}; | ||
%\tikz \node [highlight, fit=(2-1) (2-1), mpurple] {}; | ||
\tikz \node [highlight, fit=(4-1)(4-3), mpurple] {}; | ||
\end{pNiceMatrix} | ||
$}; | ||
|
||
\node[below of=matrix, xshift=5cm, yshift=-.5cm, color=white] (vector) {$ | ||
\begin{pNiceMatrix} | ||
a_{21} \\ | ||
a_{31} \\ | ||
a_{32} \\ | ||
a_{41} \\ | ||
a_{42} \\ | ||
a_{43} \\ | ||
a_{51} \\ | ||
\vdots \\ | ||
a_{54} \\ | ||
\vdots \\ | ||
a_{n1} \\ | ||
\vdots \\ | ||
a_{n(n-1)} | ||
\CodeAfter | ||
\tikz \node [highlight, fit=(1-1)(1-1), mblue] {}; | ||
\tikz \node [highlight, fit=(2-1)(3-1), morange] {}; | ||
\tikz \node [highlight, fit=(4-1)(6-1), mgreen] {}; | ||
\tikz \node [highlight, fit=(7-1)(9-1), mred] {}; | ||
\tikz \node [highlight, fit=(11-1)(13-1), mpurple] {}; | ||
\end{pNiceMatrix} | ||
$}; | ||
|
||
\draw[-stealth, thick, rounded corners, mred] (matrix) -- (vector); | ||
|
||
\end{tikzpicture} | ||
\end{document} | ||
|
||
\end{document} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
\documentclass[crop, tikz]{standalone} | ||
|
||
\usepackage{tikz} | ||
\usepackage{amsmath} | ||
\usepackage{amssymb} | ||
\usepackage[mode=buildnew]{standalone} | ||
|
||
\usepackage{xcolor} | ||
|
||
|
||
\usetikzlibrary{positioning} | ||
\usetikzlibrary{calc} | ||
\usetikzlibrary{fit} | ||
%\usepackage{nicematrix} | ||
|
||
\tikzset{set/.style={draw,circle,inner sep=0pt,align=center}} | ||
|
||
\definecolor{morange}{RGB}{255,127,14} | ||
\definecolor{mblue}{RGB}{31,119,180} | ||
\definecolor{mred}{RGB}{214,39,40} | ||
\definecolor{mpurple}{RGB}{148,103,189} | ||
\definecolor{mgreen}{RGB}{44,160,44} | ||
|
||
|
||
\begin{document} | ||
\begin{tikzpicture}[module/.style={draw, very thick, rounded corners, minimum width=15ex}, | ||
linup/.style={module, fill=morange!30}, | ||
linlow/.style={module, fill=mblue!30}, | ||
bias/.style={module, fill=mpurple!30}, | ||
nonlinup/.style={module, fill=mgreen!30}, | ||
nonlinlow/.style={module, fill=mred!30}, | ||
arrow/.style={-stealth, thick, rounded corners}, | ||
] | ||
\node (input) {Input}; | ||
\node[above of=input, linlow, align=center, yshift=.6cm] (linlow1) {LinearLowerLayer}; | ||
\node[above of=linlow1, linup, align=center] (linup1) {LinearUpperLayer}; | ||
\node[above of=linup1, bias, align=center] (bias1) {Bias}; | ||
\node[above of=bias1, nonlinlow, align=center] (nonlinlow) {NonLinearLowerLayer}; | ||
\node[above of=nonlinlow, nonlinup, align=center] (nonlinup) {NonLinearUpperLayer}; | ||
\node[above of=nonlinup, linlow, align=center] (linlow2) {LinearLowerLayer}; | ||
\node[above of=linlow2, linup, align=center] (linup2) {LinearUpperLayer}; | ||
\node[above of=linup2, bias, align=center] (bias2) {Bias}; | ||
\node[above of=bias2] (output) {Output}; | ||
|
||
\draw[arrow] (input) -- (linlow1); | ||
\draw[arrow] (linlow1) -- (linup1); | ||
\draw[arrow] (linup1) -- (bias1); | ||
\draw[arrow] (bias1) -- (nonlinlow); | ||
\draw[arrow] (nonlinlow) -- (nonlinup); | ||
\draw[arrow] (nonlinup) -- (linlow2); | ||
\draw[arrow] (linlow2) -- (linup2); | ||
\draw[arrow] (linup2) -- (bias2); | ||
\draw[arrow] (bias2) -- (output); | ||
|
||
\coordinate (linear_layer_center) at ($(linlow1.north)!0.5!(linup1.south)$); | ||
\coordinate (before_linear_lower_layer) at ($(input.north)!0.7!(linlow1.south)$); | ||
\node[left of=linear_layer_center, xshift=-1cm] (first_label) {}; | ||
|
||
\node[fit=(before_linear_lower_layer)(linlow1)(linup1),draw, ultra thick, rounded corners, label={[rotate=90, yshift=.5em, xshift=3em]left:$\mathtt{n\_linear}\times$}] (linear) {}; | ||
\node[fit=(linear)(bias1)(nonlinlow)(nonlinup)(first_label), draw, ultra thick, rounded corners, label={[rotate=90, yshift=.5em, xshift=3em]left:$\mathtt{n\_blocks}\times$}] (block) {}; | ||
\end{tikzpicture} | ||
\end{document} |
Oops, something went wrong.