Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lion 8 bit #188

Merged
merged 21 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
7247cb4
initial commit, slowly work from interface into the kernel
lucidrains Mar 9, 2023
d43ea97
make sure interface is correct
lucidrains Mar 9, 2023
cb4c3c8
do a bunch of typical bookkeeping before getting to main lion logic
lucidrains Mar 9, 2023
8de29fc
forget about tests for now, will test live on local enwik8 training
lucidrains Mar 9, 2023
64bb1ae
add a sign function, for lion
lucidrains Mar 9, 2023
c83888a
use epsilon as beta2 for lion, complete most of the logic in kernel.c…
lucidrains Mar 9, 2023
ead570a
remove something rmsprop specific
lucidrains Mar 9, 2023
af03430
fix weight decay for lion to be decoupled, using a switch
lucidrains Mar 9, 2023
c558272
missed adagrad
lucidrains Mar 9, 2023
8618bed
swap the order in which momentum and parameters are updated in ops.cu
lucidrains Mar 10, 2023
c99b44f
do the epsilon beta2 switcharoo within the cuda code, and not within …
lucidrains Mar 10, 2023
19b9ef3
whoops
lucidrains Mar 10, 2023
abbe65a
beta2 is actually accessible in kOptimizerStatic8bit1StateBlockwise
lucidrains Mar 10, 2023
6c377b3
always pass beta2 into all the 1state functions
lucidrains Mar 10, 2023
369a51c
switch all eps to beta2
lucidrains Mar 10, 2023
9b656f4
follow advice of Tim to fix update of momentum vs parameters in block…
lucidrains Mar 22, 2023
a43cd20
add some code in test_optim.py, although it seems to be failing
lucidrains Mar 22, 2023
aa9b939
add some comments, and fix use of g_val
lucidrains Mar 22, 2023
916000c
fix consistent tabs / spaces
lucidrains Mar 22, 2023
978ba2d
another tab/spaces fix
lucidrains Mar 22, 2023
2a6828e
fix comment
lucidrains Mar 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ out = linear(x.to(torch.float16))
## Features
- 8-bit Matrix multiplication with mixed precision decomposition
- LLM.int8() inference
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB (saves 75% memory)
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory)
- Stable Embedding Layer: Improved stability through better initialization, and normalization
- 8-bit quantization: Quantile, Linear, and Dynamic quantization
- Fast quantile estimation: Up to 100x faster than other algorithms
Expand Down
12 changes: 12 additions & 0 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def prod(iterable):
lib.crmsprop32bit_g32,
lib.crmsprop32bit_g16,
)
str2optimizer32bit["lion"] = (
lib.clion32bit_g32,
lib.clion32bit_g16,
)
str2optimizer32bit["adagrad"] = (
lib.cadagrad32bit_g32,
lib.cadagrad32bit_g16,
Expand All @@ -58,6 +62,10 @@ def prod(iterable):
lib.crmsprop_static_8bit_g32,
lib.crmsprop_static_8bit_g16,
)
str2optimizer8bit["lion"] = (
lib.clion_static_8bit_g32,
lib.clion_static_8bit_g16,
)
str2optimizer8bit["lamb"] = (
lib.cadam_static_8bit_g32,
lib.cadam_static_8bit_g16,
Expand All @@ -80,6 +88,10 @@ def prod(iterable):
lib.crmsprop_8bit_blockwise_fp32,
lib.crmsprop_8bit_blockwise_fp16,
)
str2optimizer8bit_blockwise["lion"] = (
lib.clion_8bit_blockwise_fp32,
lib.clion_8bit_blockwise_fp16,
)
str2optimizer8bit_blockwise["adagrad"] = (
lib.cadagrad_8bit_blockwise_fp32,
lib.cadagrad_8bit_blockwise_fp16,
Expand Down
1 change: 1 addition & 0 deletions bitsandbytes/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .optimizer import GlobalOptimManager
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .lion import Lion, Lion8bit, Lion32bit
from .sgd import SGD, SGD8bit, SGD32bit
87 changes: 87 additions & 0 deletions bitsandbytes/optim/lion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State


class Lion(Optimizer1State):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)


class Lion8bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)


class Lion32bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
81 changes: 71 additions & 10 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ __device__ float atomicMin(float* address, float val) {
return __int_as_float(old);
}

// sign function for lion
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA

template <typename T>
__device__ int sgn(T val) {
return (T(0) < val) - (val < T(0));
}

