In [107]: import torchvision
# sample input (10 RGB images containing just Gaussian Noise)
In [108]: batch_tensor = torch.randn(*(10, 3, 256, 256)) # (N, C, H, W)
# make grid (2 rows and 5 columns) to display our 10 images
In [109]: grid_img = torchvision.utils.make_grid(batch_tensor, nrow=5)
# check shape
In [110]: grid_img.shape
Out[110]: torch.Size([3, 518, 1292])
# reshape and plot (because MPL needs channel as the last dimension)
In [111]: plt.imshow(grid_img.permute(1, 2, 0))
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Out[111]: <matplotlib.image.AxesImage at 0x7f62081ef080>