mirror of https://github.com/commaai/tinygrad.git
Add Tensor.multinomial (#2295)
* add Tensor.multinomial only with replacement * add support for 2D input in Tensor.multinomial * fix multinomial output shape * allow passing replacement=False to Tensor.multinomial when num_samples=1 * improve tests for Tensor.multinomial * fix edge case in Tensor.multinomial * Tensor.multinomial no more staticmethod
This commit is contained in:
parent
cb6cfcc8f8
commit
b8d460d203
|
@ -96,6 +96,33 @@ class TestRandomness(unittest.TestCase):
|
|||
for shape in [(128, 64, 3, 3), (20, 24)]:
|
||||
self.assertTrue(equal_distribution(Tensor.kaiming_normal, lambda x: torch.nn.init.kaiming_normal_(torch.empty(x)), shape=shape))
|
||||
|
||||
def test_multinomial(self):
|
||||
def _check_with_torch(p, num_samples, replacement):
|
||||
tiny_res = Tensor(p).multinomial(num_samples, replacement=replacement)
|
||||
torch_res = torch.tensor(p).multinomial(num_samples, replacement=replacement)
|
||||
self.assertEqual(tiny_res.shape, torch_res.shape)
|
||||
if torch_res.ndim == 1:
|
||||
tiny_res = tiny_res.unsqueeze(0)
|
||||
torch_res = torch_res.unsqueeze(0)
|
||||
for i in range(torch_res.shape[0]):
|
||||
self.assertTrue(equal_distribution(lambda *_: tiny_res[i], lambda _: torch_res[i]))
|
||||
_check_with_torch(p=[0.231, 0., 1., 0.5], num_samples=2000, replacement=True)
|
||||
_check_with_torch(p=[[0.2, 0.8]], num_samples=2000, replacement=True) # 2D but only 1 row
|
||||
_check_with_torch(p=[[0.453, 0., 1., 0.81], [0.1, 0.8, 0., 0.1]], num_samples=2000, replacement=True)
|
||||
# no-replacement isn't supported, unless taking only one sample
|
||||
p = [0.1, 0.9]
|
||||
self.assertRaises(AssertionError, lambda: Tensor(p).multinomial(100, replacement=False))
|
||||
tiny_samples = [Tensor(p).multinomial(1, replacement=False).numpy().item() for _ in range(1000)]
|
||||
torch_samples = [torch.tensor(p).multinomial(1, replacement=False).item() for _ in range(1000)]
|
||||
self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_samples), lambda _: torch.tensor(torch_samples)))
|
||||
|
||||
def test_multinomial_counterexample(self):
|
||||
tiny_res = Tensor([0.3, 0.6, 0.1]).multinomial(2000, replacement=True)
|
||||
torch_res = torch.tensor([0.3, 0.6, 0.1]).multinomial(2000, replacement=True)
|
||||
self.assertTrue(equal_distribution(lambda *_: tiny_res, lambda _: torch_res))
|
||||
torch_res = torch.tensor([0.2, 0.7, 0.1]).multinomial(2000, replacement=True)
|
||||
self.assertFalse(equal_distribution(lambda *_: tiny_res, lambda _: torch_res))
|
||||
|
||||
def test_conv2d_init(self):
|
||||
params = (128, 256, (3,3))
|
||||
assert equal_distribution(lambda *_: nn.Conv2d(*params).weight, lambda _: torch.nn.Conv2d(*params).weight.detach())
|
||||
|
|
|
@ -216,6 +216,19 @@ class Tensor:
|
|||
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
|
||||
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
|
||||
|
||||
def multinomial(self: Tensor, num_samples: int, replacement: bool = False) -> Tensor:
|
||||
assert self.ndim <= 2, "p must be 1 or 2 dim"
|
||||
assert replacement or num_samples == 1, "supported only with replacement"
|
||||
p = self.unsqueeze(0) if self.ndim == 1 else self
|
||||
cdf = p.cumsum(1)
|
||||
cdf /= cdf[:, -1].unsqueeze(1)
|
||||
unif_samples = Tensor.rand(num_samples, p.shape[0], 1)
|
||||
indices = (unif_samples.expand((-1, -1, p.shape[1])) >= cdf).sum(2)
|
||||
indices = indices.permute((1, 0))
|
||||
if self.ndim == 1:
|
||||
indices = indices.squeeze(0)
|
||||
return indices.cast(dtypes.int32)
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
def deepwalk(self):
|
||||
def _deepwalk(node, visited, nodes):
|
||||
|
|
Loading…
Reference in New Issue