mirror of https://github.com/commaai/tinygrad.git
Add GPU Support! (do not merge yet) (#41)
* copy tensors to and from gpu * add on GPU * adding works * we stick shapes in * works on cpu and gpu * test changes, not passing yet * something else * op tests pass * add, mean, and sum have working forward/backward * mul ops test * no gpu support, no problem * test pass, clean up later * gpu cleanup * cleanup test ops, don't let div fail * revert more * aimpler dispatcher * clean up grad * GPU and * grad is a Tensor now * gate test on GPU * cleanups * late loading gpu * GPU as input option * last cleanups
This commit is contained in:
parent
c06a4fcc80
commit
9ac1ad40d6
|
@ -3,43 +3,55 @@ import numpy as np
|
|||
import unittest
|
||||
import timeit
|
||||
import functools
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.tensor import Tensor, GPU
|
||||
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-7, grad_atol=1e-7):
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-7, grad_atol=1e-7, gpu=False, forward_only=False):
|
||||
ts = [torch.rand(x, requires_grad=True) for x in shps]
|
||||
tst = [Tensor(x.detach().numpy()) for x in ts]
|
||||
if gpu:
|
||||
tst = [x.cuda() for x in tst]
|
||||
|
||||
out = torch_fxn(*ts)
|
||||
ret = tinygrad_fxn(*tst)
|
||||
|
||||
# TODO: why so inaccurate?
|
||||
np.testing.assert_allclose(ret.data, out.detach().numpy(), atol=atol)
|
||||
np.testing.assert_allclose(ret.cpu().data, out.detach().numpy(), atol=atol)
|
||||
|
||||
out.mean().backward()
|
||||
ret.mean().backward()
|
||||
if not forward_only:
|
||||
out.mean().backward()
|
||||
ret.mean().backward()
|
||||
|
||||
for t, tt in zip(ts, tst):
|
||||
np.testing.assert_allclose(t.grad, tt.grad, atol=grad_atol)
|
||||
for t, tt in zip(ts, tst):
|
||||
np.testing.assert_allclose(t.grad, tt.grad.cpu().data, atol=grad_atol)
|
||||
|
||||
# speed
|
||||
torch_fp = timeit.Timer(functools.partial(torch_fxn, *ts)).timeit(5) * 1000/5
|
||||
tinygrad_fp = timeit.Timer(functools.partial(tinygrad_fxn, *tst)).timeit(5) * 1000/5
|
||||
|
||||
torch_fbp = timeit.Timer(functools.partial(lambda f,x: f(*x).mean().backward(), torch_fxn, ts)).timeit(5) * 1000/5
|
||||
tinygrad_fbp = timeit.Timer(functools.partial(lambda f,x: f(*x).mean().backward(), tinygrad_fxn, tst)).timeit(5) * 1000/5
|
||||
if not forward_only:
|
||||
torch_fbp = timeit.Timer(functools.partial(lambda f,x: f(*x).mean().backward(), torch_fxn, ts)).timeit(5) * 1000/5
|
||||
tinygrad_fbp = timeit.Timer(functools.partial(lambda f,x: f(*x).mean().backward(), tinygrad_fxn, tst)).timeit(5) * 1000/5
|
||||
else:
|
||||
torch_fbp, tinygrad_fbp = np.nan, np.nan
|
||||
|
||||
print("testing %30r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms" % (shps, torch_fp, tinygrad_fp, torch_fbp-torch_fp, tinygrad_fbp-tinygrad_fp))
|
||||
|
||||
class TestOps(unittest.TestCase):
|
||||
def test_add(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add)
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
def test_add_gpu(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add, gpu=True)
|
||||
def test_sub(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub)
|
||||
def test_mul(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x*y, Tensor.mul)
|
||||
@unittest.skipUnless(GPU, "Requires GPU")
|
||||
def test_mul_gpu(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x*y, Tensor.mul, gpu=True)
|
||||
def test_div(self):
|
||||
# TODO: why does this need more tolerance?
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div, atol=5e-5, grad_atol=1e-5)
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div, atol=1e-3, grad_atol=1e-3)
|
||||
def test_pow(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow)
|
||||
def test_sqrt(self):
|
||||
|
|
|
@ -18,7 +18,7 @@ class TestTinygrad(unittest.TestCase):
|
|||
out = out.logsoftmax()
|
||||
out = out.mul(m).add(m).sum()
|
||||
out.backward()
|
||||
return out.data, x.grad, W.grad
|
||||
return out.data, x.grad.data, W.grad.data
|
||||
|
||||
def test_pytorch():
|
||||
x = torch.tensor(x_init, requires_grad=True)
|
||||
|
|
|
@ -16,7 +16,7 @@ def jacobian(func, input):
|
|||
o_scalar = Tensor(mask_like(output.data, o, 1.)).mul(output).sum()
|
||||
o_scalar.backward()
|
||||
|
||||
for i, grad in enumerate(input.grad.reshape(-1)):
|
||||
for i, grad in enumerate(input.grad.data.reshape(-1)):
|
||||
J[o,i] = grad
|
||||
return J
|
||||
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
import numpy as np
|
||||
from .tensor import Function, register, Tensor
|
||||
import pyopencl as cl
|
||||
|
||||
def buffer_new(ctx, shape):
|
||||
res_g = cl.Buffer(ctx.cl_ctx, cl.mem_flags.WRITE_ONLY, 4*np.prod(shape))
|
||||
res_g.shape = shape
|
||||
res_g.dtype = np.float32
|
||||
return res_g
|
||||
|
||||
def buffer_like(ctx, x):
|
||||
return buffer_new(ctx, x.shape)
|
||||
|
||||
class Add(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ret = buffer_like(ctx, x)
|
||||
prg = cl.Program(ctx.cl_ctx, """
|
||||
__kernel void add(
|
||||
__global const float *a_g, __global const float *b_g, __global float *res_g)
|
||||
{
|
||||
int gid = get_global_id(0);
|
||||
res_g[gid] = a_g[gid] + b_g[gid];
|
||||
}
|
||||
""").build()
|
||||
prg.add(ctx.cl_queue, [ret.size//4], None, x, y, ret)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, grad_output
|
||||
register('add', Add, gpu=True)
|
||||
|
||||
class Mul(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
ret = buffer_like(ctx, x)
|
||||
prg = cl.Program(ctx.cl_ctx, """
|
||||
__kernel void mul(
|
||||
__global const float *a_g, __global const float *b_g, __global float *res_g)
|
||||
{
|
||||
int gid = get_global_id(0);
|
||||
res_g[gid] = a_g[gid] * b_g[gid];
|
||||
}
|
||||
""").build()
|
||||
prg.mul(ctx.cl_queue, [ret.size//4], None, x, y, ret)
|
||||
ctx.save_for_backward(x, y, prg)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x,y,prg = ctx.saved_tensors
|
||||
gx = buffer_like(ctx, x)
|
||||
gy = buffer_like(ctx, y)
|
||||
prg.mul(ctx.cl_queue, [gx.size//4], None, y, grad_output, gx)
|
||||
prg.mul(ctx.cl_queue, [gy.size//4], None, x, grad_output, gy)
|
||||
return gx, gy
|
||||
register('mul', Mul, gpu=True)
|
||||
|
||||
class Sum(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
ret = buffer_new(ctx, (1,))
|
||||
prg = cl.Program(ctx.cl_ctx, """
|
||||
__kernel void sum(
|
||||
__global const float *a_g, __global float *res_g)
|
||||
{
|
||||
int gid = get_global_id(0);
|
||||
res_g[0] += a_g[gid];
|
||||
}
|
||||
""").build()
|
||||
prg.sum(ctx.cl_queue, [input.size//4], None, input, ret)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
ret = Tensor(grad_output).cpu().data * np.ones(input.shape, dtype=input.dtype)
|
||||
return Tensor(ret).cuda().data
|
||||
register('sum', Sum, gpu=True)
|
||||
|
||||
class Dot(Function):
|
||||
# TODO: write me!
|
||||
@staticmethod
|
||||
def forward(ctx, x, y):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
pass
|
||||
|
||||
|
|
@ -13,7 +13,7 @@ class SGD(Optimizer):
|
|||
|
||||
def step(self):
|
||||
for t in self.params:
|
||||
t.data -= self.lr * t.grad
|
||||
t.data -= self.lr * t.grad.data
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
def __init__(self, params, lr=0.001, decay=0.9, eps=1e-8):
|
||||
|
@ -26,8 +26,8 @@ class RMSprop(Optimizer):
|
|||
|
||||
def step(self):
|
||||
for i, t in enumerate(self.params):
|
||||
self.v[i] = self.decay * self.v[i] + (1 - self.decay) * np.square(t.grad)
|
||||
t.data -= self.lr / (np.sqrt(self.v[i]) + self.eps) * t.grad
|
||||
self.v[i] = self.decay * self.v[i] + (1 - self.decay) * np.square(t.grad.data)
|
||||
t.data -= self.lr / (np.sqrt(self.v[i]) + self.eps) * t.grad.data
|
||||
|
||||
class Adam(Optimizer):
|
||||
def __init__(self, params, lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
|
||||
|
@ -47,7 +47,7 @@ class Adam(Optimizer):
|
|||
np.sqrt(1 - np.power(self.b2, self.t)) /
|
||||
(1 - np.power(self.b1, self.t)))
|
||||
for i,t in enumerate(self.params):
|
||||
self.m[i] = self.b1 * self.m[i] + (1 - self.b1) * t.grad
|
||||
self.v[i] = self.b2 * self.v[i] + (1 - self.b2) * np.square(t.grad)
|
||||
self.m[i] = self.b1 * self.m[i] + (1 - self.b1) * t.grad.data
|
||||
self.v[i] = self.b2 * self.v[i] + (1 - self.b2) * np.square(t.grad.data)
|
||||
t.data -= a * self.m[i] / (np.sqrt(self.v[i]) + self.eps)
|
||||
|
||||
|
|
|
@ -1,32 +1,52 @@
|
|||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from functools import partialmethod
|
||||
from inspect import signature
|
||||
import numpy as np
|
||||
try:
|
||||
import pyopencl as cl
|
||||
GPU = True
|
||||
except ImportError:
|
||||
# no GPU support
|
||||
GPU = False
|
||||
|
||||
cl_ctx, cl_queue = None, None
|
||||
def require_init_gpu():
|
||||
global cl_ctx, cl_queue
|
||||
if cl_queue is None:
|
||||
cl_ctx = cl.create_some_context(answers=[0,2]) # change if you don't have mac
|
||||
cl_queue = cl.CommandQueue(cl_ctx)
|
||||
|
||||
# **** start with two base classes ****
|
||||
|
||||
class Tensor:
|
||||
did_float_warning = False
|
||||
|
||||
def __init__(self, data):
|
||||
def __init__(self, data, gpu=False):
|
||||
if isinstance(data, list):
|
||||
data = np.array(data, dtype=np.float32)
|
||||
elif GPU and isinstance(data, cl._cl.Buffer):
|
||||
self.gpu = True
|
||||
elif not isinstance(data, np.ndarray):
|
||||
raise TypeError("Error constructing tensor with %r" % data)
|
||||
|
||||
if data.dtype != np.float32 and not Tensor.did_float_warning:
|
||||
# warning? float64 is actually needed for numerical jacobian
|
||||
print("warning, %r isn't float32" % (data.shape,))
|
||||
Tensor.did_float_warning = True
|
||||
if isinstance(data, np.ndarray):
|
||||
if data.dtype != np.float32 and not Tensor.did_float_warning:
|
||||
# warning? float64 is actually needed for numerical jacobian
|
||||
print("warning, %r isn't float32" % (data.shape,))
|
||||
Tensor.did_float_warning = True
|
||||
self.gpu = False
|
||||
|
||||
self.data = data
|
||||
self.grad = None
|
||||
|
||||
if gpu:
|
||||
self.data = self.cuda().data
|
||||
self.gpu = True
|
||||
|
||||
# internal variables used for autograd graph construction
|
||||
self._ctx = None
|
||||
|
||||
def __repr__(self):
|
||||
return "Tensor %r with grad %r" % (self.data, self.grad)
|
||||
return "Tensor %r with grad %r" % (self.data, self.grad.data if self.grad else None)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
|
@ -36,6 +56,10 @@ class Tensor:
|
|||
def zeros(*shape):
|
||||
return Tensor(np.zeros(shape, dtype=np.float32))
|
||||
|
||||
@staticmethod
|
||||
def ones(*shape):
|
||||
return Tensor(np.ones(shape, dtype=np.float32))
|
||||
|
||||
@staticmethod
|
||||
def randn(*shape):
|
||||
return Tensor(np.random.randn(*shape).astype(np.float32))
|
||||
|
@ -43,7 +67,7 @@ class Tensor:
|
|||
@staticmethod
|
||||
def eye(dim):
|
||||
return Tensor(np.eye(dim).astype(np.float32))
|
||||
|
||||
|
||||
def backward(self, allow_fill=True):
|
||||
#print("running backward on", self)
|
||||
if self._ctx is None:
|
||||
|
@ -52,12 +76,12 @@ class Tensor:
|
|||
if self.grad is None and allow_fill:
|
||||
# fill in the first grad with one
|
||||
# this is "implicit gradient creation"
|
||||
assert self.data.size == 1
|
||||
self.grad = np.ones_like(self.data)
|
||||
assert self.data.shape == (1,)
|
||||
self.grad = Tensor(np.ones(self.data.shape, dtype=self.data.dtype), gpu=self.gpu)
|
||||
|
||||
assert(self.grad is not None)
|
||||
|
||||
grads = self._ctx.backward(self._ctx, self.grad)
|
||||
grads = self._ctx.backward(self._ctx, self.grad.data)
|
||||
if len(self._ctx.parents) == 1:
|
||||
grads = [grads]
|
||||
for t,g in zip(self._ctx.parents, grads):
|
||||
|
@ -67,21 +91,49 @@ class Tensor:
|
|||
print("grad shape must match tensor shape in %r, %r != %r" %
|
||||
(self._ctx, g.shape, t.data.shape))
|
||||
assert(False)
|
||||
t.grad = g
|
||||
t.grad = Tensor(g)
|
||||
t.backward(False)
|
||||
|
||||
# ***** tinygrad supports CPU and GPU *****
|
||||
|
||||
def cpu(self):
|
||||
if self.gpu:
|
||||
data = np.empty(self.shape, dtype=np.float32)
|
||||
cl.enqueue_copy(cl_queue, data, self.data)
|
||||
return Tensor(data)
|
||||
else:
|
||||
return self
|
||||
|
||||
def cuda(self):
|
||||
if not GPU:
|
||||
raise Exception("No GPU Support")
|
||||
if not self.gpu:
|
||||
require_init_gpu()
|
||||
assert self.data.dtype == np.float32 # only float32 on GPU
|
||||
data = cl.Buffer(cl_ctx, cl.mem_flags.READ_ONLY | cl.mem_flags.COPY_HOST_PTR, hostbuf=self.data)
|
||||
data.shape = self.shape
|
||||
data.dtype = self.data.dtype
|
||||
return Tensor(data)
|
||||
else:
|
||||
return self
|
||||
|
||||
# ***** put ops in these dicts *****
|
||||
|
||||
ops = {}
|
||||
opsgpu = {}
|
||||
|
||||
# ***** non first class ops *****
|
||||
|
||||
def mean(self):
|
||||
div = Tensor(np.array([1/self.data.size], dtype=self.data.dtype))
|
||||
div = Tensor(np.array([1/np.prod(self.shape)], dtype=self.data.dtype), gpu=self.gpu)
|
||||
return self.sum().mul(div)
|
||||
|
||||
def sqrt(self):
|
||||
root = Tensor(np.zeros(self.shape, dtype=self.data.dtype)+0.5)
|
||||
root = Tensor(np.zeros(self.shape, dtype=self.data.dtype)+0.5, gpu=self.gpu)
|
||||
return self.pow(root)
|
||||
|
||||
def div(self, y):
|
||||
root = Tensor(np.zeros(self.shape, dtype=self.data.dtype)-1)
|
||||
root = Tensor(np.zeros(self.shape, dtype=self.data.dtype)-1, gpu=self.gpu)
|
||||
return self.mul(y.pow(root))
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
|
@ -93,15 +145,8 @@ class Function:
|
|||
def save_for_backward(self, *x):
|
||||
self.saved_tensors.extend(x)
|
||||
|
||||
# note that due to how partialmethod works, self and arg are switched
|
||||
def apply(self, arg, *x, **kwargs):
|
||||
# support the args in both orders
|
||||
if type(arg) == Tensor:
|
||||
op = self
|
||||
x = [arg]+list(x)
|
||||
else:
|
||||
op = arg
|
||||
x = [self]+list(x)
|
||||
def apply(self, *x, **kwargs):
|
||||
op = self
|
||||
ctx = op(*x)
|
||||
# use default params
|
||||
params = signature(op.forward).parameters
|
||||
|
@ -115,9 +160,19 @@ class Function:
|
|||
ret._ctx = ctx
|
||||
return ret
|
||||
|
||||
def register(name, fxn):
|
||||
setattr(Tensor, name, partialmethod(fxn.apply, fxn))
|
||||
def register(name, fxn, gpu=False):
|
||||
if gpu:
|
||||
Tensor.opsgpu[name] = fxn
|
||||
else:
|
||||
Tensor.ops[name] = fxn
|
||||
def dispatch(self, *x, **kwargs):
|
||||
f = (Tensor.opsgpu if self.gpu else Tensor.ops)[name]
|
||||
f.cl_ctx, f.cl_queue = cl_ctx, cl_queue
|
||||
return f.apply(f, self, *x, **kwargs)
|
||||
setattr(Tensor, name, dispatch)
|
||||
|
||||
# this registers all the operations
|
||||
import tinygrad.ops
|
||||
if GPU:
|
||||
import tinygrad.opsgpu
|
||||
|
||||
|
|
Loading…
Reference in New Issue