Add check for negative dimension in view (#5790)

* add check for negative dimension in view

* add negative dim tests

* move check to tensor level

* fix error message

* move check to view create

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
P4ssenger 2024-07-30 14:26:27 -03:00 committed by GitHub
parent 2b7b7591d2
commit 6742a4789a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 38 additions and 0 deletions

View File

@ -1,5 +1,6 @@
import time, math, unittest
import numpy as np
from typing import List, Callable
import torch
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI
from tinygrad import Tensor, Device, dtypes
@ -97,6 +98,42 @@ class TestOps(unittest.TestCase):
def test_full(self):
helper_test_op([], lambda: torch.full((45,65), 4, dtype=torch.int32), lambda: Tensor.full((45,65), 4), forward_only=True)
def test_negative_dims(self):
creation_methods: List[Callable[..., Tensor]] = [
Tensor.empty,
Tensor.rand,
Tensor.zeros,
Tensor.ones,
Tensor.randn,
Tensor.randint,
Tensor.normal,
Tensor.uniform,
Tensor.scaled_uniform,
Tensor.glorot_uniform
]
for method in creation_methods:
with self.assertRaises(RuntimeError): method(-3, 2)
with self.assertRaises(RuntimeError): method((2, -3))
with self.assertRaises(RuntimeError): method((2, -3, 0))
def test_negative_dims_full(self):
with self.assertRaises(RuntimeError): Tensor.full(-3, 2)
with self.assertRaises(RuntimeError): Tensor.full((2, -3), 4)
with self.assertRaises(RuntimeError): Tensor.full((2, -3, 0), 4)
def test_negative_dims_eye(self):
with self.assertRaises(RuntimeError): Tensor.eye(-3, 3)
with self.assertRaises(AssertionError): Tensor.eye(3, -3)
with self.assertRaises(RuntimeError): Tensor.eye(-3, -3)
def test_negative_dims_kaiming(self):
creation_methods = [Tensor.kaiming_uniform, Tensor.kaiming_normal]
for method in creation_methods:
with self.assertRaises(RuntimeError): method(-3, 3)
with self.assertRaises(ValueError): method((-3, 3), 3)
with self.assertRaises(ValueError): method((-3, -3), 3)
def test_zeros(self):
helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True)
helper_test_op([], lambda: torch.zeros([45,65]), lambda: Tensor.zeros([45,65]), forward_only=True)

View File

@ -99,6 +99,7 @@ class View:
@staticmethod
@functools.lru_cache(maxsize=None)
def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
if not all(s >= 0 for s in shape): raise RuntimeError(f"Trying to create View with negative dimension: {shape=}")
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
# canonicalize 0 in shape
if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)