mirror of https://github.com/commaai/tinygrad.git
back to 1000 lines
This commit is contained in:
parent
a482d56229
commit
cfb7a4c41a
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
|
||||
class BatchNorm2D:
|
||||
def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
|
||||
assert affine == True
|
||||
assert affine == True, "BatchNorm2D is only supported with affine"
|
||||
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
|
||||
|
||||
self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
|
||||
|
|
|
@ -2,22 +2,15 @@ import numpy as np
|
|||
from ..tensor import Function
|
||||
|
||||
class CPUBuffer(np.ndarray):
|
||||
def log(x):
|
||||
return np.log(x)
|
||||
def exp(x):
|
||||
return np.exp(x)
|
||||
def relu(x):
|
||||
return np.maximum(x, 0)
|
||||
def expand(x, shp):
|
||||
return np.broadcast_to(x, shp)
|
||||
log = lambda x: np.log(x)
|
||||
exp = lambda x: np.exp(x)
|
||||
relu = lambda x: np.maximum(x, 0)
|
||||
expand = lambda x,shp: np.broadcast_to(x, shp)
|
||||
permute = lambda x,order: x.transpose(order)
|
||||
type = lambda x,tt: x.astype(tt)
|
||||
custompad = lambda x,padding: np.pad(x, padding)
|
||||
def amax(x, *args, **kwargs):
|
||||
return np.amax(x, *args, **kwargs)
|
||||
def permute(x, order):
|
||||
return x.transpose(order)
|
||||
def type(x, tt):
|
||||
return x.astype(tt)
|
||||
def custompad(x, padding):
|
||||
return np.pad(x, padding)
|
||||
def toCPU(x):
|
||||
return x
|
||||
@staticmethod
|
||||
|
|
|
@ -5,8 +5,7 @@ from ..tensor import Function
|
|||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
class TorchBuffer(torch.Tensor):
|
||||
def custompad(x, padding):
|
||||
return torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
|
||||
custompad = lambda x,padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist])
|
||||
@staticmethod
|
||||
def fromCPU(data):
|
||||
return TorchBuffer(torch.from_numpy(data).requires_grad_(False)).to(device)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
import os, atexit, time, inspect, functools
|
||||
import os, atexit, time, inspect, functools, importlib
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
|
@ -30,12 +30,13 @@ class ProfileOp:
|
|||
return self
|
||||
def __exit__(self, *junk):
|
||||
if GRAPH:
|
||||
saved_tensors = filter(lambda x: any([isinstance(x, v) for v in Device.buffers.values()]), self.ctx.saved_tensors)
|
||||
# connect inputs to outputs
|
||||
for x in self.x:
|
||||
for y in self.output:
|
||||
G.add_edge(id(x.data), id(y.data), label=self.name, color="blue" if self.backward else "black")
|
||||
G.nodes[id(x.data)]['label'], G.nodes[id(y.data)]['label'] = str(x.shape), str(y.shape)
|
||||
# which saved tensors does this backward depend on?
|
||||
saved_tensors = filter(lambda x: any([isinstance(x, v) for v in Device.buffers.values()]), self.ctx.saved_tensors)
|
||||
if self.backward:
|
||||
for x in saved_tensors:
|
||||
for y in self.output:
|
||||
|
@ -243,12 +244,12 @@ class Tensor:
|
|||
def sum(self, axis=None, keepdim=False):
|
||||
axis, out_shape = self._canonicalize_reduce_axis(axis)
|
||||
ret = self._sum(axis=axis)
|
||||
return ret if keepdim else ret.reshape(shape=out_shape)
|
||||
return ret if keepdim or ret.shape == out_shape else ret.reshape(shape=out_shape)
|
||||
|
||||
def max(self, axis=None, keepdim=False):
|
||||
axis, out_shape = self._canonicalize_reduce_axis(axis)
|
||||
ret = self._max(axis=axis)
|
||||
return ret if keepdim else ret.reshape(shape=out_shape)
|
||||
return ret if keepdim or ret.shape == out_shape else ret.reshape(shape=out_shape)
|
||||
|
||||
def mean(self, axis=None, keepdim=False):
|
||||
out = self.sum(axis=axis, keepdim=keepdim)
|
||||
|
@ -410,7 +411,6 @@ def _register_ops(namespace, device=Device.CPU):
|
|||
if name.endswith("Buffer"): Device.buffers[device] = cls
|
||||
elif name[0] != "_": register(name.lower(), cls, device=device)
|
||||
|
||||
import importlib
|
||||
for d,ops in Device.imports.items():
|
||||
try:
|
||||
_register_ops(importlib.import_module('tinygrad.ops.'+ops), d)
|
||||
|
|
Loading…
Reference in New Issue