mirror of https://github.com/commaai/tinygrad.git
shard llama3 on 0 sometimes (#5157)
This commit is contained in:
parent
294bd1a9ff
commit
c91b3c4079
|
@ -166,6 +166,8 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, device=N
|
|||
for k,v in nn.state.get_state_dict(model).items():
|
||||
if 'scale' in k: v.shard_(device, axis=None) # from quantized
|
||||
elif '.attention.' in k: v.shard_(device, axis=-1)
|
||||
elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
|
||||
elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
|
||||
elif '.feed_forward.' in k: v.shard_(device, axis=-1)
|
||||
elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
|
||||
elif 'output.weight' in k: v.shard_(device, axis=0)
|
||||
|
|
Loading…
Reference in New Issue