add LSTMCell to nn (#6080)

* add LSTMCell to nn

* lstmcell works with no input on first

* fix no bias 0

* simpler
This commit is contained in:
George Hotz 2024-08-14 12:08:42 -07:00 committed by GitHub
parent 6b3112d525
commit 64563abc90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 60 additions and 1 deletions

View File

@ -12,6 +12,7 @@
::: tinygrad.nn.LayerNorm2d
::: tinygrad.nn.RMSNorm
::: tinygrad.nn.Embedding
::: tinygrad.nn.LSTMCell
## Optimizers

View File

@ -6,7 +6,7 @@ from tinygrad import Tensor, Device, TinyJit
from tinygrad.helpers import CI, Context
from tinygrad.ops import MetaOps
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm, LSTMCell
from tinygrad.nn.state import load_state_dict
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
@ -481,5 +481,39 @@ class TestNN(unittest.TestCase):
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
def test_lstm_cell(self):
layer = LSTMCell(32, 16)
with torch.no_grad():
torch_layer = torch.nn.LSTMCell(32, 16)
layer.weight_hh.assign(torch_layer.weight_hh.numpy())
layer.weight_ih.assign(torch_layer.weight_ih.numpy())
layer.bias_hh.assign(torch_layer.bias_hh.numpy())
layer.bias_ih.assign(torch_layer.bias_ih.numpy())
inp = Tensor.randn(1, 32)
out_h, out_c = layer(inp)
torch_out_h, torch_out_c = torch_layer(torch.tensor(inp.numpy()))
np.testing.assert_allclose(out_h.numpy(), torch_out_h.numpy(), atol=1e-6)
np.testing.assert_allclose(out_c.numpy(), torch_out_c.numpy(), atol=1e-6)
out_h, out_c = layer(inp, (out_h, out_c))
torch_out_h, torch_out_c = torch_layer(torch.tensor(inp.numpy()), (torch_out_h, torch_out_c))
np.testing.assert_allclose(out_h.numpy(), torch_out_h.numpy(), atol=1e-6)
np.testing.assert_allclose(out_c.numpy(), torch_out_c.numpy(), atol=1e-6)
def test_lstm_cell_no_bias(self):
layer = LSTMCell(32, 16, bias=False)
inp = Tensor.randn(1, 32)
out_h, out_c = layer(inp)
out_h.realize()
out_c.realize()
h = Tensor.randn(1, 16)
c = Tensor.randn(1, 16)
out_h, out_c = layer(inp, (h, c))
out_h.realize()
out_c.realize()
assert layer.bias_hh is None
assert layer.bias_ih is None
if __name__ == '__main__':
unittest.main()

View File

@ -319,3 +319,27 @@ class Embedding:
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp)
return (arange == idx).mul(vals).sum(2, acc_dtype=vals.dtype)
class LSTMCell:
"""
A long short-term memory (LSTM) cell.
Args:
input_size: The number of expected features in the input `x`
hidden_size: The number of features in the hidden state `h`
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`
"""
def __init__(self, input_size:int, hidden_size:int, bias:bool=True):
stdv = 1.0 / math.sqrt(hidden_size)
self.weight_ih = Tensor.uniform(hidden_size*4, input_size, low=-stdv, high=stdv)
self.weight_hh = Tensor.uniform(hidden_size*4, hidden_size, low=-stdv, high=stdv)
self.bias_ih, self.bias_hh = (Tensor.zeros(hidden_size*4), Tensor.zeros(hidden_size*4)) if bias else (None, None)
def __call__(self, x:Tensor, hc:Optional[Tuple[Tensor, Tensor]]=None) -> Tuple[Tensor, Tensor]:
if hc is None: hc = (Tensor.zeros(x.size(0), self.weight_hh.size(1), dtype=x.dtype, device=x.device),)*2
gates = x.linear(self.weight_ih.T, self.bias_ih) + hc[0].linear(self.weight_hh.T, self.bias_hh)
i, f, g, o = gates.chunk(4, dim=1)
i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
new_c = f * hc[1] + i * g
new_h = o * new_c.tanh()
return (new_h.contiguous(), new_c.contiguous())