Skip to content

SpatialMaxPooling supports padding and ceil mode #309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 21, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions SpatialMaxPooling.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
local SpatialMaxPooling, parent = torch.class('nn.SpatialMaxPooling', 'nn.Module')

function SpatialMaxPooling:__init(kW, kH, dW, dH)
function SpatialMaxPooling:__init(kW, kH, dW, dH, padW, padH)
parent.__init(self)

dW = dW or kW
Expand All @@ -11,10 +11,28 @@ function SpatialMaxPooling:__init(kW, kH, dW, dH)
self.dW = dW
self.dH = dH

self.padW = padW or 0
self.padH = padH or 0

self.ceil_mode = false
self.indices = torch.Tensor()
end

function SpatialMaxPooling:ceil()
self.ceil_mode = true
return self
end

function SpatialMaxPooling:floor()
self.ceil_mode = false
return self
end

function SpatialMaxPooling:updateOutput(input)
-- backward compatibility
self.ceil_mode = self.ceil_mode or false
self.padW = self.padW or 0
self.padH = self.padH or 0
input.nn.SpatialMaxPooling_updateOutput(self, input)
return self.output
end
Expand All @@ -34,6 +52,12 @@ function SpatialMaxPooling:empty()
end

function SpatialMaxPooling:__tostring__()
return string.format('%s(%d,%d,%d,%d)', torch.type(self),
self.kW, self.kH, self.dW, self.dH)
local s = string.format('%s(%d,%d,%d,%d', torch.type(self),
self.kW, self.kH, self.dW, self.dH)
if (self.padW or self.padH) and (self.padW ~= 0 or self.padH ~= 0) then
s = s .. ',' .. self.padW .. ','.. self.padH
end
s = s .. ')'

return s
end
13 changes: 12 additions & 1 deletion doc/convolution.md
Original file line number Diff line number Diff line change
Expand Up @@ -361,13 +361,24 @@ Computes the `p` norm in a convolutional manner on a set of 2D input planes.
### SpatialMaxPooling ###

```lua
module = nn.SpatialMaxPooling(kW, kH [, dW, dH])
module = nn.SpatialMaxPooling(kW, kH [, dW, dH, padW, padH])
```

Applies 2D max-pooling operation in `kWxkH` regions by step size
`dWxdH` steps. The number of output features is equal to the number of
input planes.

If the input image is a 3D tensor `nInputPlane x height x width`, the output
image size will be `nOutputPlane x oheight x owidth` where

```lua
owidth = op((width + 2*padW - kW) / dW + 1)
oheight = op((height + 2*padH - kH) / dH + 1)
```

`op` is a rounding operator. By default, it is `floor`. It can be changed
by calling `:ceil()` or `:floor()` methods.

<a name="nn.SpatialAveragePooling"/>
### SpatialAveragePooling ###

Expand Down
93 changes: 60 additions & 33 deletions generic/SpatialMaxPooling.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,59 @@
#else

static void nn_(SpatialMaxPooling_updateOutput_frame)(real *input_p, real *output_p,
real *indx_p, real *indy_p,
real *ind_p,
long nslices,
long iwidth, long iheight,
long owidth, long oheight,
int kW, int kH, int dW, int dH)
int kW, int kH, int dW, int dH,
int padW, int padH)
{
long k;
#pragma omp parallel for private(k)
for (k = 0; k < nslices; k++)
{
/* loop over output */
long i, j;
real *ip = input_p + k*iwidth*iheight;
for(i = 0; i < oheight; i++)
{
for(j = 0; j < owidth; j++)
{
long hstart = i * dH - padH;
long wstart = j * dW - padW;
long hend = fminf(hstart + kH, iheight);
long wend = fminf(wstart + kW, iwidth);
hstart = fmaxf(hstart, 0);
wstart = fmaxf(wstart, 0);

/* local pointers */
real *ip = input_p + k*iwidth*iheight + i*iwidth*dH + j*dW;
real *op = output_p + k*owidth*oheight + i*owidth + j;
real *indyp = indy_p + k*owidth*oheight + i*owidth + j;
real *indxp = indx_p + k*owidth*oheight + i*owidth + j;
real *indp = ind_p + k*owidth*oheight + i*owidth + j;

/* compute local max: */
long maxindex = -1;
real maxval = -THInf;
long tcntr = 0;
int x,y;
for(y = 0; y < kH; y++)
long x,y;
for(y = hstart; y < hend; y++)
{
for(x = 0; x < kW; x++)
for(x = wstart; x < wend; x++)
{
real val = *(ip + y*iwidth + x);
tcntr = y*iwidth + x;
real val = *(ip + tcntr);
if (val > maxval)
{
maxval = val;
maxindex = tcntr;
}
tcntr++;
}
}

/* set output to local max */
*op = maxval;

/* store location of max (x,y) */
*indyp = (int)(maxindex / kW)+1;
*indxp = (maxindex % kW) +1;
/* store location of max */
*indp = maxindex + 1;
}
}
}
Expand All @@ -62,6 +68,9 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L)
int kH = luaT_getfieldcheckint(L, 1, "kH");
int dW = luaT_getfieldcheckint(L, 1, "dW");
int dH = luaT_getfieldcheckint(L, 1, "dH");
int padW = luaT_getfieldcheckint(L, 1, "padW");
int padH = luaT_getfieldcheckint(L, 1, "padH");
int ceil_mode = luaT_getfieldcheckboolean(L,1,"ceil_mode");
THTensor *indices = luaT_getfieldcheckudata(L, 1, "indices", torch_Tensor);
THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
int dimw = 2;
Expand All @@ -85,14 +94,33 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L)
dimw++;
dimh++;
}
luaL_argcheck(L, input->size[dimw] >= kW && input->size[dimh] >= kH, 2, "input image smaller than kernel size");
luaL_argcheck(L, input->size[dimw] >= kW - padW && input->size[dimh] >= kH - padH, 2, "input image smaller than kernel size");

luaL_argcheck(L, kW/2 >= padW && kH/2 >= padH, 2, "pad should be smaller than half of kernel size");

/* sizes */
nslices = input->size[dimh-1];
iheight = input->size[dimh];
iwidth = input->size[dimw];
oheight = (iheight - kH) / dH + 1;
owidth = (iwidth - kW) / dW + 1;
if (ceil_mode)
{
oheight = (long)(ceil((float)(iheight - kH + 2*padH) / dH)) + 1;
owidth = (long)(ceil((float)(iwidth - kW + 2*padW) / dW)) + 1;
}
else
{
oheight = (long)(floor((float)(iheight - kH + 2*padH) / dH)) + 1;
owidth = (long)(floor((float)(iwidth - kW + 2*padW) / dW)) + 1;
}

if (padW || padH)
{
// ensure that the last pooling starts inside the image
if ((oheight - 1)*dH >= iheight + padH)
--oheight;
if ((owidth - 1)*dW >= iwidth + padW)
--owidth;
}

/* get contiguous input */
input = THTensor_(newContiguous)(input);
Expand All @@ -101,27 +129,28 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L)
if (input->nDimension == 3)
{
THTensor_(resize3d)(output, nslices, oheight, owidth);
/* indices will contain i,j locations for each output point */
THTensor_(resize4d)(indices, 2, nslices, oheight, owidth);
/* indices will contain the locations for each output point */
THTensor_(resize3d)(indices, nslices, oheight, owidth);

input_data = THTensor_(data)(input);
output_data = THTensor_(data)(output);
indices_data = THTensor_(data)(indices);

nn_(SpatialMaxPooling_updateOutput_frame)(input_data, output_data,
indices_data+nslices*owidth*oheight, indices_data,
indices_data,
nslices,
iwidth, iheight,
owidth, oheight,
kW, kH, dW, dH);
kW, kH, dW, dH,
padW, padH);
}
else
{
long p;

THTensor_(resize4d)(output, nbatch, nslices, oheight, owidth);
/* indices will contain i,j locations for each output point */
THTensor_(resize5d)(indices, 2, nbatch, nslices, oheight, owidth);
/* indices will contain the locations for each output point */
THTensor_(resize4d)(indices, nbatch, nslices, oheight, owidth);

input_data = THTensor_(data)(input);
output_data = THTensor_(data)(output);
Expand All @@ -131,11 +160,12 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L)
for (p = 0; p < nbatch; p++)
{
nn_(SpatialMaxPooling_updateOutput_frame)(input_data+p*nslices*iwidth*iheight, output_data+p*nslices*owidth*oheight,
indices_data+(p+nbatch)*nslices*owidth*oheight, indices_data+p*nslices*owidth*oheight,
indices_data+p*nslices*owidth*oheight,
nslices,
iwidth, iheight,
owidth, oheight,
kW, kH, dW, dH);
kW, kH, dW, dH,
padW, padH);
}
}

Expand All @@ -145,7 +175,7 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L)
}

static void nn_(SpatialMaxPooling_updateGradInput_frame)(real *gradInput_p, real *gradOutput_p,
real *indx_p, real *indy_p,
real *ind_p,
long nslices,
long iwidth, long iheight,
long owidth, long oheight,
Expand All @@ -157,8 +187,7 @@ static void nn_(SpatialMaxPooling_updateGradInput_frame)(real *gradInput_p, real
{
real *gradInput_p_k = gradInput_p + k*iwidth*iheight;
real *gradOutput_p_k = gradOutput_p + k*owidth*oheight;
real *indx_p_k = indx_p + k*owidth*oheight;
real *indy_p_k = indy_p + k*owidth*oheight;
real *ind_p_k = ind_p + k*owidth*oheight;

/* calculate max points */
long i, j;
Expand All @@ -167,11 +196,9 @@ static void nn_(SpatialMaxPooling_updateGradInput_frame)(real *gradInput_p, real
for(j = 0; j < owidth; j++)
{
/* retrieve position of max */
long maxi = indy_p_k[i*owidth + j] - 1 + i*dH;
long maxj = indx_p_k[i*owidth + j] - 1 + j*dW;

long maxp = ind_p_k[i*owidth + j] - 1;
/* update gradient */
gradInput_p_k[maxi*iwidth + maxj] += gradOutput_p_k[i*owidth + j];
gradInput_p_k[maxp] += gradOutput_p_k[i*owidth + j];
}
}
}
Expand Down Expand Up @@ -226,7 +253,7 @@ static int nn_(SpatialMaxPooling_updateGradInput)(lua_State *L)
if (input->nDimension == 3)
{
nn_(SpatialMaxPooling_updateGradInput_frame)(gradInput_data, gradOutput_data,
indices_data+nslices*owidth*oheight, indices_data,
indices_data,
nslices,
iwidth, iheight,
owidth, oheight,
Expand All @@ -239,7 +266,7 @@ static int nn_(SpatialMaxPooling_updateGradInput)(lua_State *L)
for (p = 0; p < nbatch; p++)
{
nn_(SpatialMaxPooling_updateGradInput_frame)(gradInput_data+p*nslices*iwidth*iheight, gradOutput_data+p*nslices*owidth*oheight,
indices_data+(p+nbatch)*nslices*owidth*oheight, indices_data+p*nslices*owidth*oheight,
indices_data+p*nslices*owidth*oheight,
nslices,
iwidth, iheight,
owidth, oheight,
Expand Down
60 changes: 33 additions & 27 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1926,38 +1926,44 @@ function nntest.SpatialSubSampling()
end

function nntest.SpatialMaxPooling()
local from = math.random(1,5)
local ki = math.random(1,4)
local kj = math.random(1,4)
local si = math.random(1,3)
local sj = math.random(1,3)
local outi = math.random(4,5)
local outj = math.random(4,5)
local ini = (outi-1)*si+ki
local inj = (outj-1)*sj+kj

local module = nn.SpatialMaxPooling(ki,kj,si,sj)
local input = torch.rand(from,ini,inj)
for _,ceil_mode in pairs({true,false}) do
local from = math.random(1,5)
local ki = math.random(1,4)
local kj = math.random(1,4)
local si = math.random(1,3)
local sj = math.random(1,3)
local outi = math.random(4,5)
local outj = math.random(4,5)
local padW = math.min(math.random(0,1),math.floor(ki/2))
local padH = math.min(math.random(0,1),math.floor(kj/2))
local ini = (outi-1)*si+ki-2*padW
local inj = (outj-1)*sj+kj-2*padH

local ceil_string = ceil_mode and 'ceil' or 'floor'
local module = nn.SpatialMaxPooling(ki,kj,si,sj,padW,padH)
if ceil_mode then module:ceil() else module:floor() end
local input = torch.rand(from,inj,ini)

local err = jac.testJacobian(module, input)
mytester:assertlt(err, precision, 'error on state ')

local ferr, berr = jac.testIO(module, input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
local err = jac.testJacobian(module, input)
mytester:assertlt(err, precision, 'error '..ceil_string..' mode on state ')

-- batch
local nbatch = math.random(2,5)
input = torch.rand(nbatch,from,ini,inj)
module = nn.SpatialMaxPooling(ki,kj,si,sj)
local ferr, berr = jac.testIO(module, input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')

local err = jac.testJacobian(module, input)
mytester:assertlt(err, precision, 'error on state (Batch) ')
-- batch
local nbatch = math.random(2,5)
input = torch.rand(nbatch,from,inj,ini)
module = nn.SpatialMaxPooling(ki,kj,si,sj,padW,padH)
if ceil_mode then module:ceil() else module:floor() end

local ferr, berr = jac.testIO(module, input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ')
local err = jac.testJacobian(module, input)
mytester:assertlt(err, precision, 'error '..ceil_string..' mode on state (Batch)')

local ferr, berr = jac.testIO(module, input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ')
end
end

function nntest.SpatialAveragePooling()
Expand Down