背景
帮师妹改毕设代码,第一次接触pytorch,没看手册,直接开肛,遇到些坑,在这里记录一下,大佬勿喷。
坑1:进入网络训练的数据必须归一化
如果数据没有归一化,可能得到的loss会成为负数,在此参考了Crazy_Omais的一段归一化代码
复制代码
1
2
3
4
5
6def data_in_one(inputdata): min = np.nanmin(inputdata) max = np.nanmax(inputdata) outputdata = (inputdata-min)/(max-min) return outputdata
坑2:torchvision.transforms.ToTensor()
torchvision.transforms.ToTensor()不能用于处理一维数据,如果要处理的话,可以使用torch.from_numpy()
复制代码
1
2
3
4def __getitem__(self, index): data = self.datas[:][index] data = torch.from_numpy(data)
坑3:网络训练的数据需要是Dataloader类型
网络训练的数据需要是Dataloader类型,而输入必须是一个Dataset的子类,因此我们有必要定义一个类,以装载我们自己的数据,本代码数据手动分训练集和测试集,两个成员函数是必要的,getitem 函数是在torch.utils.data.DataLoader()分batch的时候循环调用的:
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30class MyDataset(torch.utils.data.Dataset): def __init__(self, train_data_flag=0): super(MyDataset, self).__init__() file_path = '/Users/sophia/Downloads/****.mat' fh = scio.loadmat(file_path) fh = fh['Y'] fh = data_in_one(fh) # 由uint16->float64 fh_array = np.array(fh, dtype='float') fh_array_t = fh_array.T if train_data_flag == 0: self.datas = fh_array_t[:][0:90000] else: self.datas = fh_array_t[:][90001:94001] def __getitem__(self, index): data = self.datas[:][index] data = torch.from_numpy(data) return data def __len__(self): return len(self.datas) #主函数调用 train_dataset = MyDataset(train_data_flag=0) test_dataset = MyDataset(train_data_flag=1) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=shuffle) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=shuffle)
坑4:报错RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #2 ‘mat1’ in call to _th_addmm
在网上看到大佬们写的demo可以了解到,要将tenor类型的数据用tenor.float()转换为浮点型即可,如果是有已有的数据,建议在出错行网上搜索输入的tenor类变量,然后对它进行操作。
复制代码
1
2
3
4
5for batch_index, train_data in enumerate(train_loader): if torch.cuda.is_available(): train_data = train_data.cuda() train_data = train_data.float()
总结
调试真的太方便了,不知道比tensorflow方便多少倍,爱了爱了,但是要运用熟练还得好好学习一下人家的框架hhhh
参考链接
- https://blog.csdn.net/weixin_42214565/article/details/102381380
- https://blog.csdn.net/Teeyohuang/article/details/79587125
最后
以上就是大方篮球最近收集整理的关于用pytorch踩过的坑的全部内容,更多相关用pytorch踩过内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复