From e135523a160b3ffbe1a9962bb8fa3ca3a62516a1 Mon Sep 17 00:00:00 2001 From: "Wang, Chang" Date: Fri, 23 Aug 2024 15:57:04 +0800 Subject: [PATCH] Update utils.py Signed-off-by: Wang, Chang --- .../transformers/llm/quantization/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/intel_extension_for_transformers/transformers/llm/quantization/utils.py b/intel_extension_for_transformers/transformers/llm/quantization/utils.py index a8a5b88baf9..d23971cfc11 100644 --- a/intel_extension_for_transformers/transformers/llm/quantization/utils.py +++ b/intel_extension_for_transformers/transformers/llm/quantization/utils.py @@ -438,7 +438,6 @@ def default_run_fn( model, tokenizer, dataset, max_length=512, n_samples=100, batch_size=8, algo="rtn" ): from torch.utils.data import DataLoader - if isinstance(dataset, (str, bytes, os.PathLike)): calib_dataset = load_dataset(dataset, split="train") calib_dataset = calib_dataset.shuffle(seed=42) @@ -513,7 +512,7 @@ def collate_batch(batch): try: model( - input_ids=input_ids, + input_ids=input_ids.to("xpu:0"), ) except ValueError: pass