template <int STOCHASTIC>
__device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
{
Expand Down Expand Up @@ -743,7 +751,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm,
const float beta1, const float eps, const float weight_decay,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n)
{

Expand Down Expand Up @@ -790,6 +798,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
break;
case LION:
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
Expand Down Expand Up @@ -821,7 +832,7 @@ template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1)
__global__ void kOptimizer32bit1State(T *g, T *p,
float *state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
{

Expand Down Expand Up @@ -890,6 +901,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p,

p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j]));
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
Expand Down Expand Up @@ -1158,7 +1173,7 @@ __global__ void
__launch_bounds__(NUM_THREADS, 2)
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm,
const float beta1,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
Expand Down Expand Up @@ -1219,6 +1234,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
if(unorm != NULL)
local_unorm += s1_vals[j]*s1_vals[j];
break;
case LION:
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break;
Expand All @@ -1244,7 +1262,7 @@ template<typename T, int OPTIMIZER>
__global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
Expand Down Expand Up @@ -1307,8 +1325,19 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
if(weight_decay > 0.0f)
g_val += ((float)p_vals[j])*weight_decay;

if(weight_decay > 0.0f) {
switch(OPTIMIZER) {
case MOMENTUM:
case RMSPROP:
g_val += ((float)p_vals[j])*weight_decay;
break;
case LION:
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
break;
}
}

s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0];

switch(OPTIMIZER)
Expand All @@ -1321,6 +1350,10 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,

p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
Expand Down Expand Up @@ -1651,8 +1684,18 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
g_val *= gnorm_scale;
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
if(weight_decay > 0.0f)
g_val += ((float)p_vals[j])*weight_decay;
if(weight_decay > 0.0f) {
switch(OPTIMIZER) {
case MOMENTUM:
case ADAGRAD:
case RMSPROP:
g_val += ((float)p_vals[j])*weight_decay;
break;
case LION:
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
break;
}
}

s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];

Expand All @@ -1664,6 +1707,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
else
s1_vals[j] = (s1_vals[j]*beta1) + g_val;
break;
case LION:
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
lucidrains marked this conversation as resolved.
Show resolved Hide resolved
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break;
Expand Down Expand Up @@ -1701,6 +1747,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
case MOMENTUM:
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])));
break;
case RMSPROP:
g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
Expand Down Expand Up @@ -2692,24 +2741,28 @@ template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *c
#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float *unorm, \
const float beta1, const float eps, const float weight_decay, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \

MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
MAKE_PreconditionOptimizer32bit1State(LION, half)
MAKE_PreconditionOptimizer32bit1State(LION, float)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)

#define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \

MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State(MOMENTUM, float)
MAKE_Optimizer32bit1State(RMSPROP, half)
MAKE_Optimizer32bit1State(RMSPROP, float)
MAKE_Optimizer32bit1State(LION, half)
MAKE_Optimizer32bit1State(LION, float)
MAKE_Optimizer32bit1State(ADAGRAD, half)
MAKE_Optimizer32bit1State(ADAGRAD, float)

Expand All @@ -2731,6 +2784,7 @@ template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p,
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
float *unorm, \
const float beta1, \
const float beta2, \
const float eps, const int step, \
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
Expand All @@ -2742,11 +2796,14 @@ MAKE_PreconditionStatic8bit1State(MOMENTUM, half)
MAKE_PreconditionStatic8bit1State(MOMENTUM, float)
MAKE_PreconditionStatic8bit1State(RMSPROP, half)
MAKE_PreconditionStatic8bit1State(RMSPROP, float)
MAKE_PreconditionStatic8bit1State(LION, half)
MAKE_PreconditionStatic8bit1State(LION, float)

#define MAKE_optimizerStatic8bit1State(oname, gtype) \
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, \
const float beta2, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
Expand All @@ -2758,6 +2815,8 @@ MAKE_optimizerStatic8bit1State(MOMENTUM, half)
MAKE_optimizerStatic8bit1State(MOMENTUM, float)
MAKE_optimizerStatic8bit1State(RMSPROP, half)
MAKE_optimizerStatic8bit1State(RMSPROP, float)
MAKE_optimizerStatic8bit1State(LION, half)
MAKE_optimizerStatic8bit1State(LION, float)

#define MAKE_PreconditionStatic8bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \
Expand Down Expand Up @@ -2849,5 +2908,7 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)
8 changes: 4 additions & 4 deletions csrc/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,20 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm,
const float beta1, const float eps, const float weight_decay,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n);

template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);

template<typename T, int OPTIMIZER>
__global__ void
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm,
const float beta1,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
Expand All @@ -57,7 +57,7 @@ template<typename T, int OPTIMIZER>
__global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
Expand Down
Loading