diff --git a/test/test_ops.py b/test/test_ops.py index 325b5f41..06133cca 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,5 +1,6 @@ import time, math, unittest import numpy as np +from typing import List, Callable import torch from tinygrad.helpers import getenv, IMAGE, DEBUG, CI from tinygrad import Tensor, Device, dtypes @@ -97,6 +98,42 @@ class TestOps(unittest.TestCase): def test_full(self): helper_test_op([], lambda: torch.full((45,65), 4, dtype=torch.int32), lambda: Tensor.full((45,65), 4), forward_only=True) + def test_negative_dims(self): + creation_methods: List[Callable[..., Tensor]] = [ + Tensor.empty, + Tensor.rand, + Tensor.zeros, + Tensor.ones, + Tensor.randn, + Tensor.randint, + Tensor.normal, + Tensor.uniform, + Tensor.scaled_uniform, + Tensor.glorot_uniform + ] + + for method in creation_methods: + with self.assertRaises(RuntimeError): method(-3, 2) + with self.assertRaises(RuntimeError): method((2, -3)) + with self.assertRaises(RuntimeError): method((2, -3, 0)) + + def test_negative_dims_full(self): + with self.assertRaises(RuntimeError): Tensor.full(-3, 2) + with self.assertRaises(RuntimeError): Tensor.full((2, -3), 4) + with self.assertRaises(RuntimeError): Tensor.full((2, -3, 0), 4) + + def test_negative_dims_eye(self): + with self.assertRaises(RuntimeError): Tensor.eye(-3, 3) + with self.assertRaises(AssertionError): Tensor.eye(3, -3) + with self.assertRaises(RuntimeError): Tensor.eye(-3, -3) + + def test_negative_dims_kaiming(self): + creation_methods = [Tensor.kaiming_uniform, Tensor.kaiming_normal] + for method in creation_methods: + with self.assertRaises(RuntimeError): method(-3, 3) + with self.assertRaises(ValueError): method((-3, 3), 3) + with self.assertRaises(ValueError): method((-3, -3), 3) + def test_zeros(self): helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True) helper_test_op([], lambda: torch.zeros([45,65]), lambda: Tensor.zeros([45,65]), forward_only=True) diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index a194b3a4..4e88c145 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -99,6 +99,7 @@ class View: @staticmethod @functools.lru_cache(maxsize=None) def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offset:sint=0, mask:Optional[Tuple[Tuple[sint, sint], ...]]=None): + if not all(s >= 0 for s in shape): raise RuntimeError(f"Trying to create View with negative dimension: {shape=}") strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape) # canonicalize 0 in shape if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)