文章目录
- 如何理解torch.max中的dim
- pytorch文档解释
如何理解torch.max中的dim
理解torch.max主要是一点,即dim的指定。
其实torch中的dim和numpy中的axis是一个东西,可以点击这个链接查看axis的详细解释,理解了axis后,dim就迎刃而解了。
pytorch文档解释
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85torch.max = max(...) max(input) -> Tensor Returns the maximum value of all elements in the ``input`` tensor. .. warning:: This function produces deterministic (sub)gradients unlike ``max(dim=0)`` Args: input (Tensor): the input tensor. Example:: >>> a = torch.randn(1, 3) >>> a tensor([[ 0.6763, 0.7445, -2.2369]]) >>> torch.max(a) tensor(0.7445) .. function:: max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor) Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum value of each row of the :attr:`input` tensor in the given dimension :attr:`dim`. And ``indices`` is the index location of each maximum value found (argmax). .. warning:: ``indices`` does not necessarily contain the first occurrence of each maximal value found, unless it is unique. The exact implementation details are device-specific. Do not expect the same result when run on CPU and GPU in general. For the same reason do not expect the gradients to be deterministic. If ``keepdim`` is ``True``, the output tensors are of the same size as ``input`` except in the dimension ``dim`` where they are of size 1. Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting in the output tensors having 1 fewer dimension than ``input``. Args: input (Tensor): the input tensor. dim (int): the dimension to reduce. keepdim (bool): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. out (tuple, optional): the result tuple of two output tensors (max, max_indices) Example:: >>> a = torch.randn(4, 4) >>> a tensor([[-1.2360, -0.2942, -0.1222, 0.8475], [ 1.1949, -1.1127, -2.2379, -0.6702], [ 1.5717, -0.9207, 0.1297, -1.8768], [-0.6172, 1.0036, -0.6060, -0.2432]]) >>> torch.max(a, 1) torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1])) .. function:: max(input, other, out=None) -> Tensor Each element of the tensor ``input`` is compared with the corresponding element of the tensor ``other`` and an element-wise maximum is taken. The shapes of ``input`` and ``other`` don't need to match, but they must be :ref:`broadcastable <broadcasting-semantics>`. .. math:: text{out}_i = max(text{tensor}_i, text{other}_i) .. note:: When the shapes do not match, the shape of the returned output tensor follows the :ref:`broadcasting rules <broadcasting-semantics>`. Args: input (Tensor): the input tensor. other (Tensor): the second input tensor out (Tensor, optional): the output tensor. Example:: >>> a = torch.randn(4) >>> a tensor([ 0.2942, -0.7416, 0.2653, -0.1584]) >>> b = torch.randn(4) >>> b tensor([ 0.8722, -1.7421, -0.4141, -0.5055]) >>> torch.max(a, b) tensor([ 0.8722, -0.7416, 0.2653, -0.1584])
最后
以上就是魁梧金鱼最近收集整理的关于torch.max用法(指定维度dim)的全部内容,更多相关torch内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复