simple pool. swimming is very easy now

This commit is contained in:
George Hotz 2020-12-29 13:48:50 -05:00
parent 8f9232d59b
commit f18801c7db
1 changed files with 6 additions and 9 deletions

View File

@ -246,19 +246,16 @@ class Tensor:
def abs(self):
return self.relu() + (-1.0*self).relu()
def _pool2d(self, py, px):
xup = self.unpad2d(padding=(0, self.shape[3]%px, 0, self.shape[2]%py))
return xup.reshape(shape=(xup.shape[0], xup.shape[1], xup.shape[2]//py, py, xup.shape[3]//px, px))
def avg_pool2d(self, kernel_size=(2,2)):
chan = self.shape[1]
ww = np.zeros((chan, 1, kernel_size[0], kernel_size[1]), dtype=np.float32)
ww[range(chan), 0, :, :] = 1/(kernel_size[0]*kernel_size[1])
return self.conv2d(Tensor(ww, device=self.device, requires_grad=False), stride=kernel_size, groups=chan)
return self._pool2d(*kernel_size).mean(axis=(3,5))
def max_pool2d(self, kernel_size=(2,2)):
py, px = kernel_size
xup = self.unpad2d(padding=(0, self.shape[3]%px, 0, self.shape[2]%py))
xup = xup.reshape(shape=(xup.shape[0], xup.shape[1], xup.shape[2]//py, py, xup.shape[3]//px, px))
# TODO: support tuples in max
xup = xup.max(axis=5).max(axis=3)
return xup
return self._pool2d(*kernel_size).max(axis=5).max(axis=3)
# An instantiation of the Function is the Context
class Function: