torch

torch.nn.functional.pad

从输入input的最后一个维度向前padding

输入input的\(\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor\)个维度进行padding

  • 如果只padding输入张量input的最后1个维度,pad的形式如:(padding_left, padding_right)

  • 如果只padding输入张量input的最后2个维度,pad的形式如:(padding_left, padding_right, padding_top, padding_bottom)

  • 如果只padding输入张量input的最后3个维度,pad的形式如:(padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)

import torch
padding_len = 94
pad_token_id = -100
# 4维度输入张量: (3, 3, 4, 2)维度
t4d = torch.empty(3, 3, 4, 2)
# padding最后1维
# padding数是左右各1个
p1d = (1, 1) # (padding_left, padding_right)
out = torch.nn.functional.pad(t4d, pad=p1d, mode="constant", value=0)
out.size()  # (3, 3, 4, 2) -> (3, 3, 4, 4)
torch.Size([3, 3, 4, 4])
# padding最后2维。
# 倒数第1维度:padding数是左右各1个;
# 倒数第2维度:padding数是上下各2个。 
p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2)
out = torch.nn.functional.pad(t4d, p2d, "constant", 0)
out.size()  # (3, 3, 4, 2) -> (3, 3, 8, 4)
torch.Size([3, 3, 8, 4])
# padding最后3维。
# 倒数第1维度:padding数是左0右1;
# 倒数第2维度:padding数是上2下1;
# 倒数第3维度:padding数是上3下3。 
p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3)
out = torch.nn.functional.pad(t4d, p3d, "constant", 0)
out.size() # (3, 3, 4, 2) -> (3, 9, 7, 3)
torch.Size([3, 9, 7, 3])

input_ids = torch.empty((1, 4002))
print(input_ids.shape)
input_ids
torch.Size([1, 4002])





tensor([[1.4822e+11, 4.8617e+30, 1.9069e-19,  ..., 1.8037e+28, 1.9368e+31,
         3.1387e-11]])
input_ids = torch.nn.functional.pad(input_ids, (0, padding_len), value=pad_token_id)
print(input_ids.shape)
input_ids
torch.Size([1, 4096])





tensor([[ 1.4822e+11,  4.8617e+30,  1.9069e-19,  ..., -1.0000e+02,
         -1.0000e+02, -1.0000e+02]])

paddle

paddle.nn.functional.pad

如果 mode 为 ‘constant’,并且 pad 的长度为 x 维度的2倍时, 则会根据 pad 和 value 对 x 从前面的维度向后依次补齐;

否则:

  • 当输入维度为3时,pad的格式为[pad_left, pad_right];

  • 当输入维度为4时,pad的格式为[pad_left, pad_right, pad_top, pad_bottom];

  • 当输入维度为5时,pad的格式为[pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back]。

⚠️注意:与torch的区别,paddle是从前往后padding,所以一般需要pad=(0, 0, …)补齐

import paddle
import numpy as np
x_shape = (1, 1, 3)
x3d = paddle.arange(np.prod(x_shape), dtype="float32").reshape(x_shape) + 1
x3d
Tensor(shape=[1, 1, 3], dtype=float32, place=CPUPlace, stop_gradient=True,
       [[[1., 2., 3.]]])
out_paddle = paddle.nn.functional.pad(x3d, [0, 0, 0, 0, 2, 3], value=1, mode='constant')
out_paddle
Tensor(shape=[1, 1, 8], dtype=float32, place=CPUPlace, stop_gradient=True,
       [[[1., 1., 1., 2., 3., 1., 1., 1.]]])
out_paddle = paddle.nn.functional.pad(x3d, [2, 3], value=1, mode='constant')
out_paddle
---------------------------------------------------------------------------

AssertionError                            Traceback (most recent call last)

/tmp/ipykernel_22733/2557938510.py in <module>
----> 1 out_paddle = paddle.nn.functional.pad(x3d, [2, 3], value=1, mode='constant')
      2 out_paddle


/opt/conda/envs/blog/lib/python3.8/site-packages/paddle/nn/functional/common.py in pad(x, pad, mode, value, data_format, name)
   1288         5: ["NCDHW", "NDHWC"],
   1289     }
-> 1290     assert data_format in supported_format_map[x_dim], \
   1291     "input tensor dimension is {}, it's data format should be in {} but got {}".format(
   1292         x_dim, supported_format_map[x_dim], data_format)


AssertionError: input tensor dimension is 3, it's data format should be in ['NCL', 'NLC'] but got NCHW
out_paddle = paddle.nn.functional.pad(x3d, [2, 3], value=1, mode='constant', data_format='NCL')
out_paddle
Tensor(shape=[1, 1, 8], dtype=float32, place=CPUPlace, stop_gradient=True,
       [[[1., 1., 1., 2., 3., 1., 1., 1.]]])

x_shape = (1, 4002)
x2d = paddle.arange(np.prod(x_shape), dtype="float32").reshape(x_shape) + 1
x2d
Tensor(shape=[1, 4002], dtype=float32, place=CPUPlace, stop_gradient=True,
       [[1.   , 2.   , 3.   , ..., 4000., 4001., 4002.]])
out_paddle = paddle.nn.functional.pad(x2d, [0, 0, 0, padding_len], value=pad_token_id, mode='constant')
out_paddle
Tensor(shape=[1, 4096], dtype=float32, place=CPUPlace, stop_gradient=True,
       [[ 1.  ,  2.  ,  3.  , ..., -100., -100., -100.]])

Share on: TwitterFacebookEmail

Comments


Related Posts


Published

Category

NLP

Tags

Contact