From c91b3c4079d096cc304048d4d2715d6918d8356a Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Wed, 26 Jun 2024 18:50:57 +0000 Subject: [PATCH] shard llama3 on 0 sometimes (#5157) --- examples/llama3.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/llama3.py b/examples/llama3.py index 27a65d21..e118dfee 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -166,6 +166,8 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, device=N for k,v in nn.state.get_state_dict(model).items(): if 'scale' in k: v.shard_(device, axis=None) # from quantized elif '.attention.' in k: v.shard_(device, axis=-1) + elif '.feed_forward.w1.' in k: v.shard_(device, axis=0) + elif '.feed_forward.w3.' in k: v.shard_(device, axis=0) elif '.feed_forward.' in k: v.shard_(device, axis=-1) elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0) elif 'output.weight' in k: v.shard_(device, axis=0)