Study of Mish activation function in transfer learning with code and discussion
Resnet50 model is tested by using Mish in its FClayers. Also, a detailed discussion of the paper is done with code implementation.
- PyTorch implementation
- Plot mish function
- Properties of mish
- Testing Mish against ReLU
- Visualization of output landscape
Link to jupyter notebook, paper, fastai discussion thread
Mish activation function is proposed in Mish: A Self Regularized Non-Monotonic Neural Activation Function paper. The experiments conducted in the paper shows it achieves better accuracy than ReLU. Also, many experiments have been conducted by the fastai community and they were also able to achieve better results than ReLU.
Mish is defined as x * tanh(softplus(x))
or by this equation $x*\tanh (\ln (1+e^x))$.
class Mish(nn.Module):
r"""
Mish activation function is proposed in "Mish: A Self
Regularized Non-Monotonic Neural Activation Function"
paper, https://arxiv.org/abs/1908.08681.
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x * torch.tanh(F.softplus(x))
x = np.linspace(-7, 7, 700)
y = x * np.tanh(np.log(1 + np.exp(x)))
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.spines['left'].set_position('center')
ax.spines['bottom'].set_position('zero')
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
plt.plot(x, y, 'b')
plt.savefig(fname='/home/kushaj/Desktop/Temp/SOTA/images/mish_plot.png', dpi=1200)
Properties of mish
- Unbounded Above:- Being unbounded above is a desired property of an activation function as it avoids saturation which causes training to slow down to near-zero gradients.
- Bounded Below:- Being bounded below is desired because it results in strong regularization effects.
- Non-monotonic:- This is the important factor in mish. We preserve small negative gradients and this allows the network to learn better and it also improves the gradient flow in the negative region as, unlike ReLU where all negative gradients become zero.
- Continuity:- Mish’s first derivative is continuous over the entire domain which helps in effective optimization and generalization. Unlike ReLU which is discontinuous at zero.
To compute the first derivative expand the tanh(softplus(x))
term and you will get the following term and then do product rule of the derivative.
$$y=x*\frac{e^{2x}+2e^x}{e^{2x}+2e^x+2}$$
When using Mish against ReLU use a lower learning rate in the case of Mish. Range of around 1e-5 to 1e-1 showed good results.
Testing Mish against ReLU
Rather than training from scratch which is already done in the paper, I would test for transfer learning. When we use pretrained models for our own dataset we keep the CNN filter weights the same (we update them during finetuning) but we initialize the last fully-connected layers randomly (head of the model). So I would test for using ReLU and Mish in these fully-connected layers.
I use CIFAR10 and CIFAR100 dataset to test a pretrained Resnet50 model. I would run the model for 10 epochs and then compare the results at the fifth and tenth epoch. Also, the results would be averaged across 3 runs using different learning rates (1e-2, 5e-3, 1e-3). The weighs of the CNN filters would not be updated, only the fully connected layers would be updated/trained.
For the fully connected layers, I would use the following architecture. In case of Mish, replace the ReLU with Mish.
# AdaptiveConcatPool2d is just combining AdaptiveAvgPool and AdaptiveMaxPool.
head = nn.Sequential(
AdaptiveConcatPool2d(),
Flatten(),
nn.BatchNorm1d(4096),
nn.Dropout(p=0.25),
nn.Linear(in_features=4096, out_features=512),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(p=0.5),
nn.Linear(in_features=512, out_features=10)
)
The final results are shown below. It was observed that Mish required training with a smaller learning rate otherwise it overfits quickly, thus suggesting that it requires stronger regularization than ReLU. It was consistent across multiple runs. Generally, you can get away with using a higher learning rate in the case of ReLU but when using Mish a higher learning rate always lead to overfitting.
Although the results are quite similar but by using Mish we can see some marginal improvements. This is a very limited test as only one Mish activation is used in the entire network and also the network has been run for only 10 epochs.
from sklearn.preprocessing import MinMaxScaler
from PIL import Image
# The following code has been taken from
# https://github.com/digantamisra98/Mish/blob/master/output_landscape.py
def get_model(act_fn='relu'):
if act_fn is 'relu':
fn = nn.ReLU(inplace=True)
if act_fn is 'mish':
fn = Mish()
model = nn.Sequential(
nn.Linear(2, 64),
fn,
nn.Linear(64, 32),
fn,
nn.Linear(32, 16),
fn,
nn.Linear(16, 1),
fn
)
return model
# Main code
relu_model = get_model('relu')
mish_model = get_model('mish')
x = np.linspace(0., 10., 100)
y = np.linspace(0., 10., 100)
grid = [torch.tensor([xi, yi]) for xi in x for yi in y]
np_img_relu = np.array([relu_model(point).detach().numpy() for point in grid]).reshape(100, 100)
np_img_mish = np.array([mish_model(point).detach().numpy() for point in grid]).reshape(100, 100)
scaler = MinMaxScaler(feature_range=(0, 255))
np_img_relu = scaler.fit_transform(np_img_relu)
np_img_mish = scaler.fit_transform(np_img_mish)
plt.imsave('relu_land.png', np_img_relu)
plt.imsave('mish_land.png', np_img_mish)
From the above output landscapes, we can observe that the mish produces a smoother output landscape thus resulting is smoother loss functions which are easier to optimize and thus the network generalizes better.