mirror of https://github.com/commaai/tinygrad.git
bfloat16 tensor creation from list and numpy (#3724)
This commit is contained in:
parent
f30fb192b7
commit
6793db169b
|
@ -127,11 +127,20 @@ def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
|
||||||
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
|
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
|
||||||
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
|
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
|
||||||
|
|
||||||
|
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "HIP", "METAL"], "bfloat16 not supported")
|
||||||
|
class TestBFloat16(unittest.TestCase):
|
||||||
|
def test_bf16_creation_numpy(self):
|
||||||
|
data = [-1, 1, 2]
|
||||||
|
t = Tensor(data, dtype=dtypes.bfloat16)
|
||||||
|
assert t.dtype == dtypes.bfloat16
|
||||||
|
tnp = t.numpy()
|
||||||
|
assert tnp.dtype == np.float32
|
||||||
|
np.testing.assert_allclose(tnp, np.array(data))
|
||||||
|
|
||||||
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "HIP"], "bfloat16 not supported")
|
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "HIP"], "bfloat16 not supported")
|
||||||
class TestBFloat16DType(unittest.TestCase):
|
class TestBFloat16DType(unittest.TestCase):
|
||||||
def test_bf16_to_float(self):
|
def test_bf16_to_float(self):
|
||||||
with self.assertRaises(AssertionError):
|
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
|
||||||
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
|
|
||||||
|
|
||||||
def test_float_to_bf16(self):
|
def test_float_to_bf16(self):
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
|
|
|
@ -97,8 +97,8 @@ class Tensor:
|
||||||
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool
|
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool
|
||||||
elif d and all_int(d): dtype = dtype or dtypes.default_int
|
elif d and all_int(d): dtype = dtype or dtypes.default_int
|
||||||
else: dtype = dtype or dtypes.default_float
|
else: dtype = dtype or dtypes.default_float
|
||||||
# NOTE: cast at the end for the dtypes that do not have a numpy dtype
|
if dtype == dtypes.bfloat16: data = Tensor(_fromcpu(np.array(data, np.float32)), device=device).cast(dtypes.bfloat16).lazydata
|
||||||
data = _fromcpu(np.array(data, dtype.np)).cast(dtype)
|
else: data = _fromcpu(np.array(data, dtype.np))
|
||||||
elif isinstance(data, np.ndarray):
|
elif isinstance(data, np.ndarray):
|
||||||
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
|
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
|
||||||
else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
|
else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
|
||||||
|
@ -172,6 +172,7 @@ class Tensor:
|
||||||
assert self.numel() == 1, "must have one element for item"
|
assert self.numel() == 1, "must have one element for item"
|
||||||
return self._data().cast(self.dtype.fmt)[0]
|
return self._data().cast(self.dtype.fmt)[0]
|
||||||
def numpy(self) -> np.ndarray:
|
def numpy(self) -> np.ndarray:
|
||||||
|
if self.dtype == dtypes.bfloat16: return self.float().numpy()
|
||||||
assert self.dtype.np is not None, f"no np dtype for {self.dtype}"
|
assert self.dtype.np is not None, f"no np dtype for {self.dtype}"
|
||||||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||||
return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape)
|
return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape)
|
||||||
|
|
Loading…
Reference in New Issue