Tensor.full of bool has dtypes.bool (#2823)

This commit is contained in:
chenyu 2023-12-18 10:51:17 -05:00 committed by GitHub
parent 220abcd8ff
commit 8aab19ce3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 2 deletions

View File

@ -238,6 +238,7 @@ class TestTypeSpec(unittest.TestCase):
assert Tensor.zeros([2,3]).dtype == Tensor.default_type
assert Tensor.full([2,3], 3.3).dtype == Tensor.default_type
assert Tensor.full([2,3], 3).dtype == dtypes.int
assert Tensor.full([2,3], True).dtype == dtypes.bool
def test_reduce_0d_default(self):
assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type

View File

@ -67,7 +67,7 @@ class Tensor:
elif data is None: data = LazyBuffer.fromCPU(np.array([], dtype=(dtype or Tensor.default_type).np))
elif isinstance(data, list):
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool
if d and all_int(d): dtype = dtype or dtypes.int32
elif d and all_int(d): dtype = dtype or dtypes.int32
else: dtype = dtype or Tensor.default_type
# NOTE: cast at the end for the types that do not have a numpy dtype
data = LazyBuffer.fromCPU(np.array(data, dtype.np)).cast(dtype)
@ -171,7 +171,8 @@ class Tensor:
@staticmethod
def full(shape:Tuple[sint, ...], fill_value, **kwargs):
dtype = kwargs.pop("dtype", Tensor.default_type if isinstance(fill_value, float) else dtypes.int32)
# TODO: dtypes.default_type and dtypes.from_py
dtype = kwargs.pop("dtype", Tensor.default_type if isinstance(fill_value,float) else dtypes.bool if isinstance(fill_value,bool) else dtypes.int32)
return Tensor(fill_value, dtype=dtype, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
@staticmethod