Loading...
墨滴

2021/12/21  阅读:47  主题:默认主题

pytorch中tril函数介绍

用法介绍

pytorchtril函数主要用于返回一个矩阵主对角线以下的下三角矩阵,其它元素全部为 。当输入是一个多维张量时,返回的是同等维度的张量并且最后两个维度的下三角矩阵的。

torch.tril(input, diagonal=0, *, out=None) Tensor

  • input(tensor):表示输入的张量
  • diagonal (int, optional):表示对角线的位置

参数 主要控制矩阵主对角线元素的位置。给定一个矩阵 ,则这个矩阵的主对角线元素组成的集合为

当参数 ,且 时,则此时矩阵主对角线元素的集合为 当参数 ,且 时,则此时矩阵主对角线元素的集合为

程序代码

torch.tril函数具体的程序代码示例如下所示

>>> import torch
>>> a = torch.randn(34)
>>> import torch
>>> a = torch.randn(33)
>>> a
tensor([[ 0.4925,  1.0023-0.5190],
        [ 0.0464-1.3224-0.0238],
        [-0.1801-0.6056,  1.0795]])
>>> torch.tril(a)
tensor([[ 0.4925,  0.0000,  0.0000],
        [ 0.0464-1.3224,  0.0000],
        [-0.1801-0.6056,  1.0795]])
>>> b = torch.randn(46)
>>> b
tensor([[-0.7886-0.2559-0.9161,  0.2353,  0.4033-0.0633],
        [-1.1292-0.3209-0.3307,  2.0719,  0.9238-1.8576],
        [-1.1988-1.0355-1.2745-1.7479,  0.3736-0.7210],
        [-0.3380,  1.7570-1.6608-0.4785,  0.2950-1.2821]])
>>> torch.tril(b)
tensor([[-0.7886,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-1.1292-0.3209,  0.0000,  0.0000,  0.0000,  0.0000],
        [-1.1988-1.0355-1.2745,  0.0000,  0.0000,  0.0000],
        [-0.3380,  1.7570-1.6608-0.4785,  0.0000,  0.0000]])
>>> torch.tril(b, diagonal=1)
tensor([[-0.7886-0.2559,  0.0000,  0.0000,  0.0000,  0.0000],
        [-1.1292-0.3209-0.3307,  0.0000,  0.0000,  0.0000],
        [-1.1988-1.0355-1.2745-1.7479,  0.0000,  0.0000],
        [-0.3380,  1.7570-1.6608-0.4785,  0.2950,  0.0000]])
>>> torch.tril(b, diagonal=-1)
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-1.1292,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-1.1988-1.0355,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.3380,  1.7570-1.6608,  0.0000,  0.0000,  0.0000]])
>>> torch.tril(b, diagonal=2)
tensor([[-0.7886-0.2559-0.9161,  0.0000,  0.0000,  0.0000],
        [-1.1292-0.3209-0.3307,  2.0719,  0.0000,  0.0000],
        [-1.1988-1.0355-1.2745-1.7479,  0.3736,  0.0000],
        [-0.3380,  1.7570-1.6608-0.4785,  0.2950-1.2821]])

2021/12/21  阅读:47  主题:默认主题

作者介绍