remove TODO: remove explicit dtypes after broadcast fix in stable_diffusion (#4241)

this is done
This commit is contained in:
chenyu 2024-04-21 00:31:24 -04:00 committed by GitHub
parent a1940ced77
commit 30fc1ad415
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 2 deletions

View File

@ -250,8 +250,7 @@ class Upsample:
def timestep_embedding(timesteps, dim, max_period=10000):
half = dim // 2
# TODO: remove explicit dtypes after broadcast fix
freqs = (-math.log(max_period) * Tensor.arange(half, dtype=dtypes.float32) / half).exp()
freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
args = timesteps * freqs
return Tensor.cat(args.cos(), args.sin()).reshape(1, -1)