BEAM_COMPARE=2 validates the correctness of BEAM kernels (#5458)

* beam compare 2

* found issue maybe

* correct, not fail

* full rand

* less numpy

* extra simplify doesn't fix it

* reorder

* no numpy

* check in reverse

* test new tensor behavior

* better error msg
This commit is contained in:
George Hotz 2024-07-13 13:53:43 -07:00 committed by GitHub
parent 6943ea5f29
commit 942c58be90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 44 additions and 7 deletions

View File

@ -1,8 +1,19 @@
import unittest
import unittest, struct
from tinygrad import Tensor, dtypes
# format types: https://docs.python.org/3/library/struct.html
class TestTensorBytes(unittest.TestCase):
def test_bytes(self):
lst = Tensor(bytes(b"\xaa\xbb\xcc\xdd"))
assert lst.tolist() == [170, 187, 204, 221]
def test_float_bytes(self):
lst = Tensor(bytes(struct.pack("ff", 0.234, 0.8585)), dtype=dtypes.float32)
assert lst.shape == (2,)
assert abs(lst.tolist()[0] - 0.234) < 1e-6
assert abs(lst.tolist()[1] - 0.8585) < 1e-6
class TestTensorData(unittest.TestCase):
def test_data(self):
a = Tensor([1,2,3,4], dtype=dtypes.int32)

View File

@ -682,7 +682,7 @@ class Kernel:
permaxis += list(range(wd+len(warp_dims), tcd))
for x,y in pattern_2: permaxis.append(y + (wd if x == 0 else tcd))
permaxis += list(range(tcd+len(tcd_expand), self.shape_len+len(tcd_expand)-len(tcd_dims)))
return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape)
return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape).simplify()
if self.opts.device == "AMD":
reduce_axes = [self.shape_len-self.upcasted]

View File

@ -57,6 +57,10 @@ class dtypes:
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
return -float("inf") if dtypes.is_float(dtype) else False
@staticmethod
def max(dtype:DType):
if dtypes.is_int(dtype): return (2**(dtype.itemsize*8-(0 if dtypes.is_unsigned(dtype) else 1)))-1
return float("inf") if dtypes.is_float(dtype) else True
@staticmethod
def fields() -> Dict[str, DType]: return DTYPES_DICT
bigint: Final[DType] = DType(-1, 0, "bigint", None, 1) # arbitrary precision integer
bool: Final[DType] = DType(0, 1, "bool", '?', 1)

View File

@ -1,8 +1,9 @@
from typing import List, Dict, Optional, cast, Generator, Tuple
import time, pprint
from dataclasses import dataclass, replace
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context
from tinygrad.ops import MetaOps, LazyOp
from tinygrad.dtype import dtypes
from tinygrad.device import Device, Buffer
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.renderer import Renderer, Program
@ -26,7 +27,7 @@ def get_linearizer(renderer:Renderer, ast:LazyOp) -> Kernel:
kb.required_optimizations()
rawbufs = bufs_from_lin(kb, allocate=False)
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
if getenv("BEAM_COMPARE", 1):
if beam_compare:=getenv("BEAM_COMPARE", 1):
# TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better
lins: List[Tuple[str, Kernel]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)]
if used_tensor_cores:
@ -36,7 +37,28 @@ def get_linearizer(renderer:Renderer, ast:LazyOp) -> Kernel:
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
k = timed[0][1]
if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]])
# TODO: check the correctness inline once compare_linearizer is in core
if beam_compare == 2:
from tinygrad import Tensor
all_outs: List[List[Tensor]] = []
with Context(DEBUG=0, BEAM=0, CAPTURING=0):
rand_bufs = [Tensor.normal(buf.size, std=0.1, dtype=buf.dtype).data() if dtypes.is_float(buf.dtype) else \
(Tensor.randint(buf.size, low=0, high=2).cast(buf.dtype).data() if buf.dtype == dtypes.bool else \
Tensor.randint(buf.size, low=dtypes.min(buf.dtype), high=dtypes.max(buf.dtype), dtype=buf.dtype).data()) \
for buf in rawbufs]
for _, tk in lins[::-1]:
for buf,data in zip(rawbufs, rand_bufs): buf.ensure_allocated().copyin(data)
time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True, disable_cache=True)
all_outs.append([Tensor(bytes(buf.as_buffer()), dtype=buf.dtype) for buf in rawbufs[:len(ast.src)]])
with Context(DEBUG=0, BEAM=0, CAPTURING=0):
for bufs in zip(*all_outs):
for b in bufs[1:]:
if dtypes.is_float(bufs[0].dtype):
# we check both atol and rtol here
diff_count = (((b-bufs[0]).abs() > 1e-3) * (((b-bufs[0])/bufs[0]).abs() > 1e-3)).sum().item()
else:
diff_count = (b != bufs[0]).sum().item()
if diff_count != 0:
raise RuntimeError(f"mismatch of {diff_count}/{b.numel()} items with type {b.dtype}, max {(b-bufs[0]).abs().max().item()}")
if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"])
if DEBUG >= 5: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
return k

View File

@ -56,7 +56,7 @@ def _fromnp(x: np.ndarray) -> LazyBuffer:
return ret
def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
if isinstance(x, bytes): ret, data = LazyBuffer.metaop(MetaOps.EMPTY, (len(x),), dtype, "PYTHON"), x
if isinstance(x, bytes): ret, data = LazyBuffer.metaop(MetaOps.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
else:
ret = LazyBuffer.metaop(MetaOps.EMPTY, get_shape(x), dtype, "PYTHON")
assert dtype.fmt is not None, f"{dtype=} has None fmt"
@ -124,7 +124,7 @@ class Tensor:
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
elif isinstance(data, get_args(ConstType)): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
elif isinstance(data, Variable): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8)
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
elif isinstance(data, (list, tuple)):
if dtype is None:
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool