mirror of https://github.com/commaai/tinygrad.git
fix examples vits / converstion.py (#4239)
it was passing a const numpy array into Tensor.arange
This commit is contained in:
parent
31c9d9a228
commit
3f126c7664
|
@ -31,7 +31,7 @@ class Synthesizer:
|
|||
y_lengths = Tensor.maximum(w_ceil.sum([1, 2]), 1).cast(dtypes.int64)
|
||||
return self.generate(g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths, pad_length)
|
||||
def generate(self, g, logs_p, m_p, max_len, max_y_length_estimate_scale, noise_scale, w_ceil, x, x_mask, y_lengths, pad_length):
|
||||
max_y_length = y_lengths.max().numpy() if max_y_length_estimate_scale is None else max(15, x.shape[-1]) * max_y_length_estimate_scale
|
||||
max_y_length = y_lengths.max().item() if max_y_length_estimate_scale is None else max(15, x.shape[-1]) * max_y_length_estimate_scale
|
||||
y_mask = sequence_mask(y_lengths, max_y_length).unsqueeze(1).cast(x_mask.dtype)
|
||||
attn_mask = x_mask.unsqueeze(2) * y_mask.unsqueeze(-1)
|
||||
attn = generate_path(w_ceil, attn_mask)
|
||||
|
|
|
@ -289,7 +289,7 @@ class Tensor:
|
|||
@staticmethod
|
||||
def arange(start, stop=None, step=1, **kwargs):
|
||||
if stop is None: stop, start = start, 0
|
||||
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), "symbolic arange not supported"
|
||||
assert all(isinstance(s, (int, float)) for s in (start, stop, step)), f"symbolic arange not supported {start=}, {stop=}, {step=}"
|
||||
dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int)
|
||||
return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
|
||||
|
||||
|
|
Loading…
Reference in New Issue