Pytorch grid_sample解析
grid_sample函數
這篇博客只對bilinear mode進行解釋說明,并且會對align_corners為True或False兩種情況進行分情況討論。
torch.nn.functional.grid_sample(input, grid, mode=‘bilinear’, padding_mode=‘zero’, align_corners=None)
nn.functional下的grid_sample函數會根據提供的坐標(grid)對input pixels進行采樣(sampling),這篇文章只以bilinear interpolation sampling為例。 根據官方文檔介紹,input shape必須是4D或5D的,分別用于二維和三維圖像的采樣(前兩個維度為batch size和channel)。
input的shape(4D case)是(N,C,Hin,Win)(N, C, H_{in}, W_{in})(N,C,Hin?,Win?), 這個很好理解。
gird的shape(4D case)是(N,Hout,Wout,2)(N, H_{out}, W_{out}, 2)(N,Hout?,Wout?,2), 這里的H和W是output的長和寬,有一點需要注意的是,grid_sample的output shape是(N,C,Hout,Wout)(N, C, H_{out}, W_{out})(N,C,Hout?,Wout?), 所以output的shape和grid的shape是一樣的, 而不是和input的shape一樣。grid的最后一個維度2表示的是x,y坐標, 如果是5D的情況,也就是處理三維圖像的時候,gird的最后一個維度就是3,因為需要引入z坐標。
grid表示的是的sampling pixel的坐標,這個坐標是被normalized過的,grid坐標取值范圍為[-1, 1]。 點(-1,-1)為左上角的pixel,(1,1)為右下的pixel。中間的坐標值為某個浮點數。
grid_sample函數做的就是根據grid坐標,從input的pixels里采樣。 如果此坐標下沒有對應的input pixel,就要用bilinear interpolation從周圍的pixels采樣。
下面是Piotr給出的一個例子
https://discuss.pytorch.org/t/solved-torch-grid-sample/51662/2
meshy是x坐標
meshx是y坐標
align_corners=True
當align_corners=True時,以坐標(-0.7143, -0.7143)為例,請看下圖。
因為align_corners=True,所以(-1, -1)點的值為0, (1, 1)點的值為15,可以認為grid的-1和1在是在corner pixel的中心位置。由此可以推出值為1和2的坐標為(-0.3333, 0)和(0.3333, 0)。我們要采樣的點(-0.7143, -0.7143)在0, 1, 4, 5中間,所以要從這四點進行采樣。根據坐標算出長度比例,然后用bilinear interpolation算出坐標(-0.7143, -0.7143)的值就okay了。
下圖是align_corners=True的output
align_corners=False
當align_corners=False時,以坐標(0.7143, -0.7143)為例,請看下圖。
注意:這個例子的坐標和上個例子的坐標不一樣。
因為align_corners=False, 所以(-1, -1)點的值不為0,(1, 1)點的值也不是15,grid的-1和1不在corner pixel的中心位置,而是在正方形像素的角。所以(-0.25, -0.25)的值才是0, (0.75, 0.75)的值才是15。由此可以推出值為1和2的坐標分別為(-0.25, -0.75)和(0.25, -0.75)。我們要采樣的點(0.7143, -0.7143)在2, 3, 6, 7中間,所以要從這四點進行采樣。根據坐標算出長度比例,然后用bilinear interpolation算出坐標(0.7143, -0.7143)的值就okay了。
下圖是align_corners=False的output:
總結
以上是生活随笔為你收集整理的Pytorch grid_sample解析的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: (转)TTime, TDateTime
- 下一篇: 只是为了好玩——Linux之父林纳斯自传