mirror of https://github.com/commaai/tinygrad.git
add LSTMCell to nn (#6080)
* add LSTMCell to nn * lstmcell works with no input on first * fix no bias 0 * simpler
This commit is contained in:
parent
6b3112d525
commit
64563abc90
|
@ -12,6 +12,7 @@
|
|||
::: tinygrad.nn.LayerNorm2d
|
||||
::: tinygrad.nn.RMSNorm
|
||||
::: tinygrad.nn.Embedding
|
||||
::: tinygrad.nn.LSTMCell
|
||||
|
||||
## Optimizers
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from tinygrad import Tensor, Device, TinyJit
|
|||
from tinygrad.helpers import CI, Context
|
||||
from tinygrad.ops import MetaOps
|
||||
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
|
||||
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm
|
||||
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm, LSTMCell
|
||||
from tinygrad.nn.state import load_state_dict
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
|
@ -481,5 +481,39 @@ class TestNN(unittest.TestCase):
|
|||
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
|
||||
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
|
||||
|
||||
def test_lstm_cell(self):
|
||||
layer = LSTMCell(32, 16)
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.LSTMCell(32, 16)
|
||||
layer.weight_hh.assign(torch_layer.weight_hh.numpy())
|
||||
layer.weight_ih.assign(torch_layer.weight_ih.numpy())
|
||||
layer.bias_hh.assign(torch_layer.bias_hh.numpy())
|
||||
layer.bias_ih.assign(torch_layer.bias_ih.numpy())
|
||||
|
||||
inp = Tensor.randn(1, 32)
|
||||
out_h, out_c = layer(inp)
|
||||
torch_out_h, torch_out_c = torch_layer(torch.tensor(inp.numpy()))
|
||||
np.testing.assert_allclose(out_h.numpy(), torch_out_h.numpy(), atol=1e-6)
|
||||
np.testing.assert_allclose(out_c.numpy(), torch_out_c.numpy(), atol=1e-6)
|
||||
|
||||
out_h, out_c = layer(inp, (out_h, out_c))
|
||||
torch_out_h, torch_out_c = torch_layer(torch.tensor(inp.numpy()), (torch_out_h, torch_out_c))
|
||||
np.testing.assert_allclose(out_h.numpy(), torch_out_h.numpy(), atol=1e-6)
|
||||
np.testing.assert_allclose(out_c.numpy(), torch_out_c.numpy(), atol=1e-6)
|
||||
|
||||
def test_lstm_cell_no_bias(self):
|
||||
layer = LSTMCell(32, 16, bias=False)
|
||||
inp = Tensor.randn(1, 32)
|
||||
out_h, out_c = layer(inp)
|
||||
out_h.realize()
|
||||
out_c.realize()
|
||||
h = Tensor.randn(1, 16)
|
||||
c = Tensor.randn(1, 16)
|
||||
out_h, out_c = layer(inp, (h, c))
|
||||
out_h.realize()
|
||||
out_c.realize()
|
||||
assert layer.bias_hh is None
|
||||
assert layer.bias_ih is None
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -319,3 +319,27 @@ class Embedding:
|
|||
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
|
||||
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.reshape(weight_shp).expand(big_shp)
|
||||
return (arange == idx).mul(vals).sum(2, acc_dtype=vals.dtype)
|
||||
|
||||
class LSTMCell:
|
||||
"""
|
||||
A long short-term memory (LSTM) cell.
|
||||
|
||||
Args:
|
||||
input_size: The number of expected features in the input `x`
|
||||
hidden_size: The number of features in the hidden state `h`
|
||||
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`
|
||||
"""
|
||||
def __init__(self, input_size:int, hidden_size:int, bias:bool=True):
|
||||
stdv = 1.0 / math.sqrt(hidden_size)
|
||||
self.weight_ih = Tensor.uniform(hidden_size*4, input_size, low=-stdv, high=stdv)
|
||||
self.weight_hh = Tensor.uniform(hidden_size*4, hidden_size, low=-stdv, high=stdv)
|
||||
self.bias_ih, self.bias_hh = (Tensor.zeros(hidden_size*4), Tensor.zeros(hidden_size*4)) if bias else (None, None)
|
||||
|
||||
def __call__(self, x:Tensor, hc:Optional[Tuple[Tensor, Tensor]]=None) -> Tuple[Tensor, Tensor]:
|
||||
if hc is None: hc = (Tensor.zeros(x.size(0), self.weight_hh.size(1), dtype=x.dtype, device=x.device),)*2
|
||||
gates = x.linear(self.weight_ih.T, self.bias_ih) + hc[0].linear(self.weight_hh.T, self.bias_hh)
|
||||
i, f, g, o = gates.chunk(4, dim=1)
|
||||
i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
|
||||
new_c = f * hc[1] + i * g
|
||||
new_h = o * new_c.tanh()
|
||||
return (new_h.contiguous(), new_c.contiguous())
|
||||
|
|
Loading…
Reference in New Issue