Tensor.uniform with dtype=int bug fix (#1593)

This commit is contained in:
Jordan Wright 2023-08-26 01:59:53 -04:00 committed by GitHub
parent f702a8f497
commit 25be7f745d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 1 deletions

View File

@ -5,6 +5,7 @@ import torch
from tinygrad.tensor import Tensor
import tinygrad.nn as nn
import pytest
from tinygrad.helpers import dtypes
pytestmark = pytest.mark.webgpu
@ -49,6 +50,13 @@ def normal_test(func, shape=(20, 23), alpha=0.05):
y = np.random.randn(*shape).flatten()
return kstest(x, y) >= alpha
def equal_distrib_ints(tiny_func, numpy_func, shape=(20, 23), low=-100, high=100, dtype=dtypes.int32, alpha=0.05):
Tensor.manual_seed(1337)
np.random.seed(1337)
x = tiny_func(*shape, low=low, high=high, dtype=dtype).cpu().numpy().flatten()
y = numpy_func(shape).flatten()
return kstest(x, y) >= alpha
def equal_distribution(tiny_func, torch_func, numpy_func=None, shape=(20, 23), alpha=0.05):
Tensor.manual_seed(1337)
torch.manual_seed(1337)
@ -74,6 +82,7 @@ class TestRandomness(unittest.TestCase):
def test_uniform(self):
self.assertFalse(normal_test(Tensor.uniform))
self.assertTrue(equal_distribution(Tensor.uniform, lambda x: torch.nn.init.uniform_(torch.empty(x), a=-1, b=1), lambda x: np.random.uniform(low=-1, high=1, size=x)))
self.assertTrue(equal_distrib_ints(Tensor.uniform, lambda x: np.random.randint(low=-100, high=100, size=x)))
def test_scaled_uniform(self):
self.assertFalse(normal_test(Tensor.scaled_uniform))

View File

@ -179,7 +179,9 @@ class Tensor:
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
@staticmethod
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor:
dtype = kwargs.pop("dtype", Tensor.default_type)
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(math.prod(shape)**-0.5)