mirror of https://github.com/commaai/tinygrad.git
arange default dtype to int and zeros/ones default to float (#2769)
This commit is contained in:
parent
3cf4376ce2
commit
66d9eb10b6
|
@ -477,8 +477,8 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None,
|
|||
output_shape = sizes if sizes else [math.floor(x*s) for x,s in zip(X.shape, scales)]
|
||||
output_shape_ = sizes if sizes else [x*s for x,s in zip(X.shape, scales)]
|
||||
scales_lol = [os/xs for xs, os in zip(X.shape, output_shape)]
|
||||
x_out = Tensor.arange(output_shape[-1])
|
||||
y_out = Tensor.arange(output_shape[-2])
|
||||
x_out = Tensor.arange(output_shape[-1]).cast(Tensor.default_type)
|
||||
y_out = Tensor.arange(output_shape[-2]).cast(Tensor.default_type)
|
||||
if mode == "nearest":
|
||||
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi)
|
||||
x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1])
|
||||
|
|
|
@ -236,6 +236,37 @@ class TestTypeSpec(unittest.TestCase):
|
|||
assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type
|
||||
# assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int
|
||||
|
||||
def test_arange(self):
|
||||
assert Tensor.arange(5).dtype == dtypes.int32
|
||||
assert Tensor.arange(5.0).dtype == Tensor.default_type
|
||||
assert Tensor.arange(5, dtype=dtypes.int16).dtype == dtypes.int16
|
||||
assert Tensor.arange(5, dtype=dtypes.int64).dtype == dtypes.int64
|
||||
assert Tensor.arange(5, dtype=dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.arange(3, 9, 0.7).dtype == Tensor.default_type
|
||||
assert Tensor.arange(3, 8.5, 3).dtype == Tensor.default_type
|
||||
|
||||
def test_zeros(self):
|
||||
assert Tensor.zeros(3, 3).dtype == Tensor.default_type
|
||||
assert Tensor.zeros(3, 3, dtype= dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.zeros(3, 3, dtype= dtypes.int64).dtype == dtypes.int64
|
||||
|
||||
def test_ones(self):
|
||||
assert Tensor.ones(3, 3).dtype == Tensor.default_type
|
||||
assert Tensor.ones(3, 3, dtype= dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.ones(3, 3, dtype= dtypes.int64).dtype == dtypes.int64
|
||||
|
||||
def test_full(self):
|
||||
assert Tensor.full((3, 3), 3).dtype == dtypes.int
|
||||
assert Tensor.full((3, 3), 3.0).dtype == Tensor.default_type
|
||||
assert Tensor.full((3, 3), 3, dtype= dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.full((3, 3), 3, dtype= dtypes.int64).dtype == dtypes.int64
|
||||
|
||||
def test_eye(self):
|
||||
assert Tensor.eye(0).dtype == Tensor.default_type
|
||||
assert Tensor.eye(3).dtype == Tensor.default_type
|
||||
assert Tensor.eye(3, dtype= dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.eye(3, dtype= dtypes.int64).dtype == dtypes.int64
|
||||
|
||||
core_types = list(DTYPES_DICT.values())
|
||||
class TestTypePromotion(unittest.TestCase):
|
||||
@given(st.sampled_from(core_types))
|
||||
|
|
|
@ -169,22 +169,24 @@ class Tensor:
|
|||
|
||||
@staticmethod
|
||||
def full(shape:Tuple[sint, ...], fill_value, **kwargs):
|
||||
return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
|
||||
dtype = kwargs.pop("dtype", Tensor.default_type if isinstance(fill_value, float) else dtypes.int32)
|
||||
return Tensor(fill_value, dtype=dtype, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
|
||||
|
||||
@staticmethod
|
||||
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs)
|
||||
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0.0, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs)
|
||||
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1.0, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def arange(start, stop=None, step=1, **kwargs):
|
||||
if stop is None: stop, start = start, 0
|
||||
return Tensor.full((math.ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
|
||||
dtype = kwargs.pop("dtype", Tensor.default_type if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.int32)
|
||||
return Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).cumsum() + (start - step)
|
||||
|
||||
@staticmethod
|
||||
def eye(dim:int, **kwargs):
|
||||
return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
|
||||
return Tensor.full((dim,1),1.0,**kwargs).pad((None,(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
|
||||
|
||||
def full_like(self, fill_value, **kwargs):
|
||||
return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
|
||||
|
|
Loading…
Reference in New Issue