From 6793db169b7d17014cfd0ce4498a640d681a30a5 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 13 Mar 2024 18:44:05 -0400 Subject: [PATCH] bfloat16 tensor creation from list and numpy (#3724) --- test/test_dtype.py | 13 +++++++++++-- tinygrad/tensor.py | 5 +++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index 7f3ebd5c..760079d3 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -127,11 +127,20 @@ def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None): _assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]]) _assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy()) +@unittest.skipUnless(Device.DEFAULT in ["LLVM", "HIP", "METAL"], "bfloat16 not supported") +class TestBFloat16(unittest.TestCase): + def test_bf16_creation_numpy(self): + data = [-1, 1, 2] + t = Tensor(data, dtype=dtypes.bfloat16) + assert t.dtype == dtypes.bfloat16 + tnp = t.numpy() + assert tnp.dtype == np.float32 + np.testing.assert_allclose(tnp, np.array(data)) + @unittest.skipUnless(Device.DEFAULT in ["LLVM", "HIP"], "bfloat16 not supported") class TestBFloat16DType(unittest.TestCase): def test_bf16_to_float(self): - with self.assertRaises(AssertionError): - _test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32) + _test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32) def test_float_to_bf16(self): with self.assertRaises(AssertionError): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c84b2954..7fd41347 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -97,8 +97,8 @@ class Tensor: if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool elif d and all_int(d): dtype = dtype or dtypes.default_int else: dtype = dtype or dtypes.default_float - # NOTE: cast at the end for the dtypes that do not have a numpy dtype - data = _fromcpu(np.array(data, dtype.np)).cast(dtype) + if dtype == dtypes.bfloat16: data = Tensor(_fromcpu(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata + else: data = _fromcpu(np.array(data, dtype.np)) elif isinstance(data, np.ndarray): if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item()) else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data) @@ -172,6 +172,7 @@ class Tensor: assert self.numel() == 1, "must have one element for item" return self._data().cast(self.dtype.fmt)[0] def numpy(self) -> np.ndarray: + if self.dtype == dtypes.bfloat16: return self.float().numpy() assert self.dtype.np is not None, f"no np dtype for {self.dtype}" assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape)