DEBUG flag

This commit is contained in:
George Hotz 2020-11-10 00:36:59 -08:00
parent 6a56d5d030
commit f7d10d5639
1 changed files with 11 additions and 3 deletions

View File

@ -1,12 +1,14 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from inspect import signature
import numpy as np
import time
try:
import pyopencl as cl
GPU = True
except ImportError:
# no GPU support
GPU = False
DEBUG = False
cl_ctx, cl_queue = None, None
def require_init_gpu():
@ -184,10 +186,16 @@ def register(name, fxn, gpu=False):
Tensor.opsgpu[name] = fxn
else:
Tensor.ops[name] = fxn
def dispatch(self, *x, **kwargs):
f = (Tensor.opsgpu if self.gpu else Tensor.ops)[name]
def dispatch(*x, **kwargs):
f = (Tensor.opsgpu if x[0].gpu else Tensor.ops)[name]
f.cl_ctx, f.cl_queue = cl_ctx, cl_queue
return f.apply(f, self, *x, **kwargs)
if DEBUG:
st = time.time()
ret = f.apply(f, *x, **kwargs)
if DEBUG:
et = time.time()-st
print("%20s : %5.2f ms %s" % (name, et*1000.0, [y.shape for y in x]))
return ret
setattr(Tensor, name, dispatch)
if name in ['add', 'sub', 'mul', 'div']:
setattr(Tensor, "__%s__" % name, dispatch)