新智元推荐
来源:专知
【新智元导读】在研究深度学习的过程中,当你脑中突然迸发出一个灵感,你是否发现没有趁手的工具可以快速实现你的想法?看完本文之后,你可能会多出一个选择。本文简要分析了研究深度学习问题时常见的工作流, 并介绍了怎么使用 PyTorch 来快速构建你的实验。
某一天, 你坐在实验室的椅子上, 突然:
你脑子里迸发出一个idea
你看了关于某一theory 的文章, 想试试: 要是把 xx 也加进去会怎么样
你老板突然给你一张纸, 然后说: 那个谁, 来把这个东西实现一下
于是, 你设计了实验流程, 并为这一 idea 挑选了合适的数据集和运行环境, 然后你废寝忘食的实现了模型, 经过长时间的训练和测试, 你发现:
这 idea 不 work --> 那算了 or 再调调
这 idea 很 work --> 可以写 paper 了
我们可以把上述流程用下图表示:
实际上, 常见的流程由下面几项组成起来:
一旦选定了数据集, 你就要写一些函数去 load 数据集, 然后 pre-process 数据集, normalize 数据集, 可以说这是一个实验中占比重最多的部分, 因为:
每个数据集的格式都不太一样
预处理和正则化的方式也不尽相同
需要一个快速的 dataloader 来 feed data, 越快越好
然后, 你就要实现自己的模型, 如果你是 CV 方向的你可能想实现一个 ResNet, 如果你是 NLP 相关的你可能想实现一个 Seq2Seq
接下来, 你需要实现训练步骤, 分 batch, 循环 epoch
在若干轮的训练后, 总要 checkpoint 一下, 才是最安全的
你还需要构建一些 baseline, 以验证自己 idea 的有效性
如果你实现的是神经网络模型, 当然离不开 GPU 的支持
很多深度学习框架提供了常见的损失函数, 但大部分时间, 损失函数都要和具体任务结合起来, 然后重新实现
使用优化方法, 优化构建的模型, 动态调整学习率
对于加载数据, Pytorch 提出了多种解决办法
Pytorch 是一个 Python 包, 而不是某些大型 C++ 库的 Python 接口, 所以, 对于数据集本身提供 Python API 的, Pytorch 可以直接调用, 不必特殊处理.
Pytorch 集成了常用数据集的 data loader
虽然以上措施已经能涵盖大部分数据集了, 但 Pytorch 还开展了两个项目: vision, 和 text, 见下图灰色背景部分. 这两个项目, 采用众包机制, 收集了大量的 dataloader, pre-process 以及 normalize, 分别对应于图像和文本信息.
如果你要自定义数据集,也只需要继承 torch.utils.data.dataset
对于构建模型, Pytorch 也提供了三种方案
众包的模型: torch.utils.model_zoo , 你可以使用这个工具, 加载大家共享出来的模型
使用 torch.nn.Sequential 模块快速构建
集成 torch.nn.Module 深度定制
对于训练过程的 Pytorch 实现
你当然可以自己实现数据的 batch, shuffer 等, 但 Pytorch 建议用类 torch.utils.data.DataLoader 加载数据,并对数据进行采样,生成batch
迭代器。
对于保存和加载模型 Pytorch 提供两种方案
保存和加载整个网络
保存和加载网络中的参数
对于 GPU 支持
你可以直接调用 Tensor 的. cuda() 直接将 Tensor 的数据迁移到 GPU 的显存上, 当然, 你也可以用. cpu() 随时将数据移回内存
对于 Loss 函数, 以及自定义 Loss
在 Pytorch 的包 torch.nn 里, 不仅包含常用且经典的 Loss 函数, 还会实时跟进新的 Loss 包括: CosineEmbeddingLoss, TripletMarginLoss 等.
如果你的 idea 非常新颖, Pytorch 提供了三种自定义 Loss 的方式
继承 torch.nn.module
然后
这样做, 你能够用 torch.nn.functional 里优化过的各种函数来组成你的 Loss
继承 torch.autograd.Function
这样做,你能够用常用的 numpy 和 scipy 函数来组成你的 Loss
写一个 Pytorch 的 C 扩展
这里就不细讲了,未来会有内容专门介绍这一部分。
对于优化算法以及调节学习率
Pytorch 集成了常见的优化算法, 包括 SGD, Adam, SparseAdam, AdagradRMSprop, Rprop 等等.
torch.optim.lr_scheduler 提供了多种方式来基于 epoch 迭代次数调节学习率 torch.optim.lr_scheduler.ReduceLROnPlateau 还能够基于实时的学习结果, 动态调整学习率.
希望第一篇《深度学习实验流程及 PyTorch 提供的解决方案》,大家会喜欢,后续会推出系列实战教程,敬请期待。
联系客服