diff --git a/examples/conversation.py b/examples/conversation.py index e3ff7645..fccfbebd 100644 --- a/examples/conversation.py +++ b/examples/conversation.py @@ -232,7 +232,7 @@ if __name__ == "__main__": parser.add_argument("--llama_pre_prompt_path", type=Path, default=Path(__file__).parent / "conversation_data" / "pre_prompt_stacy.yaml", help="Path to yaml file which contains all pre-prompt data needed. ") parser.add_argument("--llama_count", type=int, default=1000, help="Max number of tokens to generate") parser.add_argument("--llama_temperature", type=float, default=0.7, help="Temperature in the softmax") - parser.add_argument("--llama_quantize", action="store_true", help="Quantize the weights to int8 in memory") + parser.add_argument("--llama_quantize", type=str, default=None, help="Quantize the weights to int8 or nf4 in memory") parser.add_argument("--llama_model", type=Path, default=None, help="Folder with the original weights to load, or single .index.json, .safetensors or .bin file") parser.add_argument("--llama_gen", type=str, default="tiny", required=False, help="Generation of the model to use") parser.add_argument("--llama_size", type=str, default="1B-Chat", required=False, help="Size of model to use")