torch笔记
torch笔记
参考视频
www.bilibili.com/video/BV1hE411t7RN
基本类的使用
Dataset
作用:数据导入
模块位置
1 |
|
代码
1 |
|
tensorboard.SummaryWriter
日志记录,可以记录标量或者图片的变化
代码
1 |
|
查看
1 |
|
transforms
图片处理类
导入
1 |
|
代码
1 |
|
transform = transforms.ToTensor()
使所有数据转换为Tensor
,如果不进行转换则返回的是PIL图片。transforms.ToTensor()
将尺寸为 (H x W x C) 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8
的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32
且位于[0.0, 1.0]的Tensor
。
dataloader
数据分组
1 |
|
一些注意
需要多看看官方文档,一方面要看看输入输出的数据类型是什么,另一方面不清楚返回值的话要试试print(type(x))
或者关注一下print(x.shape)
的值
使用torch自带的数据集
1 |
|
网络
nn.module
torch.nn.Module is Base class for all neural network modules.
1 |
|
卷积
F.conv2d
导入
1 |
|
代码
1 |
|
卷积层
Conv2d
用于选取特征
1 |
|
代码
1 |
|
池化层
用于减少数据量
1 |
|
代码
1 |
|
非线性激活
拟合特征
1 |
|
线性层
图中每个箭头都是一个线性计算 $$y=ax+b$$ 具体是矩阵乘法
代码
1 |
|
seq
用于简化表达
1 |
|
net[0]
这样根据下标访问子模块的写法只有当net
是个ModuleList
或者Sequential
实例时才可以
损失
loss
1 |
|
backward
1 |
|
优化
1 |
|
打印参数及初始化参数
1 |
|
使用预下载好的模型
1 |
|
测试网络参数是否正确
使用torch.zeros
1 |
|
保存和使用已有模型的参数
保存参数 两种方法
1 |
|
读取参数 与之对应的两种方法,这两种方法就相当于加载好了实例了
1 |
|
训练过程
1 |
|
GPU加速
需要在三个地方进行设置
- 网络
- 数据集
- 损失函数
方法一
用.cuda()的方法
1 |
|
方法二
用to(device)的方法
1 |
|
torch笔记
http://iamlihua.github.io/2023/07/12/torch-learn/