make profiler work for backward pass too

This commit is contained in:
George Hotz 2022-01-16 09:35:10 -08:00
parent 2cae2dfa07
commit 6a5cb6842e
1 changed files with 7 additions and 10 deletions

View File

@ -24,14 +24,11 @@ class ProfileOp:
return self
def __exit__(self, *junk):
if DEBUG:
if self.output is not None:
self.output.data.toCPU()
elif self.x[0] is not None:
self.x[0].data.toCPU()
self.output[0].data.toCPU()
et = (time.time()-self.st)*1000.
debug_counts[self.name] += 1
debug_times[self.name] += et
print(f"{self.name:>20} : {et:>7.2f} ms {str([y.shape for y in self.x]):>40} {'-> '+str(self.output.shape) if self.output is not None else ''}")
print(f"{self.name:>20} : {et:>7.2f} ms {str([y.shape for y in self.x]):>40} -> {str([y.shape for y in self.output])}")
# **** enumerate supported devices ****
@ -132,14 +129,13 @@ class Tensor:
assert (t0.grad is not None)
with ProfileOp(t0._ctx.__class__.__name__, [t0.grad], backward=True) as po:
grads = t0._ctx.backward(t0._ctx, t0.grad.data)
if len(t0._ctx.parents) == 1:
grads = [grads]
po.output = grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
for t, g in zip(t0._ctx.parents, grads):
if g is not None and t.requires_grad:
assert g.shape == t.shape, \
f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}"
gt = Tensor(g, device=self.device, requires_grad=False)
t.grad = gt if t.grad is None else (t.grad + gt)
t.grad = g if t.grad is None else (t.grad + g)
# ***** tinygrad supports many devices *****
@ -355,8 +351,9 @@ class Function:
for k, v in kwargs.items():
setattr(ctx, k, v)
with ProfileOp(ctx.__class__.__name__, x) as po:
po.output = ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs),
ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs),
device=ctx.device, requires_grad=any([t.requires_grad for t in x]))
po.output = [ret]
if ret.requires_grad:
ret._ctx = ctx
return ret