diff --git a/inference.py b/inference.py index 4492aed..bde373a 100644 --- a/inference.py +++ b/inference.py @@ -42,8 +42,6 @@ def build(checkpoints_dir: str, tokenizer_path: str, load_model: bool, max_seq_l if device == "cuda": torch.set_default_tensor_type(torch.cuda.HalfTensor) - else: - torch.set_default_tensor_type(torch.BFloat16Tensor) model = Transformer(model_args).to(device) @@ -56,6 +54,7 @@ def build(checkpoints_dir: str, tokenizer_path: str, load_model: bool, max_seq_l return LLaMA(model, tokenizer, model_args) def text_completion(self, prompts: list[str], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None): + device = self.args.device if max_gen_len is None: max_gen_len = self.args.max_seq_len - 1 # Convert each prompt into tokens @@ -156,7 +155,7 @@ def _sample_top_p(self, probs, p): checkpoints_dir='llama-2-7b/', tokenizer_path='tokenizer.model', load_model=True, - max_seq_len=1024, + max_seq_len=128, max_batch_size=len(prompts), device=device )