mirror of https://github.com/commaai/tinygrad.git
fix casting behavior for interpreted buffers (#1525)
This commit is contained in:
parent
13659ac6fa
commit
b6937acb7e
|
@ -4,7 +4,7 @@ import math
|
|||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI
|
||||
from tinygrad.lazy import Device
|
||||
|
||||
if CI:
|
||||
|
@ -15,7 +15,7 @@ FORWARD_ONLY = getenv("FORWARD_ONLY", 0)
|
|||
PRINT_TENSORS = getenv("PRINT_TENSORS", 0)
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, forward_only=False, vals=None, a=-0.5, b=3):
|
||||
if tinygrad_fxn is None: tinygrad_fxn = torch_fxn
|
||||
ts, tst = prepare_test_op(a, b, shps, vals)
|
||||
ts, tst = prepare_test_op(a, b, shps, vals, forward_only)
|
||||
|
||||
st = time.monotonic()
|
||||
out = torch_fxn(*ts)
|
||||
|
@ -55,12 +55,12 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
|||
|
||||
if not CI: print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % (shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="")
|
||||
|
||||
def prepare_test_op(a, b, shps, vals):
|
||||
def prepare_test_op(a, b, shps, vals, forward_only=False):
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
if shps is None: ts = [torch.tensor(x, requires_grad=True) for x in vals]
|
||||
else: ts = [torch.tensor((np.random.random(size=x) + a) * b, requires_grad=True, dtype=torch.float32) for x in shps]
|
||||
tst = [Tensor(x.detach().numpy(), requires_grad=not FORWARD_ONLY) for x in ts]
|
||||
if shps is None: ts = [torch.tensor(x, requires_grad=(not forward_only)) for x in vals]
|
||||
else: ts = [torch.tensor((np.random.random(size=x) + a) * b, requires_grad=(not forward_only), dtype=torch.float32) for x in shps]
|
||||
tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts]
|
||||
return ts, tst
|
||||
|
||||
class TestOps(unittest.TestCase):
|
||||
|
@ -240,6 +240,7 @@ class TestOps(unittest.TestCase):
|
|||
def test_div(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div)
|
||||
helper_test_op([(), ()], lambda x,y: x/y, Tensor.div)
|
||||
helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=[[5],[1]])
|
||||
def test_div_const(self):
|
||||
helper_test_op([(45,65)], lambda x: x/255, lambda x: x/255)
|
||||
helper_test_op([(45,65)], lambda x: x/1, lambda x: x/1)
|
||||
|
|
|
@ -97,6 +97,8 @@ class dtypes:
|
|||
half = float16
|
||||
float32: Final[DType] = DType(4, 4, "float", np.float32)
|
||||
float = float32
|
||||
float64: Final[DType] = DType(0, 8, "double", np.float64)
|
||||
double = float64
|
||||
int8: Final[DType] = DType(0, 1, "char", np.int8)
|
||||
int16: Final[DType] = DType(1, 2, "short", np.int16)
|
||||
int32: Final[DType] = DType(2, 4, "int", np.int32)
|
||||
|
|
|
@ -81,11 +81,8 @@ class LazyOp:
|
|||
|
||||
class Interpreted:
|
||||
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_lazybuffer=lambda x: x.realized, to_underlying=lambda x: x._buf, from_underlying=None):
|
||||
self.buffer = buffer
|
||||
self.fxn_for_op = fxn_for_op
|
||||
self.from_lazybuffer = from_lazybuffer
|
||||
self.buffer, self.fxn_for_op, self.from_lazybuffer, self.to_underlying = buffer, fxn_for_op, from_lazybuffer, to_underlying
|
||||
self.from_underlying = buffer if from_underlying is None else from_underlying
|
||||
self.to_underlying = to_underlying
|
||||
self.synchronize = lambda: None
|
||||
self.codegen = None
|
||||
|
||||
|
@ -98,6 +95,7 @@ class Interpreted:
|
|||
srcs = [self.exec_ast(cast(LazyOp, x), context=context, **kwargs) if x.__class__ is LazyOp else self.from_lazybuffer(x) for x in ast.src]
|
||||
if DEBUG >= 3: st = time.perf_counter()
|
||||
ret = self.from_underlying(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else []))))
|
||||
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op: ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.to_underlying(ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
|
||||
if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape) if hasattr(ret._buf, 'shape') else str(len(ret._buf)):30s} in({len(srcs)}):", list(set(x._buf.shape if hasattr(x._buf, 'shape') else len(x._buf) for x in srcs)), ast.arg if ast.arg is not None else "")
|
||||
if not created_context: context[ast] = ret
|
||||
if output is not None and output.output_buffer is not None:
|
||||
|
|
Loading…
Reference in New Issue