shard llama3 on 0 sometimes (#5157)

This commit is contained in:
wozeparrot 2024-06-26 18:50:57 +00:00 committed by GitHub
parent 294bd1a9ff
commit c91b3c4079
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 0 deletions

View File

@ -166,6 +166,8 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, device=N
for k,v in nn.state.get_state_dict(model).items():
if 'scale' in k: v.shard_(device, axis=None) # from quantized
elif '.attention.' in k: v.shard_(device, axis=-1)
elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
elif '.feed_forward.' in k: v.shard_(device, axis=-1)
elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
elif 'output.weight' in k: v.shard_(device, axis=0)