diff --git a/src/pipelines.js b/src/pipelines.js index afb627a4a..8b1b663b6 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -3399,7 +3399,7 @@ export async function pipeline( revision = 'main', device = null, dtype = null, - subfolder = 'onnx', + subfolder = null, use_external_data_format = null, model_file_name = null, session_options = {}, @@ -3466,6 +3466,7 @@ export async function pipeline( * @private */ async function loadItems(mapping, model, pretrainedOptions) { + const { subfolder, ...rest } = pretrainedOptions; const result = Object.create(null); @@ -3474,6 +3475,8 @@ async function loadItems(mapping, model, pretrainedOptions) { for (const [name, cls] of mapping.entries()) { if (!cls) continue; + const options = name === 'model' ? { ...rest, subfolder: subfolder ?? 'onnx' } : pretrainedOptions; + /**@type {Promise} */ let promise; if (Array.isArray(cls)) { @@ -3487,7 +3490,7 @@ async function loadItems(mapping, model, pretrainedOptions) { return; } try { - resolve(await c.from_pretrained(model, pretrainedOptions)); + resolve(await c.from_pretrained(model, options)); return; } catch (err) { if (err.message?.includes('Unsupported model type')) { @@ -3506,7 +3509,7 @@ async function loadItems(mapping, model, pretrainedOptions) { reject(e); }) } else { - promise = cls.from_pretrained(model, pretrainedOptions); + promise = cls.from_pretrained(model, options); } result[name] = promise; diff --git a/src/tokenizers.js b/src/tokenizers.js index 83e33cc52..7cb191b79 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -65,8 +65,8 @@ import { async function loadTokenizer(pretrained_model_name_or_path, options) { const info = await Promise.all([ - getModelJSON(pretrained_model_name_or_path, 'tokenizer.json', true, options), - getModelJSON(pretrained_model_name_or_path, 'tokenizer_config.json', true, options), + getModelJSON(pretrained_model_name_or_path, `${options.subfolder || ''}/tokenizer.json`, true, options), + getModelJSON(pretrained_model_name_or_path, `${options.subfolder || ''}/tokenizer_config.json`, true, options), ]) // Override legacy option if `options.legacy` is not null