Tensor(False) has dtypes.bool (#2805)

This commit is contained in:
chenyu 2023-12-16 19:04:08 -05:00 committed by GitHub
parent fa84998244
commit baa94d6142
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 0 deletions

View File

@ -225,6 +225,9 @@ class TestHelpers(unittest.TestCase):
class TestTypeSpec(unittest.TestCase):
def test_creation(self):
assert Tensor(True).dtype == dtypes.bool
assert Tensor(2).dtype == dtypes.int
assert Tensor(2.34).dtype == Tensor.default_type
assert Tensor([]).dtype == Tensor.default_type
assert Tensor([1]).dtype == dtypes.int
assert Tensor([1.1]).dtype == Tensor.default_type

View File

@ -59,6 +59,7 @@ class Tensor:
# internal variables used for autograd graph construction
self._ctx: Optional[Function] = None
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
elif isinstance(data, bool): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.bool, device, data)
elif isinstance(data, int): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.int32, device, data)
elif isinstance(data, float): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data)
elif isinstance(data, bytes): data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8))