fix examples vits / converstion.py (#4239)

it was passing a const numpy array into Tensor.arange
This commit is contained in:
chenyu 2024-04-20 23:29:12 -04:00 committed by GitHub
parent 31c9d9a228
commit 3f126c7664
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View File

@ -31,7 +31,7 @@ class Synthesizer:
y_lengths = Tensor.maximum(w_ceil.sum([1, 2]), 1).cast(dtypes.int64)
return self.generate(g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths, pad_length)
def generate(self, g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths, pad_length):
max_y_length = y_lengths.max().numpy() if max_y_length_estimate_scale is None else max(15, x.shape[-1]) * max_y_length_estimate_scale
max_y_length = y_lengths.max().item() if max_y_length_estimate_scale is None else max(15, x.shape[-1]) * max_y_length_estimate_scale
y_mask = sequence_mask(y_lengths, max_y_length).unsqueeze(1).cast(x_mask.dtype)
attn_mask = x_mask.unsqueeze(2) * y_mask.unsqueeze(-1)
attn = generate_path(w_ceil, attn_mask)

View File

@ -289,7 +289,7 @@ class Tensor:
@staticmethod
def arange(start, stop=None, step=1, **kwargs):
if stop is None: stop, start = start, 0
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), "symbolic arange not supported"
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}"
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)