mirror of https://github.com/commaai/tinygrad.git
refactor ops
This commit is contained in:
parent
4320c45c4b
commit
7ae14179d3
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
import unittest
|
||||
import timeit
|
||||
import functools
|
||||
from tinygrad.tensor import Tensor, DEFAULT_DEVICE, Device
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-6, rtol=1e-3, grad_atol=1e-6, grad_rtol=1e-3, forward_only=False, vals=None, a=-0.5, b=20):
|
||||
torch.manual_seed(0)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import numpy as np
|
||||
from .tensor import Function
|
||||
from ..tensor import Function
|
||||
|
||||
class CPUBuffer(np.ndarray):
|
||||
def log(x):
|
|
@ -1,7 +1,7 @@
|
|||
import functools
|
||||
import pyopencl as cl
|
||||
import numpy as np
|
||||
from .tensor import Function
|
||||
from ..tensor import Function
|
||||
|
||||
cl_ctx, cl_queue = None, None
|
||||
def require_init_gpu():
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from .tensor import Function
|
||||
from ..tensor import Function
|
||||
|
||||
class TorchBuffer(torch.Tensor):
|
||||
@staticmethod
|
||||
|
@ -8,14 +8,16 @@ class TorchBuffer(torch.Tensor):
|
|||
return TorchBuffer(torch.from_numpy(data).requires_grad_(False))
|
||||
def toCPU(x):
|
||||
return x.numpy()
|
||||
def getdtype(self):
|
||||
return np.float32
|
||||
|
||||
# ************* unary+binary+reduce ops *************
|
||||
|
||||
from tinygrad.ops_cpu import ReLU, Log, Exp, Add, Sub, Mul, Pow, Sum, Max
|
||||
from tinygrad.ops.ops_cpu import ReLU, Log, Exp, Add, Sub, Mul, Pow, Sum, Max
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
from tinygrad.ops_cpu import Reshape, Transpose
|
||||
from tinygrad.ops.ops_cpu import Reshape, Transpose
|
||||
|
||||
def inner_slice(x, arg):
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
|
@ -35,7 +37,7 @@ class Slice(Function):
|
|||
|
||||
# ************* processing ops *************
|
||||
|
||||
from tinygrad.ops_cpu import Matmul
|
||||
from tinygrad.ops.ops_cpu import Matmul
|
||||
|
||||
class Conv2D(Function):
|
||||
def forward(ctx, x, w, stride=1, groups=1):
|
|
@ -34,16 +34,25 @@ class ProfileOp:
|
|||
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
|
||||
# TODO: make this class creation generic
|
||||
class Device: CPU, GPU, TORCH, buffers, imports = 0, 1, 2, {}, {0:"ops_cpu", 1:"ops_gpu", 2:"ops_torch"}
|
||||
DEFAULT_DEVICE = (Device.CPU if os.environ.get("GPU", 0) != "1" else Device.GPU) if os.environ.get("TORCH", 0) != "1" else Device.TORCH
|
||||
class Device:
|
||||
buffers = {}
|
||||
imports = {}
|
||||
_ops = sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ops")))
|
||||
DEFAULT = None
|
||||
for i,o in enumerate([os.path.splitext(x)[0] for x in _ops if x.startswith("ops_")]):
|
||||
name = o[len("ops_"):].upper()
|
||||
if os.environ.get(name, 0) == "1":
|
||||
DEFAULT = i
|
||||
vars()[name] = i
|
||||
imports[i] = o
|
||||
DEFAULT = CPU if DEFAULT is None else DEFAULT
|
||||
|
||||
class Tensor:
|
||||
did_float_warning = False
|
||||
training = True
|
||||
ops = defaultdict(dict)
|
||||
|
||||
def __init__(self, data, device=DEFAULT_DEVICE, requires_grad=True):
|
||||
def __init__(self, data, device=Device.DEFAULT, requires_grad=True):
|
||||
self.device, self.data = device, self._move_data(data, device)
|
||||
|
||||
self.grad, self.requires_grad = None, requires_grad
|
||||
|
@ -63,8 +72,8 @@ class Tensor:
|
|||
|
||||
@property
|
||||
def dtype(self):
|
||||
if self.device == Device.TORCH:
|
||||
return np.float32
|
||||
if getattr(self.data, 'getdtype', None):
|
||||
return self.data.getdtype()
|
||||
else:
|
||||
return self.data.dtype
|
||||
|
||||
|
@ -308,6 +317,6 @@ def _register_ops(namespace, device=Device.CPU):
|
|||
import importlib
|
||||
for d,ops in Device.imports.items():
|
||||
try:
|
||||
_register_ops(importlib.import_module('tinygrad.'+ops), d)
|
||||
_register_ops(importlib.import_module('tinygrad.ops.'+ops), d)
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue