Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SVRG and LSVRG approximate gradient methods #1625

Merged
merged 233 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
233 commits
Select commit Hold shift + click to select a range
8389932
First attempt at sampling class
MargaretDuff Aug 2, 2023
7331c73
Changed how probabilities and samplers interact in SPDHG
MargaretDuff Aug 2, 2023
343509c
Initial playing
MargaretDuff Aug 7, 2023
40ac9c0
Ready to start some basic testing
MargaretDuff Aug 8, 2023
34fb1d5
Started to debug
MargaretDuff Aug 8, 2023
6ef169a
Testind SGD
MargaretDuff Aug 8, 2023
2bce666
Update sampling.py
MargaretDuff Aug 8, 2023
ea759c5
Changed to factory method style and added in permuatations
MargaretDuff Aug 9, 2023
d1909a3
Debugging and fixing random generator in show epochs
MargaretDuff Aug 9, 2023
98b0694
Testing SPDHG
MargaretDuff Aug 9, 2023
05b67cb
Changed the show epochs
MargaretDuff Aug 10, 2023
001350b
Meeting with Vaggelis, Jakob, Gemma and Edo
MargaretDuff Aug 11, 2023
890dec0
Set up for installation
MargaretDuff Aug 14, 2023
25806fc
Added staggered and custom order and started with writing documentation
MargaretDuff Aug 14, 2023
75abbfe
Work on documentation
MargaretDuff Aug 14, 2023
ebdf329
Commenting and examples in the class
MargaretDuff Aug 15, 2023
ba35fb8
Debugging sampler
MargaretDuff Aug 15, 2023
ff5cdf1
sorted build and imports
MargaretDuff Aug 15, 2023
f62f064
Changes to todo
MargaretDuff Aug 16, 2023
beac6fa
Changes after dev meeting
MargaretDuff Aug 17, 2023
1202e53
Checking probabilities in init
MargaretDuff Aug 18, 2023
079935b
initial testing
MargaretDuff Aug 23, 2023
43e3dc4
Sped up PDHG and SPDHG testing
MargaretDuff Aug 24, 2023
004ab2f
Removed timing statements
MargaretDuff Aug 24, 2023
7b857e0
Got rid of epochs - still need to fix the shuffle
MargaretDuff Sep 13, 2023
1f7d546
Fixed random without replacement shuffle=False
MargaretDuff Sep 14, 2023
6993a95
Changes after meeting 12-09-2023. Remove epochs in sampler and deprec…
MargaretDuff Sep 14, 2023
bafc748
Sampler unit tests added
MargaretDuff Sep 19, 2023
d62aa2b
Some checks for setting step sizes
MargaretDuff Sep 19, 2023
c81b71c
Started looking at unit tests and debugging SPDHG setters and init
MargaretDuff Sep 21, 2023
b28f2f1
Notes after discussions with gemma
MargaretDuff Sep 22, 2023
4a87f48
Changes after discussion with gemma
MargaretDuff Sep 25, 2023
b35222f
Updated tests
MargaretDuff Sep 25, 2023
6e552af
Just a commenting change
MargaretDuff Sep 25, 2023
4ae9b3c
Tiny changes
MargaretDuff Sep 28, 2023
69c1e1a
Merge branch 'master' of github.com:TomographicImaging/CIL into SGD
MargaretDuff Sep 28, 2023
6575af6
Initial changes and tests- currently failing tests
MargaretDuff Sep 28, 2023
6b463bc
Sorted tests and checks on the set_norms function
MargaretDuff Oct 2, 2023
215bfa6
Changed a comment
MargaretDuff Oct 2, 2023
3898a03
Removed reference to dask
MargaretDuff Oct 4, 2023
b946d79
Bug fixes
MargaretDuff Oct 5, 2023
96e4730
Changes based on Gemma's review
MargaretDuff Oct 5, 2023
3c36f3f
Small changes
MargaretDuff Oct 9, 2023
1ca3a2b
Comments from Edo fixed
MargaretDuff Oct 9, 2023
4b541e7
Merge branch 'master' into blockoperator-norms
MargaretDuff Oct 9, 2023
9a04de4
Added stuff to gitignore
MargaretDuff Oct 9, 2023
5a302c8
Fixed tests
MargaretDuff Oct 9, 2023
0bffa24
Added a note to the documentation about which sampler to use
MargaretDuff Oct 11, 2023
8416837
Option for list or blockfunction
MargaretDuff Oct 11, 2023
37565fc
Fixed the bugs of the previous commit
MargaretDuff Oct 11, 2023
18647af
Merge branch 'master' of github.com:MargaretDuff/CIL-margaret into st…
MargaretDuff Oct 11, 2023
222c377
Moved the sampler to the algorithms folder
MargaretDuff Oct 12, 2023
1d70eb3
Updated tests
MargaretDuff Oct 12, 2023
5c9fa3a
Sampler inheritance
MargaretDuff Oct 12, 2023
48d355b
Notes from meeting
MargaretDuff Oct 12, 2023
8e84276
Moved sampler to a new folder algorithms.utilities- think there is st…
MargaretDuff Oct 12, 2023
a9cb92e
Some notes from the stochastic meeting
MargaretDuff Oct 12, 2023
c552257
changed cmake file for new folder
MargaretDuff Oct 12, 2023
c6e1458
Some changes from Edo
MargaretDuff Oct 16, 2023
2b35fad
Maths documentation
MargaretDuff Oct 16, 2023
43e6fee
Some more Edo comments on sampler
MargaretDuff Oct 16, 2023
f77b553
Tried to sort the tests
MargaretDuff Oct 17, 2023
cf1b7f1
Vaggelis comment on checks
MargaretDuff Oct 17, 2023
c2c4df9
Change to jinja version in doc_environment.yml
MargaretDuff Oct 17, 2023
544a215
Merge branch 'TomographicImaging:master' into blockoperator-norms
MargaretDuff Oct 17, 2023
d11296f
Revert changes to docs_environment.yml
lauramurgatroyd Oct 18, 2023
32e057b
Docstring change
MargaretDuff Oct 18, 2023
4e0ca6a
Docstring change
MargaretDuff Oct 18, 2023
87f1a00
Revert naming of docs environment file
lauramurgatroyd Oct 18, 2023
2ff165a
Updated changelog
MargaretDuff Oct 18, 2023
81fc7e2
Updated changelog
MargaretDuff Oct 18, 2023
8f100e0
Updated changelog
MargaretDuff Oct 18, 2023
920cf90
Started adding new unit tests
MargaretDuff Oct 20, 2023
3a02a47
More work on tests
MargaretDuff Oct 20, 2023
10748ef
SG tests
MargaretDuff Oct 20, 2023
381342c
Changes to docstring
MargaretDuff Oct 25, 2023
c67818b
Changes to tests
MargaretDuff Oct 25, 2023
6b5ff83
SGD tests including SumFunction
MargaretDuff Oct 25, 2023
c7e5b9f
Merge branch 'SGD' of github.com:MargaretDuff/CIL-margaret into SGD
MargaretDuff Oct 26, 2023
876d4c9
Added size to the BlockOperator
MargaretDuff Oct 26, 2023
5ae4aaf
Merged the blockoperator-norms branch
MargaretDuff Oct 30, 2023
b983e2f
Removed precalculated_norms and pull the prob_weights from the sampler
MargaretDuff Oct 31, 2023
71cbdf9
Changes to setting tau and new unit test
MargaretDuff Oct 31, 2023
8f24634
Just some comments
MargaretDuff Oct 31, 2023
f0f4de3
Changes after discussion with Edo and Gemma
MargaretDuff Nov 2, 2023
100a42d
Merge branch 'blockoperator-norms' of github.com:MargaretDuff/CIL-mar…
MargaretDuff Nov 2, 2023
26584c9
Documentation changes
MargaretDuff Nov 2, 2023
ba8226b
Merge branch 'blockoperator-norms' of github.com:MargaretDuff/CIL-mar…
MargaretDuff Nov 2, 2023
d182423
Changes to SPDHG with block_norms
MargaretDuff Nov 3, 2023
ad86a58
Started setting up factory methods
MargaretDuff Nov 6, 2023
40ba3f4
Added function sampler
MargaretDuff Nov 6, 2023
835ce83
Abstract base class
MargaretDuff Nov 6, 2023
3760458
prob_weights to sampler
MargaretDuff Nov 7, 2023
878675d
TODO:s
MargaretDuff Nov 7, 2023
5ce0a09
Changes after stochastic meeting
MargaretDuff Nov 8, 2023
2d99762
Updates to sampler
MargaretDuff Nov 8, 2023
7154834
Updates to SPDHG after stochastic meeting
MargaretDuff Nov 8, 2023
520b9fa
Updated unit tests
MargaretDuff Nov 8, 2023
13c27e3
Merge branch 'master' into SGD
MargaretDuff Nov 8, 2023
11a4624
Merge branch 'master' into stochastic_sampling
MargaretDuff Nov 8, 2023
4e7f2b6
Merge error fixed
MargaretDuff Nov 8, 2023
d861a13
SPDHG documentation changes
MargaretDuff Nov 15, 2023
c0f0703
Merge branch 'stochastic_sampling' of github.com:MargaretDuff/CIL-mar…
MargaretDuff Nov 22, 2023
fce2a8e
Merged sampler into SGD
MargaretDuff Nov 22, 2023
ea4f114
Merge branch 'master' into SPDHG_unit_tests
MargaretDuff Nov 22, 2023
f95560f
Merged in SPDHG speed up
MargaretDuff Nov 22, 2023
0af2e61
Changes from meeting with Edo and Gemma
MargaretDuff Nov 22, 2023
8e14034
Remove changes to BlockOperator.py
MargaretDuff Nov 22, 2023
5c34e69
sigma and tau properties
MargaretDuff Nov 22, 2023
d1fffdf
Another attempt at speeding up unit tests
MargaretDuff Nov 23, 2023
b3dc8a1
Added random seeds to tests
MargaretDuff Nov 23, 2023
edbaa9f
Started on Gemma's suggestions
MargaretDuff Nov 24, 2023
dc1b67a
Some more of Gemma's changes
MargaretDuff Nov 27, 2023
3b41fc4
Last of Gemma's changes
MargaretDuff Nov 27, 2023
7e5759b
Merge branch 'master' into stochastic_sampling
MargaretDuff Nov 27, 2023
bab0b98
Edo's comments
MargaretDuff Nov 28, 2023
41ff3b5
New __str__ functions in sampler
MargaretDuff Nov 30, 2023
aaa7200
Documentation changes
MargaretDuff Nov 30, 2023
b9bb04d
Documentation changes x2
MargaretDuff Nov 30, 2023
ef25425
Moved custom order to an example of a function
MargaretDuff Dec 5, 2023
0948e39
Back to num_indices and more explanation for custom function examples
MargaretDuff Dec 5, 2023
5804f7d
Updates from chat with Gemma
MargaretDuff Dec 7, 2023
fca94f4
Updates from chat with Gemma
MargaretDuff Dec 7, 2023
7d4ffe6
Pulled prime factorisation code out of the Herman Meyer function
MargaretDuff Dec 8, 2023
2dba9d7
created herman_meyer sampling as a fucntion of iteration number
gfardell Dec 8, 2023
c576a51
Merge pull request #1 from gfardell/stochastic_sampling_hm
MargaretDuff Dec 11, 2023
4c36fdf
Merge branch 'master' into stochastic_sampling
MargaretDuff Dec 11, 2023
f5c2d96
Update Wrappers/Python/cil/optimisation/algorithms/SPDHG.py
MargaretDuff Dec 11, 2023
188000f
Changes from Edo review
MargaretDuff Dec 11, 2023
0155e3d
Merge branch 'stochastic_sampling' of github.com:MargaretDuff/CIL-mar…
MargaretDuff Dec 11, 2023
86c1e3e
Removed from_order to replace with functions
MargaretDuff Dec 11, 2023
47542a5
fix failing tests
MargaretDuff Dec 12, 2023
ddbdbb3
Test fix...again
MargaretDuff Dec 12, 2023
060e915
Merged conflict
MargaretDuff Dec 12, 2023
8e7a6ac
Merge branch 'master' into stochastic_sampling
MargaretDuff Dec 12, 2023
c8b9cc4
Merge branch 'stochastic_sampling' of github.com:MargaretDuff/CIL-mar…
MargaretDuff Dec 12, 2023
5f03675
Fixed tests
MargaretDuff Dec 12, 2023
c49cd16
SVRG and initial unit tests
MargaretDuff Dec 13, 2023
e8a2950
LSVRG draft
MargaretDuff Dec 13, 2023
9e67473
LSVRG draft and data pass unit tests
MargaretDuff Dec 13, 2023
2fb3d5d
More unit tests for LSVRG
MargaretDuff Dec 14, 2023
f0b4f34
Add the important file!
MargaretDuff Dec 14, 2023
e5adafd
Merge branch 'master' into svrg
MargaretDuff Dec 20, 2023
44173f4
Merge branch 'master' into SGD
MargaretDuff Dec 20, 2023
d72b7c1
Merge branch 'master' into SGD
MargaretDuff Jan 25, 2024
3531b96
Tidy up PR
MargaretDuff Jan 25, 2024
d383733
Tidy up PR
MargaretDuff Jan 25, 2024
470ed86
Updated doc strings and requirements for sampler class - need to do d…
MargaretDuff Jan 25, 2024
54cf27c
optimisation.rst updated to add in the new documentation
MargaretDuff Jan 26, 2024
4084a19
Merge branch 'master' into svrg
MargaretDuff Feb 12, 2024
4c4a26c
Changes after discussion with Edo and Kris
MargaretDuff Feb 12, 2024
b6b2a9e
Merge in the base class
MargaretDuff Feb 12, 2024
0921847
Fixed merge error
MargaretDuff Feb 12, 2024
c064388
Fixed merge error
MargaretDuff Feb 12, 2024
4b97d9b
Fixed merge error
MargaretDuff Feb 13, 2024
38a4b3a
Skip if you don't have astra
MargaretDuff Feb 13, 2024
f41ae7b
Added skip astra
MargaretDuff Feb 13, 2024
4dc814f
Changes after discussion with Edo and Kris
MargaretDuff Feb 13, 2024
be75374
New data_passes function and getter
MargaretDuff Feb 13, 2024
778c7c1
New data_passes function and getter
MargaretDuff Feb 13, 2024
958fb6c
Merge with SGD
MargaretDuff Feb 13, 2024
2f3eb50
Docstrings
MargaretDuff Feb 13, 2024
655df78
Merge branch 'master' into SGD
MargaretDuff Feb 13, 2024
2f6f0f4
Merge branch 'SGD' into svrg
MargaretDuff Feb 13, 2024
3f92a8f
Updated documentation rst
MargaretDuff Feb 13, 2024
9990536
Rate to step_size
MargaretDuff Feb 13, 2024
d0f189b
Merge branch 'SGD' into svrg
MargaretDuff Feb 13, 2024
f125e32
Use backtracking in unit tests
MargaretDuff Feb 13, 2024
57b71e1
Use backtracking in unit tests
MargaretDuff Feb 13, 2024
fae9907
Discussed with Zeljko and Edo
MargaretDuff Feb 13, 2024
eac5397
Getter for num_functions
MargaretDuff Feb 13, 2024
9befcf8
Discussion with Zeljko, Edo and Vaggelis
MargaretDuff Feb 14, 2024
e1019a3
Documentation on data passes
MargaretDuff Feb 14, 2024
7d12eca
Merged in SGD
MargaretDuff Feb 15, 2024
c7a8e9e
Data passes indices
MargaretDuff Feb 15, 2024
c971dd6
Comments from discussion with Edo
MargaretDuff Feb 27, 2024
66c4853
Merge branch 'master' into SGD
MargaretDuff Feb 27, 2024
addfe47
Changes after Vaggelis review
MargaretDuff Mar 12, 2024
649e186
Merge branch 'master' into SGD
MargaretDuff Mar 12, 2024
a3bd92a
Merge branch 'master' into SGD
MargaretDuff Mar 13, 2024
f1a1c64
Merge branch 'master' into svrg
MargaretDuff Mar 13, 2024
d779c84
Merge branch 'SGD' into svrg
MargaretDuff Mar 13, 2024
df1ec24
Updated tests
MargaretDuff Mar 13, 2024
3a74f75
Tweak to unit tests after discussion with Edo
MargaretDuff Mar 14, 2024
a73654d
Changes after discussion with Edo
MargaretDuff Mar 14, 2024
8a364de
Fix unit test
MargaretDuff Mar 14, 2024
2395b03
Some of Jakob's comments
MargaretDuff Mar 20, 2024
a7b5016
Updated documentation from Vaggelis and Jakob comments
MargaretDuff Mar 21, 2024
c716d31
Merge branch 'master' into SGD
MargaretDuff Mar 21, 2024
b9d8ab4
Try to fix rst file example
MargaretDuff Mar 21, 2024
08e91a9
Merge branch 'SGD' of github.com:MargaretDuff/CIL-margaret into SGD
MargaretDuff Mar 21, 2024
9aebc50
Try to fix rst file example
MargaretDuff Mar 21, 2024
e034e03
Try to fix rst file bullet points
MargaretDuff Mar 21, 2024
cf0e60f
Try to fix rst file example
MargaretDuff Mar 21, 2024
fe21db0
Try to fix SGD docs
MargaretDuff Mar 21, 2024
53dd837
Updated example after Vaggelis comments
MargaretDuff Mar 21, 2024
843413f
Discussions with Edo and Gemma
MargaretDuff Mar 21, 2024
f3e416a
Documentation for the multiplication factor
MargaretDuff Mar 22, 2024
37e06d7
Merge
MargaretDuff Mar 22, 2024
c22868e
Documentation for the multiplication factor
MargaretDuff Mar 22, 2024
1cffc65
Updates from Edo and Gemma and some of Jakobs comments
MargaretDuff Mar 22, 2024
6294ccb
Merge branch 'master' into SGD
MargaretDuff Mar 25, 2024
561f3e7
Improved documentation after discussion with Edo
MargaretDuff Mar 25, 2024
5a8a4c4
merged
MargaretDuff Mar 25, 2024
002714d
Updates from Jakob's comments
MargaretDuff Mar 25, 2024
55753db
Fix failing test
MargaretDuff Mar 25, 2024
0f8f6b4
Merge branch 'master' into svrg
MargaretDuff Mar 26, 2024
8d1863c
Merge branch 'master' into svrg
MargaretDuff Apr 25, 2024
29787ef
Small things from merge
MargaretDuff Apr 25, 2024
6b2244f
Merge branch 'TomographicImaging:master' into svrg
MargaretDuff May 2, 2024
168715a
Merge branch 'master' into svrg
MargaretDuff May 14, 2024
8b7d8f9
Comments from Kris
MargaretDuff May 14, 2024
d282617
Merge
MargaretDuff May 14, 2024
bc643f2
Kris suggestion
MargaretDuff May 14, 2024
5b3d477
Changes from Jakob's review
MargaretDuff Jul 9, 2024
fab896e
Merge branch 'master' into svrg
MargaretDuff Jul 9, 2024
f0f4cc7
Small documentation changes to make it similar to sag-saga
MargaretDuff Jul 10, 2024
0476497
Vaggelis comments
MargaretDuff Jul 15, 2024
31b13e8
Merge branch 'master' into svrg
MargaretDuff Jul 17, 2024
325983c
Changes from Edo's review
MargaretDuff Jul 17, 2024
a6e7eee
Added info on memory requirements
MargaretDuff Jul 17, 2024
f1ceb79
Documentation spell checking
MargaretDuff Jul 17, 2024
5e5a5cb
Updates to unit tests
MargaretDuff Jul 19, 2024
bd18f2f
Fix to tests
MargaretDuff Jul 19, 2024
50c5c2b
Merge branch 'master' into svrg
paskino Jul 21, 2024
7f16acb
add missing import
paskino Jul 21, 2024
a9fb0e1
fix seed for LSVRG test
paskino Jul 21, 2024
d934c72
fix numba implementation of KullbackLeibler
paskino Jul 21, 2024
e6d7d80
Merge branch 'master' into svrg
paskino Jul 22, 2024
df4082b
added to changelog [ci skip]
paskino Jul 22, 2024
f15c313
checkout master KL
paskino Jul 26, 2024
4464173
Merge branch 'master' into svrg
MargaretDuff Aug 12, 2024
33c092e
Removed unnecessary changes
MargaretDuff Aug 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
* 24.x.x
- New Features:
- Added SVRG and LSVRG stochastic functions (#1625)
- Added SAG and SAGA stochastic functions (#1624)
- Allow `SumFunction` with 1 item (#1857)
- Enhancements:
Expand Down
283 changes: 283 additions & 0 deletions Wrappers/Python/cil/optimisation/functions/SVRGFunction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
# Copyright 2024 United Kingdom Research and Innovation
# Copyright 2024 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Authors:
# - CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt
# - Daniel Deidda (National Physical Laboratory, UK)
# - Claire Delplancke (Electricite de France, Research and Development)
# - Ashley Gillman (Australian e-Health Res. Ctr., CSIRO, Brisbane, Queensland, Australia)
# - Zeljko Kereta (Department of Computer Science, University College London, UK)
# - Evgueni Ovtchinnikov (STFC - UKRI)
# - Georg Schramm (Department of Imaging and Pathology, Division of Nuclear Medicine, KU Leuven, Leuven, Belgium)


from .ApproximateGradientSumFunction import ApproximateGradientSumFunction
import numpy as np
import numbers


class SVRGFunction(ApproximateGradientSumFunction):

r"""
The Stochastic Variance Reduced Gradient (SVRG) function calculates the approximate gradient of :math:`\sum_{i=1}^{n-1}f_i`. For this approximation, every `snapshot_update_interval` number of iterations, a full gradient calculation is made at this "snapshot" point. Intermediate gradient calculations update this snapshot by taking a index :math:`i_k` and calculating the gradient of :math:`f_{i_k}`s at the current iterate and the snapshot, updating the approximate gradient to be:

.. math ::
n*\nabla f_{i_k}(x_k) - n*\nabla f_{i_k}(\tilde{x}) + \nabla \sum_{i=0}^{n-1}f_i(\tilde{x}),

where :math:`\tilde{x}` is the latest "snapshot" point and :math:`x_k` is the value at the current iteration.

Note
-----
Compared with the literature, we multiply by :math:`n`, the number of functions, so that we return an approximate gradient of the whole sum function and not an average gradient.

Note
----
In the case where `store_gradients=False` the memory requirements are 4 times the image size (1 stored full gradient at the "snapshot", one stored "snapshot" point and two lots of intermediary calculations). Alternatively, if `store_gradients=True` the memory requirement is `n+4` (`n` gradients at the snapshot for each function in the sum, one stored full gradient at the "snapshot", one stored "snapshot" point and two lots of intermediary calculations).

Reference
---------
Johnson, R. and Zhang, T., 2013. Accelerating stochastic gradient descent using predictive variance reduction. Advances in neural information processing systems, 26.https://proceedings.neurips.cc/paper_files/paper/2013/file/ac1dd209cbcc5e5d1c6e28598e8cbbe8-Paper.pdf


MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
functions : `list` of functions
A list of functions: :code:`[f_{0}, f_{1}, ..., f_{n-1}]`. Each function is assumed to be smooth with an implemented :func:`~Function.gradient` method. All functions must have the same domain. The number of functions must be strictly greater than 1.
sampler: An instance of a CIL Sampler class ( :meth:`~optimisation.utilities.sampler`) or of another class which has a `next` function implemented to output integers in {0, 1, ..., n-1}.
This sampler is called each time gradient is called and sets the internal `function_num` passed to the `approximate_gradient` function. Default is `Sampler.random_with_replacement(len(functions))`.
snapshot_update_interval : positive int or None, optional
The interval for updating the full gradient (taking a snapshot). The default is 2*len(functions) so a "snapshot" is taken every 2*len(functions) iterations. If the user passes `0` then no full gradient snapshots will be taken.
store_gradients : bool, default: `False`
Flag indicating whether to store an update a list of gradients for each function :math:`f_i` or just to store the snapshot point :math:` \tilde{x}` and its gradient :math:`\nabla \sum_{i=0}^{n-1}f_i(\tilde{x})`.


"""

def __init__(self, functions, sampler=None, snapshot_update_interval=None, store_gradients=False):
super(SVRGFunction, self).__init__(functions, sampler)

# snapshot_update_interval for SVRG
self.snapshot_update_interval = snapshot_update_interval

if snapshot_update_interval is None:
self.snapshot_update_interval = 2*self.num_functions
self.store_gradients = store_gradients

self._svrg_iter_number = 0

self._full_gradient_at_snapshot = None
self._list_stored_gradients = None

self.stoch_grad_at_iterate = None
self._stochastic_grad_difference = None

self.snapshot = None

def gradient(self, x, out=None):
""" Selects a random function using the `sampler` and then calls the approximate gradient at :code:`x` or calculates a full gradient depending on the update frequency

Parameters
----------
x : DataContainer (e.g. ImageData object)
paskino marked this conversation as resolved.
Show resolved Hide resolved
out: return DataContainer, if `None` a new DataContainer is returned, default `None`.
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

Returns
--------
DataContainer (e.g. ImageData object)
the value of the approximate gradient of the sum function at :code:`x`
"""

# For SVRG, every `snapshot_update_interval` a full gradient step is calculated, else an approximate gradient is taken.
if ( (self.snapshot_update_interval != 0) and (self._svrg_iter_number % (self.snapshot_update_interval)) == 0):

return self._update_full_gradient_and_return(x, out=out)

else:

self.function_num = self.sampler.next()
if not isinstance(self.function_num, numbers.Number):
raise ValueError("Batch gradient is not yet implemented")
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
if self.function_num >= self.num_functions or self.function_num < 0:
raise IndexError(
f"The sampler has produced the index {self.function_num} which does not match the expected range of available functions to sample from. Please ensure your sampler only selects from [0,1,...,len(functions)-1] ")
return self.approximate_gradient(x, self.function_num, out=out)

def approximate_gradient(self, x, function_num, out=None):
""" Calculates the stochastic gradient at the point :math:`x` by using the gradient of the selected function, indexed by :math:`i_k`, the `function_number` in {0,...,len(functions)-1}, and the full gradient at the snapshot :math:`\tilde{x}`
.. math ::
n*\nabla f_{i_k}(x_k) - n*\nabla f_{i_k}(\tilde{x}) + \nabla \sum_{i=0}^{n-1}f_i(\tilde{x})

Note
-----
Compared with the literature, we multiply by :math:`n`, the number of functions, so that we return an approximate gradient of the whole sum function and not an average gradient.

Parameters
----------
x : DataContainer (e.g. ImageData object)
out: return DataContainer, if `None` a new DataContainer is returned, default `None`.
function_num: `int`
Between 0 and n-1, where n is the number of functions in the list
Returns
--------
DataContainer (e.g. ImageData object)
the value of the approximate gradient of the sum function at :code:`x` given a `function_number` in {0,...,len(functions)-1}
"""

self._svrg_iter_number += 1

self.stoch_grad_at_iterate = self.functions[function_num].gradient(x, out=self.stoch_grad_at_iterate)

if self.store_gradients is True:
self._stochastic_grad_difference = self.stoch_grad_at_iterate.sapyb(
1., self._list_stored_gradients[function_num], -1., out=self._stochastic_grad_difference)
else:
self._stochastic_grad_difference = self.stoch_grad_at_iterate.sapyb(
1., self.functions[function_num].gradient(self.snapshot), -1., out=self._stochastic_grad_difference)

self._update_data_passes_indices([function_num])

out = self._stochastic_grad_difference.sapyb(
self.num_functions, self._full_gradient_at_snapshot, 1., out=out)

return out

def _update_full_gradient_and_return(self, x, out=None):
"""
Takes a "snapshot" at the point :math:`x`, saving both the point :math:` \tilde{x}=x` and its gradient :math:`\sum_{i=0}^{n-1}f_i{\tilde{x}}`. The function returns :math:`\sum_{i=0}^{n-1}f_i{\tilde{x}}` as the gradient calculation. If :code:`store_gradients==True`, the gradient of all the :math:`f_i`s is computed and stored at the "snapshot"..
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
Takes a "snapshot" at the point :math:`x`. The function returns :math:`\sum_{i=0}^{n-1}f_i{\tilde{x}}` as the gradient calculation. If :code:`store_gradients==True`, the gradient of all the :math:`f_i`s is stored, otherwise only the sum of the gradients and the snapshot point :math:` \tilde{x}=x` are stored.
out: return DataContainer, if `None` a new DataContainer is returned, default `None`.

Returns
--------
DataContainer (e.g. ImageData object)
the value of the approximate gradient of the sum function at :code:`x` given a `function_number` in {0,...,len(functions)-1}
"""

self._svrg_iter_number += 1
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

if self.store_gradients is True:
if self._list_stored_gradients is None:
# Save the gradient of each individual f_i and the gradient of the full sum at the point x.
self._list_stored_gradients = [
fi.gradient(x) for fi in self.functions]
self._full_gradient_at_snapshot = sum(
self._list_stored_gradients, start=0*x)
else:
for i, fi in enumerate(self.functions):
fi.gradient(x, out=self._list_stored_gradients[i])

self._full_gradient_at_snapshot.fill(
sum(self._list_stored_gradients, start=0*x))
self._full_gradient_at_snapshot *= 0

for i, el in enumerate(self._list_stored_gradients):
self._full_gradient_at_snapshot += el

else:
# Save the snapshot point and the gradient of the full sum at the point x.
self._full_gradient_at_snapshot = self.full_gradient(
x, out=self._full_gradient_at_snapshot)

if self.snapshot is None:
self.snapshot = x.copy()

self.snapshot.fill(x)

# In this iteration all functions in the sum were used to update the gradient
self._update_data_passes_indices(list(range(self.num_functions)))

# Return the gradient of the full sum at the snapshot.
if out is None:
out = self._full_gradient_at_snapshot
else:
out.fill(self._full_gradient_at_snapshot)

return out


class LSVRGFunction(SVRGFunction):
"""""
A class representing a function for Loopless Stochastic Variance Reduced Gradient (SVRG) approximation. This is similar to SVRG, except the full gradient at a "snapshot" is calculated at random intervals rather than at fixed numbers of iterations.


Reference
----------

Kovalev, D., Horváth, S. &; Richtárik, P.. (2020). Don’t Jump Through Hoops and Remove Those Loops: SVRG and Katyusha are Better Without the Outer Loop. Proceedings of the 31st International Conference on Algorithmic Learning Theory, in Proceedings of Machine Learning Research 117:451-467 Available from https://proceedings.mlr.press/v117/kovalev20a.html.



Parameters
----------
functions : `list` of functions
A list of functions: :code:`[f_{0}, f_{1}, ..., f_{n-1}]`. Each function is assumed to be smooth with an implemented :func:`~Function.gradient` method. All functions must have the same domain. The number of functions `n` must be strictly greater than 1.
sampler: An instance of a CIL Sampler class ( :meth:`~optimisation.utilities.sampler`) or of another class which has a `next` function implemented to output integers in {0,...,n-1}.
This sampler is called each time gradient is called and sets the internal `function_num` passed to the `approximate_gradient` function. Default is `Sampler.random_with_replacement(len(functions))`.
snapshot_update_probability: positive float, default: 1/n
The probability of updating the full gradient (taking a snapshot) at each iteration. The default is :math:`1./n` so, in expectation, a snapshot will be taken every :math:`n` iterations.
store_gradients : bool, default: `False`
Flag indicating whether to store an update a list of gradients for each function :math:`f_i` or just to store the snapshot point :math:` \tilde{x}` and it's gradient :math:`\nabla \sum_{i=0}^{n-1}f_i(\tilde{x})`.


Note
----
In the case where `store_gradients=False` the memory requirements are 4 times the image size (1 stored full gradient at the "snapshot", one stored "snapshot" point and two lots of intermediary calculations). Alternatively, if `store_gradients=True` the memory requirement is `n+4` (`n` gradients at the snapshot for each function in the sum, one stored full gradient at the "snapshot", one stored "snapshot" point and two lots of intermediary calculations).

"""

def __init__(self, functions, sampler=None, snapshot_update_probability=None, store_gradients=False, seed=None):

super(LSVRGFunction, self).__init__(
functions, sampler=sampler, store_gradients=store_gradients)

# Update frequency based on probability.
self.snapshot_update_probability = snapshot_update_probability
# Default snapshot_update_probability for Loopless SVRG
if self.snapshot_update_probability is None:
self.snapshot_update_probability = 1./self.num_functions

# The random generator used to decide if the gradient calculation is a full gradient or an approximate gradient
self.generator = np.random.default_rng(seed=seed)

def gradient(self, x, out=None):
""" Selects a random function using the `sampler` and then calls the approximate gradient at :code:`x` or calculates a full gradient depending on the update probability.

Parameters
----------
x : DataContainer (e.g. ImageData objects)
out: return DataContainer, if `None` a new DataContainer is returned, default `None`.

Returns
--------
DataContainer (e.g. ImageData object)
the value of the approximate gradient of the sum function at :code:`x`
"""

if self._svrg_iter_number == 0 or self.generator.uniform() < self.snapshot_update_probability:

return self._update_full_gradient_and_return(x, out=out)

else:

self.function_num = self.sampler.next()
if not isinstance(self.function_num, numbers.Number):
raise ValueError("Batch gradient is not yet implemented")
if self.function_num >= self.num_functions or self.function_num < 0:
raise IndexError(
f"The sampler has produced the index {self.function_num} which does not match the expected range of available functions to sample from. Please ensure your sampler only selects from [0,1,...,len(functions)-1] ")
return self.approximate_gradient(x, self.function_num, out=out)
1 change: 1 addition & 0 deletions Wrappers/Python/cil/optimisation/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,6 @@
from .L1Sparsity import L1Sparsity
from .ApproximateGradientSumFunction import ApproximateGradientSumFunction
from .SGFunction import SGFunction
from .SVRGFunction import SVRGFunction, LSVRGFunction
from .SAGFunction import SAGFunction, SAGAFunction

Loading
Loading