import torchvision.transforms.functional as F
class SquarePad:
def __call__(self, image):
max_wh = max(image.size)
p_left, p_top = [(max_wh - s) // 2 for s in image.size]
p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])]
padding = (p_left, p_top, p_right, p_bottom)
return F.pad(image, padding, 0, 'constant')
target_image_size = (224, 224) # as an example
# now use it as the replacement of transforms.Pad class
transform=transforms.Compose([
SquarePad(),
transforms.Resize(target_image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])