diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 67e0c02c..ceaf3bf0 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -477,8 +477,8 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None, output_shape = sizes if sizes else [math.floor(x*s) for x,s in zip(X.shape, scales)] output_shape_ = sizes if sizes else [x*s for x,s in zip(X.shape, scales)] scales_lol = [os/xs for xs, os in zip(X.shape, output_shape)] - x_out = Tensor.arange(output_shape[-1]) - y_out = Tensor.arange(output_shape[-2]) + x_out = Tensor.arange(output_shape[-1]).cast(Tensor.default_type) + y_out = Tensor.arange(output_shape[-2]).cast(Tensor.default_type) if mode == "nearest": x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi) x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1]) diff --git a/test/test_dtype.py b/test/test_dtype.py index aaf364c2..40796d81 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -236,6 +236,37 @@ class TestTypeSpec(unittest.TestCase): assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type # assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int + def test_arange(self): + assert Tensor.arange(5).dtype == dtypes.int32 + assert Tensor.arange(5.0).dtype == Tensor.default_type + assert Tensor.arange(5, dtype=dtypes.int16).dtype == dtypes.int16 + assert Tensor.arange(5, dtype=dtypes.int64).dtype == dtypes.int64 + assert Tensor.arange(5, dtype=dtypes.float16).dtype == dtypes.float16 + assert Tensor.arange(3, 9, 0.7).dtype == Tensor.default_type + assert Tensor.arange(3, 8.5, 3).dtype == Tensor.default_type + + def test_zeros(self): + assert Tensor.zeros(3, 3).dtype == Tensor.default_type + assert Tensor.zeros(3, 3, dtype= dtypes.float16).dtype == dtypes.float16 + assert Tensor.zeros(3, 3, dtype= dtypes.int64).dtype == dtypes.int64 + + def test_ones(self): + assert Tensor.ones(3, 3).dtype == Tensor.default_type + assert Tensor.ones(3, 3, dtype= dtypes.float16).dtype == dtypes.float16 + assert Tensor.ones(3, 3, dtype= dtypes.int64).dtype == dtypes.int64 + + def test_full(self): + assert Tensor.full((3, 3), 3).dtype == dtypes.int + assert Tensor.full((3, 3), 3.0).dtype == Tensor.default_type + assert Tensor.full((3, 3), 3, dtype= dtypes.float16).dtype == dtypes.float16 + assert Tensor.full((3, 3), 3, dtype= dtypes.int64).dtype == dtypes.int64 + + def test_eye(self): + assert Tensor.eye(0).dtype == Tensor.default_type + assert Tensor.eye(3).dtype == Tensor.default_type + assert Tensor.eye(3, dtype= dtypes.float16).dtype == dtypes.float16 + assert Tensor.eye(3, dtype= dtypes.int64).dtype == dtypes.int64 + core_types = list(DTYPES_DICT.values()) class TestTypePromotion(unittest.TestCase): @given(st.sampled_from(core_types)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 02bed595..784f77e5 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -169,22 +169,24 @@ class Tensor: @staticmethod def full(shape:Tuple[sint, ...], fill_value, **kwargs): - return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape) + dtype = kwargs.pop("dtype", Tensor.default_type if isinstance(fill_value, float) else dtypes.int32) + return Tensor(fill_value, dtype=dtype, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape) @staticmethod - def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs) + def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0.0, **kwargs) @staticmethod - def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs) + def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1.0, **kwargs) @staticmethod def arange(start, stop=None, step=1, **kwargs): if stop is None: stop, start = start, 0 - return Tensor.full((math.ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step) + dtype = kwargs.pop("dtype", Tensor.default_type if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.int32) + return Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).cumsum() + (start - step) @staticmethod def eye(dim:int, **kwargs): - return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim) + return Tensor.full((dim,1),1.0,**kwargs).pad((None,(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim) def full_like(self, fill_value, **kwargs): return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)