学术, 机器学习

神经网络中常见的激活函数

激活函数在神经网络中扮演着至关重要的角色,它们用于引入非线性特性,使神经网络能够学习和适应更复杂的数据模式。本篇列出几个常见的激活函数,简单形式比较常用的是 ReLU 函数和Leaky ReLU 函数。

1. Sigmoid 函数

PyTorch实现(类):torch.nn.Sigmoid()

PyTorch实现(函数):torch.nn.functional.sigmoid()

PyTorch文档:

2. Tanh 函数

PyTorch实现(类):torch.nn.Tanh()

PyTorch实现(函数):torch.nn.functional.tanh()

PyTorch文档:

3. ReLU 函数

全称:Rectified Linear Unit

PyTorch实现(类):torch.nn.ReLU()

PyTorch实现(函数):torch.nn.functional.relu()

PyTorch文档:

4. Leaky ReLU 函数

PyTorch实现(类):torch.nn.LeakyReLU()

PyTorch实现(函数):torch.nn.functional.leaky_relu()

PyTorch文档:

5. GELU 函数

PyTorch实现(类):torch.nn.GELU()

PyTorch实现(函数):torch.nn.functional.gelu()

PyTorch文档:

6. Swish/SiLU 函数

PyTorch实现(类):torch.nn.SiLU()

PyTorch实现(函数):torch.nn.functional.silu()

PyTorch文档:

还有其他的一些激活函数,这里暂时不列出。每种激活函数都有其适用的场景和局限性,通常需要根据具体的问题和数据特点来选择合适的激活函数。

7. 画图代码

"""
This code is supported by the website: https://www.guanjihuan.com
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/39029
"""

import torch
import numpy as np
import matplotlib.pyplot as plt

x_array = np.linspace(-6, 6, 100)
x_array_torch_tensor = torch.from_numpy(x_array)

y_array_torch_tensor = torch.nn.functional.sigmoid(x_array_torch_tensor)
plt.plot(x_array_torch_tensor, y_array_torch_tensor)
plt.title('sigmoid')
plt.show()

y_array_torch_tensor = torch.nn.functional.tanh(x_array_torch_tensor)
plt.plot(x_array_torch_tensor, y_array_torch_tensor)
plt.title('tanh')
plt.show()

y_array_torch_tensor = torch.nn.functional.relu(x_array_torch_tensor)
plt.plot(x_array_torch_tensor, y_array_torch_tensor)
plt.title('relu')
plt.show()

y_array_torch_tensor = torch.nn.functional.leaky_relu(x_array_torch_tensor)
plt.plot(x_array_torch_tensor, y_array_torch_tensor)
plt.title('leaky_relu')
plt.show()

y_array_torch_tensor = torch.nn.functional.gelu(x_array_torch_tensor)
plt.plot(x_array_torch_tensor, y_array_torch_tensor)
plt.title('gelu')
plt.show()

y_array_torch_tensor = torch.nn.functional.silu(x_array_torch_tensor)
plt.plot(x_array_torch_tensor, y_array_torch_tensor)
plt.title('silu')
plt.show()
42 次浏览

【说明:本站主要是个人的一些笔记和代码分享,内容可能会不定期修改。为了使全网显示的始终是最新版本,这里的文章未经同意请勿转载。引用请注明出处:https://www.guanjihuan.com

发表评论

您的电子邮箱地址不会被公开。 必填项已用 * 标注

Captcha Code