mirror of https://github.com/commaai/tinygrad.git
don't apply padding on script call (#2585)
* don't apply padding on script call * no need for new param because batch_size value can be utilized to check * fixed argument naming
This commit is contained in:
parent
9d7ead84e1
commit
7c427d738c
|
@ -109,21 +109,21 @@ def tts(
|
|||
estimate_max_y_length: bool,
|
||||
text_mapper: TextMapper,
|
||||
model_has_multiple_speakers: bool,
|
||||
batch_size=600,
|
||||
vits_batch_size=1000
|
||||
pad_length=600,
|
||||
vits_pad_length=1000
|
||||
):
|
||||
if model_to_use == "mmts-tts": text_to_synthesize = text_mapper.filter_oov(text_to_synthesize.lower())
|
||||
|
||||
# Convert the input text to a tensor.
|
||||
stn_tst = text_mapper.get_text(text_to_synthesize, hps.data.add_blank, hps.data.text_cleaners)
|
||||
init_shape = stn_tst.shape
|
||||
assert init_shape[0] < batch_size, "text is too long"
|
||||
x_tst, x_tst_lengths = stn_tst.pad(((0, batch_size - init_shape[0]),), 1).unsqueeze(0), Tensor([init_shape[0]], dtype=dtypes.int64)
|
||||
assert init_shape[0] < pad_length, "text is too long"
|
||||
x_tst, x_tst_lengths = stn_tst.pad(((0, pad_length - init_shape[0]),), 1).unsqueeze(0), Tensor([init_shape[0]], dtype=dtypes.int64)
|
||||
sid = Tensor([speaker_id], dtype=dtypes.int64) if model_has_multiple_speakers else None
|
||||
|
||||
# Perform inference.
|
||||
audio_tensor = synth.infer(x_tst, x_tst_lengths, sid, noise_scale, length_scale, noise_scale_w, emotion_embedding=emotion_embedding,
|
||||
max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use] if estimate_max_y_length else None, batch_size=vits_batch_size)[0, 0]
|
||||
max_y_length_estimate_scale=Y_LENGTH_ESTIMATE_SCALARS[model_to_use] if estimate_max_y_length else None, pad_length=vits_pad_length)[0, 0]
|
||||
# Save the audio output.
|
||||
audio_data = (np.clip(audio_tensor.numpy(), -1.0, 1.0) * 32767).astype(np.int16)
|
||||
return audio_data
|
||||
|
|
|
@ -23,14 +23,14 @@ class Synthesizer:
|
|||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) if use_sdp else DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
|
||||
if n_speakers > 1: self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
||||
def infer(self, x, x_lengths, sid=None, noise_scale=1.0, length_scale=1, noise_scale_w=1., max_len=None, emotion_embedding=None, max_y_length_estimate_scale=None, batch_size=500):
|
||||
def infer(self, x, x_lengths, sid=None, noise_scale=1.0, length_scale=1, noise_scale_w=1., max_len=None, emotion_embedding=None, max_y_length_estimate_scale=None, pad_length=-1):
|
||||
x, m_p, logs_p, x_mask = self.enc_p.forward(x.realize(), x_lengths.realize(), emotion_embedding.realize() if emotion_embedding is not None else emotion_embedding)
|
||||
g = self.emb_g(sid.reshape(1, 1)).squeeze(1).unsqueeze(-1) if self.n_speakers > 0 else None
|
||||
logw = self.dp.forward(x, x_mask.realize(), g=g.realize(), reverse=self.use_sdp, noise_scale=noise_scale_w if self.use_sdp else 1.0)
|
||||
w_ceil = Tensor.ceil(logw.exp() * x_mask * length_scale)
|
||||
y_lengths = Tensor.maximum(w_ceil.sum([1, 2]), 1).cast(dtypes.int64)
|
||||
return self.generate(g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths, batch_size)
|
||||
def generate(self, g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths, batch_size):
|
||||
return self.generate(g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths, pad_length)
|
||||
def generate(self, g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths, pad_length):
|
||||
max_y_length = y_lengths.max().numpy() if max_y_length_estimate_scale is None else max(15, x.shape[-1]) * max_y_length_estimate_scale
|
||||
y_mask = sequence_mask(y_lengths, max_y_length).unsqueeze(1).cast(x_mask.dtype)
|
||||
attn_mask = x_mask.unsqueeze(2) * y_mask.unsqueeze(-1)
|
||||
|
@ -38,13 +38,14 @@ class Synthesizer:
|
|||
m_p_2 = attn.squeeze(1).matmul(m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
logs_p_2 = attn.squeeze(1).matmul(logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
z_p = m_p_2 + Tensor.randn(*m_p_2.shape, dtype=m_p_2.dtype) * logs_p_2.exp() * noise_scale
|
||||
# Pad flow forward inputs to enable JIT
|
||||
row_len = y_mask.shape[2]
|
||||
assert batch_size > row_len, "batch size is too small"
|
||||
y_mask = y_mask.pad(((0, 0), (0, 0), (0, batch_size - row_len)), 0).cast(z_p.dtype)
|
||||
# New y_mask tensor to remove sts mask
|
||||
y_mask = Tensor(y_mask.numpy(), device=y_mask.device, dtype=y_mask.dtype, requires_grad=y_mask.requires_grad)
|
||||
z_p = z_p.squeeze(0).pad(((0, 0), (0, batch_size - z_p.shape[2])), 1).unsqueeze(0)
|
||||
if pad_length > -1:
|
||||
# Pad flow forward inputs to enable JIT
|
||||
assert pad_length > row_len, "pad length is too small"
|
||||
y_mask = y_mask.pad(((0, 0), (0, 0), (0, pad_length - row_len)), 0).cast(z_p.dtype)
|
||||
# New y_mask tensor to remove sts mask
|
||||
y_mask = Tensor(y_mask.numpy(), device=y_mask.device, dtype=y_mask.dtype, requires_grad=y_mask.requires_grad)
|
||||
z_p = z_p.squeeze(0).pad(((0, 0), (0, pad_length - z_p.shape[2])), 1).unsqueeze(0)
|
||||
z = self.flow.forward(z_p.realize(), y_mask.realize(), g=g.realize(), reverse=True)
|
||||
result_length = reduce(lambda x, y: x * y, self.dec.upsample_rates, row_len)
|
||||
o = self.dec.forward((z * y_mask)[:, :, :max_len], g=g)[:, :, :result_length]
|
||||
|
|
Loading…
Reference in New Issue