mirror of https://github.com/commaai/tinygrad.git
fp16 support in stable diffusion
This commit is contained in:
parent
ad7d26c393
commit
b9feb1b743
|
@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.helpers import dtypes, GlobalCounters
|
||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||
from extra.utils import download_file
|
||||
from tinygrad.state import torch_load, load_state_dict
|
||||
from tinygrad.state import torch_load, load_state_dict, get_state_dict
|
||||
|
||||
class AttnBlock:
|
||||
def __init__(self, in_channels):
|
||||
|
@ -564,6 +564,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument('--prompt', type=str, default="a horse sized cat eating a bagel", help="Phrase to render")
|
||||
parser.add_argument('--out', type=str, default=os.path.join(tempfile.gettempdir(), "rendered.png"), help="Output filename")
|
||||
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
|
||||
parser.add_argument('--fp16', action='store_true', help="Cast the weights to float16")
|
||||
args = parser.parse_args()
|
||||
|
||||
Tensor.no_grad = True
|
||||
|
@ -573,6 +574,10 @@ if __name__ == "__main__":
|
|||
download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
|
||||
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
|
||||
|
||||
if args.fp16:
|
||||
for l in get_state_dict(model).values():
|
||||
l.assign(l.cast(dtypes.float16).realize())
|
||||
|
||||
# run through CLIP to get context
|
||||
tokenizer = ClipTokenizer()
|
||||
prompt = Tensor([tokenizer.encode(args.prompt)])
|
||||
|
|
Loading…
Reference in New Issue