From 89a0ddda05ba1e0e217b670de78537b5bdc5c2ea Mon Sep 17 00:00:00 2001 From: setzer22 Date: Sat, 18 Mar 2023 13:07:06 +0100 Subject: [PATCH 1/2] Add --seed arg to llama-cli --- llama-cli/src/cli_args.rs | 6 ++++++ llama-cli/src/main.rs | 8 ++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs index 92278762..6342f39d 100644 --- a/llama-cli/src/cli_args.rs +++ b/llama-cli/src/cli_args.rs @@ -67,6 +67,12 @@ pub struct Args { /// --cache-prompt #[arg(long, default_value = None)] pub restore_prompt: Option, + + /// Specifies the seed to use during sampling. Note that, depending on + /// hardware, the same seed may lead to different results on two separate + /// machines. + #[arg(long, default_value = None)] + pub seed: Option, } /// CLI args are stored in a lazy static variable so they're accessible from diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index cba57078..3d85d0b7 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -2,7 +2,7 @@ use std::{convert::Infallible, io::Write}; use cli_args::CLI_ARGS; use llama_rs::{InferenceParameters, InferenceSnapshot}; -use rand::thread_rng; +use rand::{thread_rng, SeedableRng}; mod cli_args; @@ -94,7 +94,11 @@ fn main() { log::info!("Model fully loaded!"); - let mut rng = thread_rng(); + let mut rng = if let Some(seed) = CLI_ARGS.seed { + rand::rngs::StdRng::seed_from_u64(seed) + } else { + rand::rngs::StdRng::from_entropy() + }; let mut session = if let Some(restore_path) = &args.restore_prompt { let snapshot = InferenceSnapshot::load_from_disk(restore_path); From df16b3e9b7f3f6bdc8c2404e2d1cd2d85700b44f Mon Sep 17 00:00:00 2001 From: setzer22 Date: Sat, 18 Mar 2023 14:17:54 +0100 Subject: [PATCH 2/2] Remove unused code --- llama-cli/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama-cli/src/main.rs b/llama-cli/src/main.rs index 3d85d0b7..cb3e6cc5 100644 --- a/llama-cli/src/main.rs +++ b/llama-cli/src/main.rs @@ -2,7 +2,7 @@ use std::{convert::Infallible, io::Write}; use cli_args::CLI_ARGS; use llama_rs::{InferenceParameters, InferenceSnapshot}; -use rand::{thread_rng, SeedableRng}; +use rand::SeedableRng; mod cli_args;