From 942c58be90a22dd4a9d517f086c67c404f1e89bc Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 13 Jul 2024 13:53:43 -0700 Subject: [PATCH] BEAM_COMPARE=2 validates the correctness of BEAM kernels (#5458) * beam compare 2 * found issue maybe * correct, not fail * full rand * less numpy * extra simplify doesn't fix it * reorder * no numpy * check in reverse * test new tensor behavior * better error msg --- test/test_tensor_data.py | 13 ++++++++++++- tinygrad/codegen/kernel.py | 2 +- tinygrad/dtype.py | 4 ++++ tinygrad/engine/realize.py | 28 +++++++++++++++++++++++++--- tinygrad/tensor.py | 4 ++-- 5 files changed, 44 insertions(+), 7 deletions(-) diff --git a/test/test_tensor_data.py b/test/test_tensor_data.py index 4ab945de..a6f272e1 100644 --- a/test/test_tensor_data.py +++ b/test/test_tensor_data.py @@ -1,8 +1,19 @@ -import unittest +import unittest, struct from tinygrad import Tensor, dtypes # format types: https://docs.python.org/3/library/struct.html +class TestTensorBytes(unittest.TestCase): + def test_bytes(self): + lst = Tensor(bytes(b"\xaa\xbb\xcc\xdd")) + assert lst.tolist() == [170, 187, 204, 221] + + def test_float_bytes(self): + lst = Tensor(bytes(struct.pack("ff", 0.234, 0.8585)), dtype=dtypes.float32) + assert lst.shape == (2,) + assert abs(lst.tolist()[0] - 0.234) < 1e-6 + assert abs(lst.tolist()[1] - 0.8585) < 1e-6 + class TestTensorData(unittest.TestCase): def test_data(self): a = Tensor([1,2,3,4], dtype=dtypes.int32) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 43067ab9..d126f0c1 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -682,7 +682,7 @@ class Kernel: permaxis += list(range(wd+len(warp_dims), tcd)) for x,y in pattern_2: permaxis.append(y + (wd if x == 0 else tcd)) permaxis += list(range(tcd+len(tcd_expand), self.shape_len+len(tcd_expand)-len(tcd_dims))) - return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape) + return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape).simplify() if self.opts.device == "AMD": reduce_axes = [self.shape_len-self.upcasted] diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index c555122c..f988e4cc 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -57,6 +57,10 @@ class dtypes: if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1) return -float("inf") if dtypes.is_float(dtype) else False @staticmethod + def max(dtype:DType): + if dtypes.is_int(dtype): return (2**(dtype.itemsize*8-(0 if dtypes.is_unsigned(dtype) else 1)))-1 + return float("inf") if dtypes.is_float(dtype) else True + @staticmethod def fields() -> Dict[str, DType]: return DTYPES_DICT bigint: Final[DType] = DType(-1, 0, "bigint", None, 1) # arbitrary precision integer bool: Final[DType] = DType(0, 1, "bool", '?', 1) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index e68ba270..441f417e 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -1,8 +1,9 @@ from typing import List, Dict, Optional, cast, Generator, Tuple import time, pprint from dataclasses import dataclass, replace -from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata +from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context from tinygrad.ops import MetaOps, LazyOp +from tinygrad.dtype import dtypes from tinygrad.device import Device, Buffer from tinygrad.shape.symbolic import Variable, sym_infer, sint from tinygrad.renderer import Renderer, Program @@ -26,7 +27,7 @@ def get_linearizer(renderer:Renderer, ast:LazyOp) -> Kernel: kb.required_optimizations() rawbufs = bufs_from_lin(kb, allocate=False) k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))) - if getenv("BEAM_COMPARE", 1): + if beam_compare:=getenv("BEAM_COMPARE", 1): # TODO: move the HC/TC/BEAM compare to beam_search so it can be optionally cached which choice is better lins: List[Tuple[str, Kernel]] = [(f"beam{BEAM.value}", k), (("tc" if used_tensor_cores else "hc"), k_opt)] if used_tensor_cores: @@ -36,7 +37,28 @@ def get_linearizer(renderer:Renderer, ast:LazyOp) -> Kernel: if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed)) k = timed[0][1] if logkerns is not None and logkerns_level > 1: logkerns.writelines([f"{(lin.ast, lin.applied_opts)}\n" for (_,lin,_) in timed[1:]]) - # TODO: check the correctness inline once compare_linearizer is in core + if beam_compare == 2: + from tinygrad import Tensor + all_outs: List[List[Tensor]] = [] + with Context(DEBUG=0, BEAM=0, CAPTURING=0): + rand_bufs = [Tensor.normal(buf.size, std=0.1, dtype=buf.dtype).data() if dtypes.is_float(buf.dtype) else \ + (Tensor.randint(buf.size, low=0, high=2).cast(buf.dtype).data() if buf.dtype == dtypes.bool else \ + Tensor.randint(buf.size, low=dtypes.min(buf.dtype), high=dtypes.max(buf.dtype), dtype=buf.dtype).data()) \ + for buf in rawbufs] + for _, tk in lins[::-1]: + for buf,data in zip(rawbufs, rand_bufs): buf.ensure_allocated().copyin(data) + time_linearizer(tk, rawbufs, allow_test_size=False, clear_l2=True, disable_cache=True) + all_outs.append([Tensor(bytes(buf.as_buffer()), dtype=buf.dtype) for buf in rawbufs[:len(ast.src)]]) + with Context(DEBUG=0, BEAM=0, CAPTURING=0): + for bufs in zip(*all_outs): + for b in bufs[1:]: + if dtypes.is_float(bufs[0].dtype): + # we check both atol and rtol here + diff_count = (((b-bufs[0]).abs() > 1e-3) * (((b-bufs[0])/bufs[0]).abs() > 1e-3)).sum().item() + else: + diff_count = (b != bufs[0]).sum().item() + if diff_count != 0: + raise RuntimeError(f"mismatch of {diff_count}/{b.numel()} items with type {b.dtype}, max {(b-bufs[0]).abs().max().item()}") if logkerns is not None: logkerns.writelines([f"{(k.ast, k.applied_opts)}\n"]) if DEBUG >= 5: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search return k diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c6fd3ab0..0155a8f3 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -56,7 +56,7 @@ def _fromnp(x: np.ndarray) -> LazyBuffer: return ret def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer: - if isinstance(x, bytes): ret, data = LazyBuffer.metaop(MetaOps.EMPTY, (len(x),), dtype, "PYTHON"), x + if isinstance(x, bytes): ret, data = LazyBuffer.metaop(MetaOps.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x else: ret = LazyBuffer.metaop(MetaOps.EMPTY, get_shape(x), dtype, "PYTHON") assert dtype.fmt is not None, f"{dtype=} has None fmt" @@ -124,7 +124,7 @@ class Tensor: if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" elif isinstance(data, get_args(ConstType)): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data) elif isinstance(data, Variable): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data) - elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8) + elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype) elif isinstance(data, (list, tuple)): if dtype is None: if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool