mirror of https://github.com/commaai/tinygrad.git
Add check for negative dimension in view (#5790)
* add check for negative dimension in view * add negative dim tests * move check to tensor level * fix error message * move check to view create --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
parent
2b7b7591d2
commit
6742a4789a
|
@ -1,5 +1,6 @@
|
|||
import time, math, unittest
|
||||
import numpy as np
|
||||
from typing import List, Callable
|
||||
import torch
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
|
@ -97,6 +98,42 @@ class TestOps(unittest.TestCase):
|
|||
def test_full(self):
|
||||
helper_test_op([], lambda: torch.full((45,65), 4, dtype=torch.int32), lambda: Tensor.full((45,65), 4), forward_only=True)
|
||||
|
||||
def test_negative_dims(self):
|
||||
creation_methods: List[Callable[..., Tensor]] = [
|
||||
Tensor.empty,
|
||||
Tensor.rand,
|
||||
Tensor.zeros,
|
||||
Tensor.ones,
|
||||
Tensor.randn,
|
||||
Tensor.randint,
|
||||
Tensor.normal,
|
||||
Tensor.uniform,
|
||||
Tensor.scaled_uniform,
|
||||
Tensor.glorot_uniform
|
||||
]
|
||||
|
||||
for method in creation_methods:
|
||||
with self.assertRaises(RuntimeError): method(-3, 2)
|
||||
with self.assertRaises(RuntimeError): method((2, -3))
|
||||
with self.assertRaises(RuntimeError): method((2, -3, 0))
|
||||
|
||||
def test_negative_dims_full(self):
|
||||
with self.assertRaises(RuntimeError): Tensor.full(-3, 2)
|
||||
with self.assertRaises(RuntimeError): Tensor.full((2, -3), 4)
|
||||
with self.assertRaises(RuntimeError): Tensor.full((2, -3, 0), 4)
|
||||
|
||||
def test_negative_dims_eye(self):
|
||||
with self.assertRaises(RuntimeError): Tensor.eye(-3, 3)
|
||||
with self.assertRaises(AssertionError): Tensor.eye(3, -3)
|
||||
with self.assertRaises(RuntimeError): Tensor.eye(-3, -3)
|
||||
|
||||
def test_negative_dims_kaiming(self):
|
||||
creation_methods = [Tensor.kaiming_uniform, Tensor.kaiming_normal]
|
||||
for method in creation_methods:
|
||||
with self.assertRaises(RuntimeError): method(-3, 3)
|
||||
with self.assertRaises(ValueError): method((-3, 3), 3)
|
||||
with self.assertRaises(ValueError): method((-3, -3), 3)
|
||||
|
||||
def test_zeros(self):
|
||||
helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True)
|
||||
helper_test_op([], lambda: torch.zeros([45,65]), lambda: Tensor.zeros([45,65]), forward_only=True)
|
||||
|
|
|
@ -99,6 +99,7 @@ class View:
|
|||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None):
|
||||
if not all(s >= 0 for s in shape): raise RuntimeError(f"Trying to create View with negative dimension: {shape=}")
|
||||
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
|
||||
# canonicalize 0 in shape
|
||||
if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
|
||||
|
|
Loading…
Reference in New Issue