docs: showcase remove mnist_gan and add conversation.py (#4757)

fixed both examples, and i think it's better to show conversation
This commit is contained in:
chenyu 2024-05-28 11:09:26 -04:00 committed by GitHub
parent 019f4680e5
commit e614b7c696
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 12 additions and 8 deletions

View File

@ -33,12 +33,6 @@ SMALL=1 python3 examples/whisper.py
## Generative
### Generative Adversarial Networks
Take a look at [mnist_gan.py](https://github.com/tinygrad/tinygrad/tree/master/examples/mnist_gan.py).
![mnist gan by tinygrad](https://github.com/tinygrad/tinygrad/blob/master/docs/showcase/mnist_by_tinygrad.jpg?raw=true)
### Stable Diffusion
```sh
@ -57,3 +51,12 @@ Then you can have a chat with Stacy:
```sh
python3 examples/llama.py
```
### Conversation
Make sure you have espeak installed and `PHONEMIZER_ESPEAK_LIBRARY` set.
Then you can talk to Stacy:
```sh
python3 examples/conversation.py
```

View File

@ -88,6 +88,7 @@ if __name__ == "__main__":
optim_g = optim.Adam(get_parameters(generator),lr=0.0002, b1=0.5) # 0.0002 for equilibrium!
optim_d = optim.Adam(get_parameters(discriminator),lr=0.0002, b1=0.5)
# training loop
Tensor.training = True
for epoch in (t := trange(epochs)):
loss_g, loss_d = 0.0, 0.0
for _ in range(n_steps):

View File

@ -487,7 +487,7 @@ def split(tensor, split_sizes, dim=0): # if split_sizes is an integer, convert
slice_range = [(start, start + size) if j == dim else None for j in range(len(tensor.shape))]
slices.append(slice_range)
start += size
return [tensor.slice(s) for s in slices]
return [tensor._slice(s) for s in slices]
def gather(x, indices, axis):
indices = (indices < 0).where(indices + x.shape[axis], indices).transpose(0, axis)
permute_args = list(range(x.ndim))

View File

@ -4,7 +4,7 @@ if "THREEFRY" not in os.environ: os.environ["THREEFRY"] = "1"
from tinygrad import Tensor, GlobalCounters
for N in [10, 100, 1_000, 10_000, 100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000]:
for N in [10_000_000, 100_000_000, 1_000_000_000]:
GlobalCounters.reset()
Tensor.rand(N).realize()
print(f"N {N:>20_}, global_ops {GlobalCounters.global_ops:>20_}, global_mem {GlobalCounters.global_mem:>20_}")