mirror of https://github.com/commaai/tinygrad.git
make profiler work for backward pass too
This commit is contained in:
parent
2cae2dfa07
commit
6a5cb6842e
|
@ -24,14 +24,11 @@ class ProfileOp:
|
|||
return self
|
||||
def __exit__(self, *junk):
|
||||
if DEBUG:
|
||||
if self.output is not None:
|
||||
self.output.data.toCPU()
|
||||
elif self.x[0] is not None:
|
||||
self.x[0].data.toCPU()
|
||||
self.output[0].data.toCPU()
|
||||
et = (time.time()-self.st)*1000.
|
||||
debug_counts[self.name] += 1
|
||||
debug_times[self.name] += et
|
||||
print(f"{self.name:>20} : {et:>7.2f} ms {str([y.shape for y in self.x]):>40} {'-> '+str(self.output.shape) if self.output is not None else ''}")
|
||||
print(f"{self.name:>20} : {et:>7.2f} ms {str([y.shape for y in self.x]):>40} -> {str([y.shape for y in self.output])}")
|
||||
|
||||
# **** enumerate supported devices ****
|
||||
|
||||
|
@ -132,14 +129,13 @@ class Tensor:
|
|||
assert (t0.grad is not None)
|
||||
with ProfileOp(t0._ctx.__class__.__name__, [t0.grad], backward=True) as po:
|
||||
grads = t0._ctx.backward(t0._ctx, t0.grad.data)
|
||||
if len(t0._ctx.parents) == 1:
|
||||
grads = [grads]
|
||||
po.output = grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
||||
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
||||
for t, g in zip(t0._ctx.parents, grads):
|
||||
if g is not None and t.requires_grad:
|
||||
assert g.shape == t.shape, \
|
||||
f"grad shape must match tensor shape in {self._ctx!r}, {g.shape!r} != {t.shape!r}"
|
||||
gt = Tensor(g, device=self.device, requires_grad=False)
|
||||
t.grad = gt if t.grad is None else (t.grad + gt)
|
||||
t.grad = g if t.grad is None else (t.grad + g)
|
||||
|
||||
# ***** tinygrad supports many devices *****
|
||||
|
||||
|
@ -355,8 +351,9 @@ class Function:
|
|||
for k, v in kwargs.items():
|
||||
setattr(ctx, k, v)
|
||||
with ProfileOp(ctx.__class__.__name__, x) as po:
|
||||
po.output = ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs),
|
||||
ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs),
|
||||
device=ctx.device, requires_grad=any([t.requires_grad for t in x]))
|
||||
po.output = [ret]
|
||||
if ret.requires_grad:
|
||||
ret._ctx = ctx
|
||||
return ret
|
||||
|
|
Loading…
Reference in New Issue