From 8aab19ce3d441cfc2f1ffb384cf1e76f7f4c3050 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 18 Dec 2023 10:51:17 -0500 Subject: [PATCH] Tensor.full of bool has dtypes.bool (#2823) --- test/test_dtype.py | 1 + tinygrad/tensor.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index 657cd91d..719999df 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -238,6 +238,7 @@ class TestTypeSpec(unittest.TestCase): assert Tensor.zeros([2,3]).dtype == Tensor.default_type assert Tensor.full([2,3], 3.3).dtype == Tensor.default_type assert Tensor.full([2,3], 3).dtype == dtypes.int + assert Tensor.full([2,3], True).dtype == dtypes.bool def test_reduce_0d_default(self): assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e6aa8a2d..2dcb80a7 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -67,7 +67,7 @@ class Tensor: elif data is None: data = LazyBuffer.fromCPU(np.array([], dtype=(dtype or Tensor.default_type).np)) elif isinstance(data, list): if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool - if d and all_int(d): dtype = dtype or dtypes.int32 + elif d and all_int(d): dtype = dtype or dtypes.int32 else: dtype = dtype or Tensor.default_type # NOTE: cast at the end for the types that do not have a numpy dtype data = LazyBuffer.fromCPU(np.array(data, dtype.np)).cast(dtype) @@ -171,7 +171,8 @@ class Tensor: @staticmethod def full(shape:Tuple[sint, ...], fill_value, **kwargs): - dtype = kwargs.pop("dtype", Tensor.default_type if isinstance(fill_value, float) else dtypes.int32) + # TODO: dtypes.default_type and dtypes.from_py + dtype = kwargs.pop("dtype", Tensor.default_type if isinstance(fill_value,float) else dtypes.bool if isinstance(fill_value,bool) else dtypes.int32) return Tensor(fill_value, dtype=dtype, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape) @staticmethod