refactor ops

This commit is contained in:
George Hotz 2021-11-27 11:12:23 -05:00
parent 4320c45c4b
commit 7ae14179d3
5 changed files with 25 additions and 14 deletions

View File

@ -4,7 +4,7 @@ import numpy as np
import unittest
import timeit
import functools
from tinygrad.tensor import Tensor, DEFAULT_DEVICE, Device
from tinygrad.tensor import Tensor, Device
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-6, rtol=1e-3, grad_atol=1e-6, grad_rtol=1e-3, forward_only=False, vals=None, a=-0.5, b=20):
torch.manual_seed(0)

View File

@ -1,5 +1,5 @@
import numpy as np
from .tensor import Function
from ..tensor import Function
class CPUBuffer(np.ndarray):
def log(x):

View File

@ -1,7 +1,7 @@
import functools
import pyopencl as cl
import numpy as np
from .tensor import Function
from ..tensor import Function
cl_ctx, cl_queue = None, None
def require_init_gpu():

View File

@ -1,6 +1,6 @@
import torch
import numpy as np
from .tensor import Function
from ..tensor import Function
class TorchBuffer(torch.Tensor):
@staticmethod
@ -8,14 +8,16 @@ class TorchBuffer(torch.Tensor):
return TorchBuffer(torch.from_numpy(data).requires_grad_(False))
def toCPU(x):
return x.numpy()
def getdtype(self):
return np.float32
# ************* unary+binary+reduce ops *************
from tinygrad.ops_cpu import ReLU, Log, Exp, Add, Sub, Mul, Pow, Sum, Max
from tinygrad.ops.ops_cpu import ReLU, Log, Exp, Add, Sub, Mul, Pow, Sum, Max
# ************* movement ops *************
from tinygrad.ops_cpu import Reshape, Transpose
from tinygrad.ops.ops_cpu import Reshape, Transpose
def inner_slice(x, arg):
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
@ -35,7 +37,7 @@ class Slice(Function):
# ************* processing ops *************
from tinygrad.ops_cpu import Matmul
from tinygrad.ops.ops_cpu import Matmul
class Conv2D(Function):
def forward(ctx, x, w, stride=1, groups=1):

View File

@ -34,16 +34,25 @@ class ProfileOp:
# **** start with two base classes, Tensor and Function ****
# TODO: make this class creation generic
class Device: CPU, GPU, TORCH, buffers, imports = 0, 1, 2, {}, {0:"ops_cpu", 1:"ops_gpu", 2:"ops_torch"}
DEFAULT_DEVICE = (Device.CPU if os.environ.get("GPU", 0) != "1" else Device.GPU) if os.environ.get("TORCH", 0) != "1" else Device.TORCH
class Device:
buffers = {}
imports = {}
_ops = sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ops")))
DEFAULT = None
for i,o in enumerate([os.path.splitext(x)[0] for x in _ops if x.startswith("ops_")]):
name = o[len("ops_"):].upper()
if os.environ.get(name, 0) == "1":
DEFAULT = i
vars()[name] = i
imports[i] = o
DEFAULT = CPU if DEFAULT is None else DEFAULT
class Tensor:
did_float_warning = False
training = True
ops = defaultdict(dict)
def __init__(self, data, device=DEFAULT_DEVICE, requires_grad=True):
def __init__(self, data, device=Device.DEFAULT, requires_grad=True):
self.device, self.data = device, self._move_data(data, device)
self.grad, self.requires_grad = None, requires_grad
@ -63,8 +72,8 @@ class Tensor:
@property
def dtype(self):
if self.device == Device.TORCH:
return np.float32
if getattr(self.data, 'getdtype', None):
return self.data.getdtype()
else:
return self.data.dtype
@ -308,6 +317,6 @@ def _register_ops(namespace, device=Device.CPU):
import importlib
for d,ops in Device.imports.items():
try:
_register_ops(importlib.import_module('tinygrad.'+ops), d)
_register_ops(importlib.import_module('tinygrad.ops.'+ops), d)
except ImportError:
pass