keepdim avoids reshapes

This commit is contained in:
George Hotz 2022-06-05 15:56:42 -07:00
parent 60b8689ea2
commit 81c9438ea1
3 changed files with 17 additions and 17 deletions

View File

@ -38,23 +38,31 @@ C.shape = max(A.shape, B.shape)
Movement Ops
Movement Ops (2 or 1)
===
Reshape, Transpose, Slice
Depending on your Tensor implementation, these are free.
Reshape is almost always free.
Slice can be made free.
Slice can be made free, but probably shouldn't be.
Transpose is hard to make free except in trivial cases.
Regardless, these are "reindexings" of existing arrays
Transpose and Slice are similar enough I think they can be merged.
They should use a DMA engine
Processing Ops
Processing Ops (4)
===
Matmul is 1 matmul for forward, 2 for backward.
Conv2D is very complex.
* It's actually three matmuls transposed
* cublasSgemm()
Conv2D is very complex. It seems to need three.
* cudnnConvolutionForward()
* cudnnConvolutionBackwardData()
* cudnnConvolutionBackwardFilter()
NOTE: Tensor Cores require that the tensors be in the NHWC data layout

View File

@ -1,5 +1,4 @@
# llops don't know about derivatives
import functools
import numpy as np
import pyopencl as cl
@ -45,11 +44,6 @@ def buffer_new(shape, zero=False):
def buffer_np(x):
return cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | cl.mem_flags.COPY_HOST_PTR, hostbuf=x)
def clbuffer(hostbuf, shape):
return cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE | (cl.mem_flags.COPY_HOST_PTR if hostbuf is not None else 0),
4*np.prod(shape),
hostbuf=hostbuf.astype(np.float32).ravel() if hostbuf is not None else None)
@functools.lru_cache
def clbuild(name, prg):
clprg = cl.Program(cl_ctx, prg).build().__getattr__(name)

View File

@ -287,16 +287,14 @@ class Tensor:
return self.relu() - (-neg_slope*self).relu()
def softmax(self):
ns = list(self.shape)[:-1]+[1]
m = self.max(axis=len(self.shape)-1).reshape(shape=ns)
m = self.max(axis=len(self.shape)-1, keepdim=True)
e = (self - m).exp()
ss = e.sum(axis=len(self.shape)-1).reshape(shape=ns)
ss = e.sum(axis=len(self.shape)-1, keepdim=True)
return e.div(ss)
def logsoftmax(self):
ns = list(self.shape)[:-1]+[1]
m = self.max(axis=len(self.shape)-1).reshape(shape=ns)
ss = m + (self-m).exp().sum(axis=len(self.shape)-1).reshape(shape=ns).log()
m = self.max(axis=len(self.shape)-1, keepdim=True)
ss = m + (self-m).exp().sum(axis=len(self.shape)-1, keepdim=True).log()
return self - ss
def dropout(self, p=0.5):