Pytorch中使用torch.device()选取并返回抽象出的设备,然后在定义的网络模块或者Tensor后面加上.to(device变量)就可以将它们搬到设备上了。
以上一篇代码为例,使用GPU设备:
device = torch.device('cuda:0') # 使用第一张显卡
需要将如下部分搬移到GPU上:
1. 定义的网络
net = MLP().to(device)
2. 损失函数
criteon = nn.CrossEntropyLoss()
3.每次取出的训练集和验证集的batch
data, target = data.to(device), target.to(device)
最后
以上就是可爱小懒虫最近收集整理的关于Pytorch 使用GPU加速的全部内容,更多相关Pytorch内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复