fp16 support in stable diffusion

This commit is contained in:
George Hotz 2023-08-20 05:36:25 +00:00
parent ad7d26c393
commit b9feb1b743
1 changed files with 6 additions and 1 deletions

View File

@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, GlobalCounters
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from extra.utils import download_file
from tinygrad.state import torch_load, load_state_dict
from tinygrad.state import torch_load, load_state_dict, get_state_dict
class AttnBlock:
def __init__(self, in_channels):
@ -564,6 +564,7 @@ if __name__ == "__main__":
parser.add_argument('--prompt', type=str, default="a horse sized cat eating a bagel", help="Phrase to render")
parser.add_argument('--out', type=str, default=os.path.join(tempfile.gettempdir(), "rendered.png"), help="Output filename")
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
args = parser.parse_args()
Tensor.no_grad = True
@ -573,6 +574,10 @@ if __name__ == "__main__":
download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
if args.fp16:
for l in get_state_dict(model).values():
l.assign(l.cast(dtypes.float16).realize())
# run through CLIP to get context
tokenizer = ClipTokenizer()
prompt = Tensor([tokenizer.encode(args.prompt)])