Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Feature request] temperature parameter in Softmax and SoftmaxOutput #11016

Closed
slitsey opened this issue May 21, 2018 · 8 comments
Closed

[Feature request] temperature parameter in Softmax and SoftmaxOutput #11016

slitsey opened this issue May 21, 2018 · 8 comments

Comments

@slitsey
Copy link

slitsey commented May 21, 2018

MXNet does not appear to have a native temperature parameter in its softmax functions. I would like this to be added, as it has many useful applications when learning a categorical probability distribution, especially in a reinforcement learning setting. Should default to 1 to reproduce current behavior.

https://en.wikipedia.org/wiki/Softmax_function#Reinforcement_learning

@eric-haibin-lin

@srochel
Copy link
Contributor

srochel commented Jun 7, 2018

@slitsey Would you be able to provide an example we can use to validate the requested feature?

@slitsey
Copy link
Author

slitsey commented Jun 7, 2018

@srochel @eric-haibin-lin I'm expecting behavior similar to the short program below:

import numpy as np

def softmax(x, temperature=1):
    return np.exp(x/temperature)/sum(np.exp(x/temperature))

x = np.array([ 1,  2,  3])

print(softmax(x))
print(softmax(x, temperature=10))
print(softmax(x, temperature=0.1))

which returns

[ 0.09003057  0.24472847  0.66524096]
[ 0.30060961  0.33222499  0.3671654 ]
[  2.06106005e-09   4.53978686e-05   9.99954600e-01]

This allows interpolation between a uniform distribution in the high-temperature limit and a greedy distribution with all probability mass on the most probable index in the zero-temperature limit. Starting with high temperature and decreasing throughout learning allows transitioning from exploration-heavy to exploitation-heavy policies in RL, if the softmax represents the policy's action distribution. Again, there are potential numerical stability issues with this due to the exponentials. But probably something like

import mxnet as mx

data = mx.sym.Variable('data')
net = mx.sym.softmax(data=data, temperature=10)

x = mx.nd.array([ 1,  2,  3])

ex = net.bind(mx.cpu(), args={'data': x, 'softmax2_label': 'softmax2'})
ex.forward()

should return

[
[ 0.30060961  0.33222499  0.3671654 ]
 <NDArray 3 @cpu(0)>]

for example. And of course this behavior should generalize as expected across axes, etc., and also be incorporated into mx.sym.SoftmaxOutput (and anywhere else softmax/Boltzmann distributions might arise in MXNet).

@apeforest
Copy link
Contributor

Hi @slitsey I will start working on this Feature Request. There are multiple operators related to softmax. I will first implement the operators mx.sym.softmax and mx.sym.SoftmaxOutput as suggested in your example. Please advise otherwise. Thanks!

@apeforest
Copy link
Contributor

I have created a JIRA for this issue (https://issues.apache.org/jira/browse/MXNET-560) and I have started the implementation. Estimated development time: 3 days

@apeforest
Copy link
Contributor

@slitsey This feature has been merged. Please pull the latest code and verify. Thanks!

@szha szha closed this as completed Jul 20, 2018
@slitsey
Copy link
Author

slitsey commented Jul 24, 2018

@apeforest Looks great; thanks for the work on this!

@apeforest
Copy link
Contributor

apeforest commented Jul 26, 2018

@slitsey It would be great if you could share with me how you used this feature in reinforcement training. I am curious to learn and think it will be very helpful to other users who are not aware of this feature. A few lines of code or a pointer to your repo will be ideal. Thanks a lot in advance!

@slitsey
Copy link
Author

slitsey commented Jul 30, 2018

@apeforest I can’t share any code with you because I haven’t written it yet. However, I can explain the idea. One approach to reinforcement learning is to have a function that generates a probability distribution over all possible actions for each state, and then to take each action with the associated probability. If this probability distribution is generated by a softmax at the end of a deep learning model, we can optimize it by taking series of actions, measuring or estimating the rewards of each, and optimizing the distribution to maximize the expected reward. Temperature comes in because early in training, we have little information about the value of each action, so we want to intentionally bias the distribution to be more uniform, encouraging exploration. This can be achieved with a high temperature. As training continues, we can reduce the temperature, shifting probability mass to the most valuable actions, as we learn more about which actions are actually valuable. If there is a unique best action at each state, the zero-temperature limit will put all the probability mass on that best action, yielding an optimal trajectory.

There are other applications outside of RL as well; for example, temperature can be used in model distillation, i.e. training a simple model to mimic a complex one.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

6 participants