import torch
import torch.nn as nn
input=torch.randn([32,49,768])
l=nn.Linear(768,512)
out=l(input)
print(out.shape)
# torch.Size([32, 49, 512])
# l=nn.Linear(49,512)
# mat1 and mat2 shapes cannot be multiplied (1568x768 and 49x512)
# 说明了执行linear时,输入的channel只能位于最后一维
b=nn.BatchNorm1d(49)
out=b(out)
print(out.shape)
# torch.Size([32, 49, 512])
# b=nn.BatchNorm1d(512)
# RuntimeError: running_mean should contain 49 elements not 512
# 说明了执行linear时,输入的channel只能位于最后中间
最后
以上就是沉默凉面最近收集整理的关于nn.Linear和nn.BatchNorm1的维度问题的全部内容,更多相关nn.Linear和nn.BatchNorm1内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复