mirror of https://github.com/commaai/tinygrad.git
203 lines
7.9 KiB
Python
203 lines
7.9 KiB
Python
from tinygrad.tensor import Tensor
|
|
from tinygrad.jit import TinyJit
|
|
from tinygrad.nn import Linear, Embedding
|
|
import numpy as np
|
|
from extra.utils import download_file
|
|
from pathlib import Path
|
|
|
|
|
|
class RNNT:
|
|
def __init__(self, input_features=240, vocab_size=29, enc_hidden_size=1024, pred_hidden_size=320, joint_hidden_size=512, pre_enc_layers=2, post_enc_layers=3, pred_layers=2, stack_time_factor=2, dropout=0.32):
|
|
self.encoder = Encoder(input_features, enc_hidden_size, pre_enc_layers, post_enc_layers, stack_time_factor, dropout)
|
|
self.prediction = Prediction(vocab_size, pred_hidden_size, pred_layers, dropout)
|
|
self.joint = Joint(vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout)
|
|
|
|
@TinyJit
|
|
def __call__(self, x, y, hc=None):
|
|
f, _ = self.encoder(x, None)
|
|
g, _ = self.prediction(y, hc, Tensor.ones(1, requires_grad=False))
|
|
out = self.joint(f, g)
|
|
return out.realize()
|
|
|
|
def decode(self, x, x_lens):
|
|
logits, logit_lens = self.encoder(x, x_lens)
|
|
outputs = []
|
|
for b in range(logits.shape[0]):
|
|
inseq = logits[b, :, :].unsqueeze(1)
|
|
logit_len = logit_lens[b]
|
|
seq = self._greedy_decode(inseq, int(np.ceil(logit_len.numpy()).item()))
|
|
outputs.append(seq)
|
|
return outputs
|
|
|
|
def _greedy_decode(self, logits, logit_len):
|
|
hc = Tensor.zeros(self.prediction.rnn.layers, 2, self.prediction.hidden_size, requires_grad=False)
|
|
labels = []
|
|
label = Tensor.zeros(1, 1, requires_grad=False)
|
|
mask = Tensor.zeros(1, requires_grad=False)
|
|
for time_idx in range(logit_len):
|
|
logit = logits[time_idx, :, :].unsqueeze(0)
|
|
not_blank = True
|
|
added = 0
|
|
while not_blank and added < 30:
|
|
if len(labels) > 0:
|
|
mask = (mask + 1).clip(0, 1)
|
|
label = Tensor([[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]], requires_grad=False) + 1 - 1
|
|
jhc = self._pred_joint(Tensor(logit.numpy()), label, hc, mask)
|
|
k = np.argmax(jhc[0, 0, :29].numpy(), axis=0)
|
|
not_blank = k != 28
|
|
if not_blank:
|
|
labels.append(k)
|
|
hc = jhc[:, :, 29:] + 1 - 1
|
|
added += 1
|
|
return labels
|
|
|
|
@TinyJit
|
|
def _pred_joint(self, logit, label, hc, mask):
|
|
g, hc = self.prediction(label, hc, mask)
|
|
j = self.joint(logit, g)[0]
|
|
j = j.pad(((0, 1), (0, 1), (0, 0)))
|
|
out = j.cat(hc, dim=2)
|
|
return out.realize()
|
|
|
|
def load_from_pretrained(self):
|
|
fn = Path(__file__).parent.parent / "weights/rnnt.pt"
|
|
download_file("https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1", fn)
|
|
|
|
import torch
|
|
with open(fn, "rb") as f:
|
|
state_dict = torch.load(f, map_location="cpu")["state_dict"]
|
|
|
|
# encoder
|
|
for i in range(2):
|
|
self.encoder.pre_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.pre_rnn.lstm.weight_ih_l{i}"].numpy())
|
|
self.encoder.pre_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.pre_rnn.lstm.weight_hh_l{i}"].numpy())
|
|
self.encoder.pre_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.pre_rnn.lstm.bias_ih_l{i}"].numpy())
|
|
self.encoder.pre_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.pre_rnn.lstm.bias_hh_l{i}"].numpy())
|
|
for i in range(3):
|
|
self.encoder.post_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.post_rnn.lstm.weight_ih_l{i}"].numpy())
|
|
self.encoder.post_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.post_rnn.lstm.weight_hh_l{i}"].numpy())
|
|
self.encoder.post_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.post_rnn.lstm.bias_ih_l{i}"].numpy())
|
|
self.encoder.post_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.post_rnn.lstm.bias_hh_l{i}"].numpy())
|
|
|
|
# prediction
|
|
self.prediction.emb.weight.assign(state_dict["prediction.embed.weight"].numpy())
|
|
for i in range(2):
|
|
self.prediction.rnn.cells[i].weights_ih.assign(state_dict[f"prediction.dec_rnn.lstm.weight_ih_l{i}"].numpy())
|
|
self.prediction.rnn.cells[i].weights_hh.assign(state_dict[f"prediction.dec_rnn.lstm.weight_hh_l{i}"].numpy())
|
|
self.prediction.rnn.cells[i].bias_ih.assign(state_dict[f"prediction.dec_rnn.lstm.bias_ih_l{i}"].numpy())
|
|
self.prediction.rnn.cells[i].bias_hh.assign(state_dict[f"prediction.dec_rnn.lstm.bias_hh_l{i}"].numpy())
|
|
|
|
# joint
|
|
self.joint.l1.weight.assign(state_dict["joint_net.0.weight"].numpy())
|
|
self.joint.l1.bias.assign(state_dict["joint_net.0.bias"].numpy())
|
|
self.joint.l2.weight.assign(state_dict["joint_net.3.weight"].numpy())
|
|
self.joint.l2.bias.assign(state_dict["joint_net.3.bias"].numpy())
|
|
|
|
|
|
class LSTMCell:
|
|
def __init__(self, input_size, hidden_size, dropout):
|
|
self.dropout = dropout
|
|
|
|
self.weights_ih = Tensor.uniform(hidden_size * 4, input_size)
|
|
self.bias_ih = Tensor.uniform(hidden_size * 4)
|
|
self.weights_hh = Tensor.uniform(hidden_size * 4, hidden_size)
|
|
self.bias_hh = Tensor.uniform(hidden_size * 4)
|
|
|
|
def __call__(self, x, hc):
|
|
gates = x.linear(self.weights_ih.T, self.bias_ih) + hc[:x.shape[0]].linear(self.weights_hh.T, self.bias_hh)
|
|
|
|
i, f, g, o = gates.chunk(4, 1)
|
|
i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
|
|
|
|
c = (f * hc[x.shape[0]:]) + (i * g)
|
|
h = (o * c.tanh()).dropout(self.dropout)
|
|
|
|
return Tensor.cat(h, c).realize()
|
|
|
|
|
|
class LSTM:
|
|
def __init__(self, input_size, hidden_size, layers, dropout):
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.layers = layers
|
|
|
|
self.cells = [LSTMCell(input_size, hidden_size, dropout) if i == 0 else LSTMCell(hidden_size, hidden_size, dropout if i != layers - 1 else 0) for i in range(layers)]
|
|
|
|
def __call__(self, x, hc):
|
|
@TinyJit
|
|
def _do_step(x_, hc_):
|
|
return self.do_step(x_, hc_)
|
|
|
|
if hc is None:
|
|
hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False)
|
|
|
|
output = None
|
|
for t in range(x.shape[0]):
|
|
hc = _do_step(x[t] + 1 - 1, hc) # TODO: why do we need to do this?
|
|
if output is None:
|
|
output = hc[-1:, :x.shape[1]]
|
|
else:
|
|
output = output.cat(hc[-1:, :x.shape[1]], dim=0).realize()
|
|
|
|
return output, hc
|
|
|
|
def do_step(self, x, hc):
|
|
new_hc = [x]
|
|
for i, cell in enumerate(self.cells):
|
|
new_hc.append(cell(new_hc[i][:x.shape[0]], hc[i]))
|
|
return Tensor.stack(new_hc[1:]).realize()
|
|
|
|
|
|
class StackTime:
|
|
def __init__(self, factor):
|
|
self.factor = factor
|
|
|
|
def __call__(self, x, x_lens):
|
|
x = x.pad(((0, (-x.shape[0]) % self.factor), (0, 0), (0, 0)))
|
|
x = x.reshape(x.shape[0] // self.factor, x.shape[1], x.shape[2] * self.factor)
|
|
return x, x_lens / self.factor if x_lens is not None else None
|
|
|
|
|
|
class Encoder:
|
|
def __init__(self, input_size, hidden_size, pre_layers, post_layers, stack_time_factor, dropout):
|
|
self.pre_rnn = LSTM(input_size, hidden_size, pre_layers, dropout)
|
|
self.stack_time = StackTime(stack_time_factor)
|
|
self.post_rnn = LSTM(stack_time_factor * hidden_size, hidden_size, post_layers, dropout)
|
|
|
|
def __call__(self, x, x_lens):
|
|
x, _ = self.pre_rnn(x, None)
|
|
x, x_lens = self.stack_time(x, x_lens)
|
|
x, _ = self.post_rnn(x, None)
|
|
return x.transpose(0, 1), x_lens
|
|
|
|
|
|
class Prediction:
|
|
def __init__(self, vocab_size, hidden_size, layers, dropout):
|
|
self.hidden_size = hidden_size
|
|
|
|
self.emb = Embedding(vocab_size - 1, hidden_size)
|
|
self.rnn = LSTM(hidden_size, hidden_size, layers, dropout)
|
|
|
|
def __call__(self, x, hc, m):
|
|
emb = self.emb(x) * m
|
|
x_, hc = self.rnn(emb.transpose(0, 1), hc)
|
|
return x_.transpose(0, 1), hc
|
|
|
|
|
|
class Joint:
|
|
def __init__(self, vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout):
|
|
self.dropout = dropout
|
|
|
|
self.l1 = Linear(pred_hidden_size + enc_hidden_size, joint_hidden_size)
|
|
self.l2 = Linear(joint_hidden_size, vocab_size)
|
|
|
|
def __call__(self, f, g):
|
|
(_, T, H), (B, U, H2) = f.shape, g.shape
|
|
f = f.unsqueeze(2).expand(B, T, U, H)
|
|
g = g.unsqueeze(1).expand(B, T, U, H2)
|
|
|
|
inp = f.cat(g, dim=3)
|
|
t = self.l1(inp).relu()
|
|
t = t.dropout(self.dropout)
|
|
return self.l2(t)
|