refactor profiler

This commit is contained in:
George Hotz 2020-11-10 07:31:16 -08:00
parent f27628b21c
commit f513302955
1 changed files with 26 additions and 21 deletions

View File

@ -9,6 +9,8 @@ try:
except ImportError:
# no GPU support
GPU = False
# **** profiler, 10 lines too long ****
DEBUG = os.getenv("DEBUG", None) is not None
if DEBUG:
import collections, atexit
@ -18,6 +20,25 @@ if DEBUG:
for name, _ in sorted(debug_times.items(), key=lambda x: -x[1]):
print("%20s : %3d %10.2f ms" % (name, debug_counts[name], debug_times[name]))
atexit.register(print_debug_exit)
class ProfileOp:
def __init__(self, name, x, backward=False):
self.name = ("back_" if backward else "")+name
self.x = x
def __enter__(self):
self.st = time.time()
def __exit__(self, *junk):
et = (time.time()-self.st)*1000.
debug_counts[self.name] += 1
debug_times[self.name] += et
print("%20s : %7.2f ms %s" % (self.name, et, [y.shape for y in self.x]))
else:
class ProfileOp:
def __init__(self, name, x, backward=False):
pass
def __enter__(self):
pass
def __exit__(self, *junk):
pass
cl_ctx, cl_queue = None, None
def require_init_gpu():
@ -101,18 +122,10 @@ class Tensor:
assert(self.grad is not None)
if DEBUG:
st = time.time()
grads = self._ctx.backward(self._ctx, self.grad.data)
with ProfileOp(self._ctx.__class__.__name__, [self.grad], backward=True):
grads = self._ctx.backward(self._ctx, self.grad.data)
if len(self._ctx.parents) == 1:
grads = [grads]
if DEBUG:
global debug_counts, debug_times
name = "back_"+self._ctx.__class__.__name__
et = (time.time()-st)*1000.
debug_counts[name] += 1
debug_times[name] += et
print("%20s : %7.2f ms %s" % (name, et, [y.shape for y in grads]))
for t,g in zip(self._ctx.parents, grads):
if g is None:
continue
@ -195,7 +208,8 @@ class Function:
# overwrite with passed params
for k, v in kwargs.items():
setattr(ctx, k, v)
ret = Tensor(op.forward(ctx, *[t.data for t in x], **kwargs))
with ProfileOp(ctx.__class__.__name__, x):
ret = Tensor(op.forward(ctx, *[t.data for t in x], **kwargs))
ret._ctx = ctx
return ret
@ -207,16 +221,7 @@ def register(name, fxn, gpu=False):
def dispatch(*x, **kwargs):
f = (Tensor.opsgpu if x[0].gpu else Tensor.ops)[name]
f.cl_ctx, f.cl_queue = cl_ctx, cl_queue
if DEBUG:
st = time.time()
ret = f.apply(f, *x, **kwargs)
if DEBUG:
global debug_counts, debug_times
et = (time.time()-st)*1000.
debug_counts[name] += 1
debug_times[name] += et
print("%20s : %7.2f ms %s" % (name, et, [y.shape for y in x]))
return ret
return f.apply(f, *x, **kwargs)
setattr(Tensor, name, dispatch)
if name in ['add', 'sub', 'mul', 'div']:
setattr(Tensor, "__%s__" % name, dispatch)