arange default dtype to int and zeros/ones default to float (#2769)

This commit is contained in:
chenyu 2023-12-14 17:53:00 -05:00 committed by GitHub
parent 3cf4376ce2
commit 66d9eb10b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 7 deletions

View File

@ -477,8 +477,8 @@ def Resize(X:Tensor, roi=None, scales=None, sizes=None, antialias=0, axes=None,
output_shape = sizes if sizes else [math.floor(x*s) for x,s in zip(X.shape, scales)]
output_shape_ = sizes if sizes else [x*s for x,s in zip(X.shape, scales)]
scales_lol = [os/xs for xs, os in zip(X.shape, output_shape)]
x_out = Tensor.arange(output_shape[-1])
y_out = Tensor.arange(output_shape[-2])
x_out = Tensor.arange(output_shape[-1]).cast(Tensor.default_type)
y_out = Tensor.arange(output_shape[-2]).cast(Tensor.default_type)
if mode == "nearest":
x_out, y_out = _coordinate_transformation(x_out, y_out, output_shape, scales_lol, roi)
x_out = _nearest_mode(x_out, nearest_mode, X.shape[-1])

View File

@ -236,6 +236,37 @@ class TestTypeSpec(unittest.TestCase):
assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type
# assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int
def test_arange(self):
assert Tensor.arange(5).dtype == dtypes.int32
assert Tensor.arange(5.0).dtype == Tensor.default_type
assert Tensor.arange(5, dtype=dtypes.int16).dtype == dtypes.int16
assert Tensor.arange(5, dtype=dtypes.int64).dtype == dtypes.int64
assert Tensor.arange(5, dtype=dtypes.float16).dtype == dtypes.float16
assert Tensor.arange(3, 9, 0.7).dtype == Tensor.default_type
assert Tensor.arange(3, 8.5, 3).dtype == Tensor.default_type
def test_zeros(self):
assert Tensor.zeros(3, 3).dtype == Tensor.default_type
assert Tensor.zeros(3, 3, dtype= dtypes.float16).dtype == dtypes.float16
assert Tensor.zeros(3, 3, dtype= dtypes.int64).dtype == dtypes.int64
def test_ones(self):
assert Tensor.ones(3, 3).dtype == Tensor.default_type
assert Tensor.ones(3, 3, dtype= dtypes.float16).dtype == dtypes.float16
assert Tensor.ones(3, 3, dtype= dtypes.int64).dtype == dtypes.int64
def test_full(self):
assert Tensor.full((3, 3), 3).dtype == dtypes.int
assert Tensor.full((3, 3), 3.0).dtype == Tensor.default_type
assert Tensor.full((3, 3), 3, dtype= dtypes.float16).dtype == dtypes.float16
assert Tensor.full((3, 3), 3, dtype= dtypes.int64).dtype == dtypes.int64
def test_eye(self):
assert Tensor.eye(0).dtype == Tensor.default_type
assert Tensor.eye(3).dtype == Tensor.default_type
assert Tensor.eye(3, dtype= dtypes.float16).dtype == dtypes.float16
assert Tensor.eye(3, dtype= dtypes.int64).dtype == dtypes.int64
core_types = list(DTYPES_DICT.values())
class TestTypePromotion(unittest.TestCase):
@given(st.sampled_from(core_types))

View File

@ -169,22 +169,24 @@ class Tensor:
@staticmethod
def full(shape:Tuple[sint, ...], fill_value, **kwargs):
return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
dtype = kwargs.pop("dtype", Tensor.default_type if isinstance(fill_value, float) else dtypes.int32)
return Tensor(fill_value, dtype=dtype, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
@staticmethod
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs)
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0.0, **kwargs)
@staticmethod
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs)
def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1.0, **kwargs)
@staticmethod
def arange(start, stop=None, step=1, **kwargs):
if stop is None: stop, start = start, 0
return Tensor.full((math.ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
dtype = kwargs.pop("dtype", Tensor.default_type if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.int32)
return Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).cumsum() + (start - step)
@staticmethod
def eye(dim:int, **kwargs):
return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
return Tensor.full((dim,1),1.0,**kwargs).pad((None,(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
def full_like(self, fill_value, **kwargs):
return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)