我是靠谱客的博主 沉默凉面,这篇文章主要介绍nn.Linear和nn.BatchNorm1的维度问题,现在分享给大家,希望可以做个参考。

import torch
import torch.nn as nn
input=torch.randn([32,49,768])

l=nn.Linear(768,512)
out=l(input)
print(out.shape)
# torch.Size([32, 49, 512])

# l=nn.Linear(49,512)
# mat1 and mat2 shapes cannot be multiplied (1568x768 and 49x512)
# 说明了执行linear时,输入的channel只能位于最后一维
b=nn.BatchNorm1d(49)
out=b(out)
print(out.shape)
# torch.Size([32, 49, 512])
# b=nn.BatchNorm1d(512)
# RuntimeError: running_mean should contain 49 elements not 512
# 说明了执行linear时,输入的channel只能位于最后中间

最后

以上就是沉默凉面最近收集整理的关于nn.Linear和nn.BatchNorm1的维度问题的全部内容,更多相关nn.Linear和nn.BatchNorm1内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(116)

评论列表共有 0 条评论

立即
投稿
返回
顶部