mirror of https://github.com/commaai/tinygrad.git
Tensor(False) has dtypes.bool (#2805)
This commit is contained in:
parent
fa84998244
commit
baa94d6142
|
@ -225,6 +225,9 @@ class TestHelpers(unittest.TestCase):
|
|||
|
||||
class TestTypeSpec(unittest.TestCase):
|
||||
def test_creation(self):
|
||||
assert Tensor(True).dtype == dtypes.bool
|
||||
assert Tensor(2).dtype == dtypes.int
|
||||
assert Tensor(2.34).dtype == Tensor.default_type
|
||||
assert Tensor([]).dtype == Tensor.default_type
|
||||
assert Tensor([1]).dtype == dtypes.int
|
||||
assert Tensor([1.1]).dtype == Tensor.default_type
|
||||
|
|
|
@ -59,6 +59,7 @@ class Tensor:
|
|||
# internal variables used for autograd graph construction
|
||||
self._ctx: Optional[Function] = None
|
||||
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
elif isinstance(data, bool): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.bool, device, data)
|
||||
elif isinstance(data, int): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.int32, device, data)
|
||||
elif isinstance(data, float): data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data)
|
||||
elif isinstance(data, bytes): data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8))
|
||||
|
|
Loading…
Reference in New Issue