Add Tensor.multinomial (#2295)

* add Tensor.multinomial only with replacement

* add support for 2D input in Tensor.multinomial

* fix multinomial output shape

* allow passing replacement=False to Tensor.multinomial when num_samples=1

* improve tests for Tensor.multinomial

* fix edge case in Tensor.multinomial

* Tensor.multinomial no more staticmethod
This commit is contained in:
Marcello Fuschi 2023-11-15 20:38:39 +01:00 committed by GitHub
parent cb6cfcc8f8
commit b8d460d203
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 0 deletions

View File

@ -96,6 +96,33 @@ class TestRandomness(unittest.TestCase):
for shape in [(128, 64, 3, 3), (20, 24)]:
self.assertTrue(equal_distribution(Tensor.kaiming_normal, lambda x: torch.nn.init.kaiming_normal_(torch.empty(x)), shape=shape))
def test_multinomial(self):
def _check_with_torch(p, num_samples, replacement):
tiny_res = Tensor(p).multinomial(num_samples, replacement=replacement)
torch_res = torch.tensor(p).multinomial(num_samples, replacement=replacement)
self.assertEqual(tiny_res.shape, torch_res.shape)
if torch_res.ndim == 1:
tiny_res = tiny_res.unsqueeze(0)
torch_res = torch_res.unsqueeze(0)
for i in range(torch_res.shape[0]):
self.assertTrue(equal_distribution(lambda *_: tiny_res[i], lambda _: torch_res[i]))
_check_with_torch(p=[0.231, 0., 1., 0.5], num_samples=2000, replacement=True)
_check_with_torch(p=[[0.2, 0.8]], num_samples=2000, replacement=True) # 2D but only 1 row
_check_with_torch(p=[[0.453, 0., 1., 0.81], [0.1, 0.8, 0., 0.1]], num_samples=2000, replacement=True)
# no-replacement isn't supported, unless taking only one sample
p = [0.1, 0.9]
self.assertRaises(AssertionError, lambda: Tensor(p).multinomial(100, replacement=False))
tiny_samples = [Tensor(p).multinomial(1, replacement=False).numpy().item() for _ in range(1000)]
torch_samples = [torch.tensor(p).multinomial(1, replacement=False).item() for _ in range(1000)]
self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_samples), lambda _: torch.tensor(torch_samples)))
def test_multinomial_counterexample(self):
tiny_res = Tensor([0.3, 0.6, 0.1]).multinomial(2000, replacement=True)
torch_res = torch.tensor([0.3, 0.6, 0.1]).multinomial(2000, replacement=True)
self.assertTrue(equal_distribution(lambda *_: tiny_res, lambda _: torch_res))
torch_res = torch.tensor([0.2, 0.7, 0.1]).multinomial(2000, replacement=True)
self.assertFalse(equal_distribution(lambda *_: tiny_res, lambda _: torch_res))
def test_conv2d_init(self):
params = (128, 256, (3,3))
assert equal_distribution(lambda *_: nn.Conv2d(*params).weight, lambda _: torch.nn.Conv2d(*params).weight.detach())

View File

@ -216,6 +216,19 @@ class Tensor:
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
def multinomial(self: Tensor, num_samples: int, replacement: bool = False) -> Tensor:
assert self.ndim <= 2, "p must be 1 or 2 dim"
assert replacement or num_samples == 1, "supported only with replacement"
p = self.unsqueeze(0) if self.ndim == 1 else self
cdf = p.cumsum(1)
cdf /= cdf[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, p.shape[0], 1)
indices = (unif_samples.expand((-1, -1, p.shape[1])) >= cdf).sum(2)
indices = indices.permute((1, 0))
if self.ndim == 1:
indices = indices.squeeze(0)
return indices.cast(dtypes.int32)
# ***** toposort and backward pass *****
def deepwalk(self):
def _deepwalk(node, visited, nodes):