mirror of https://github.com/commaai/tinygrad.git
Tensor.full of bool has dtypes.bool (#2823)
This commit is contained in:
parent
220abcd8ff
commit
8aab19ce3d
|
@ -238,6 +238,7 @@ class TestTypeSpec(unittest.TestCase):
|
|||
assert Tensor.zeros([2,3]).dtype == Tensor.default_type
|
||||
assert Tensor.full([2,3], 3.3).dtype == Tensor.default_type
|
||||
assert Tensor.full([2,3], 3).dtype == dtypes.int
|
||||
assert Tensor.full([2,3], True).dtype == dtypes.bool
|
||||
|
||||
def test_reduce_0d_default(self):
|
||||
assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type
|
||||
|
|
|
@ -67,7 +67,7 @@ class Tensor:
|
|||
elif data is None: data = LazyBuffer.fromCPU(np.array([], dtype=(dtype or Tensor.default_type).np))
|
||||
elif isinstance(data, list):
|
||||
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool
|
||||
if d and all_int(d): dtype = dtype or dtypes.int32
|
||||
elif d and all_int(d): dtype = dtype or dtypes.int32
|
||||
else: dtype = dtype or Tensor.default_type
|
||||
# NOTE: cast at the end for the types that do not have a numpy dtype
|
||||
data = LazyBuffer.fromCPU(np.array(data, dtype.np)).cast(dtype)
|
||||
|
@ -171,7 +171,8 @@ class Tensor:
|
|||
|
||||
@staticmethod
|
||||
def full(shape:Tuple[sint, ...], fill_value, **kwargs):
|
||||
dtype = kwargs.pop("dtype", Tensor.default_type if isinstance(fill_value, float) else dtypes.int32)
|
||||
# TODO: dtypes.default_type and dtypes.from_py
|
||||
dtype = kwargs.pop("dtype", Tensor.default_type if isinstance(fill_value,float) else dtypes.bool if isinstance(fill_value,bool) else dtypes.int32)
|
||||
return Tensor(fill_value, dtype=dtype, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in New Issue