diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 3a48a579c2..81c0f9fd2f 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -80,6 +80,7 @@ from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME from .model_patcher import ( CLIPModelPatcher, + ColPaliModelPatcher, FalconModelPatcher, MgpstrModelPatcher, MistralModelPatcher, @@ -2621,3 +2622,48 @@ class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. + +class PaliGemmaOnnxConfig(GemmaOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyVisionInputGenerator) + NORMALIZED_CONFIG_CLASS = NormalizedTextAndVisionConfig.with_args( + text_config="text_config", vision_config="vision_config" + ) + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + dynamic_axis = {0: "batch_size", 1: "sequence_length"} + if self.task == "feature-extraction": + return { + "input_ids": dynamic_axis, + "attention_mask": dynamic_axis, + "pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, + } + elif self.task == "image-to-text": + return { + "input_ids": dynamic_axis, + "attention_mask": dynamic_axis, + } + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): + dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) + if framework == "pt": + if self.task == "feature-extraction": + generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0](self.task, self._normalized_config) + prefix_tensor = generator.constant_tensor( + shape=[dummy_inputs["input_ids"].shape[0], self._normalized_config.vision_config.num_image_tokens], + value=self._normalized_config.image_token_index, + framework=framework, + ) + dummy_inputs["input_ids"] = generator.concat_inputs([prefix_tensor, dummy_inputs["input_ids"]], dim=1) + dummy_inputs["attention_mask"] = generator.random_mask_tensor( + shape=[generator.batch_size, generator.sequence_length + self._normalized_config.vision_config.num_image_tokens], + padding_side=generator.padding_side, + framework=framework, + dtype="int64", + ) + return dummy_inputs + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + if self.task == "feature-extraction": + return ColPaliModelPatcher(self, model, model_kwargs=model_kwargs) + else: + return super().patch_model_for_export(model, model_kwargs=model_kwargs) \ No newline at end of file diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 083bc12799..bece9a0654 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -535,6 +535,24 @@ def patched_forward(*args, **kwargs): self.patched_forward = patched_forward +class ColPaliModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + + def patched_forward(input_ids=None, pixel_values=None, attention_mask=None, **kwargs): + outputs = self.orig_forward( + input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **kwargs + ) + return outputs + + self.patched_forward = patched_forward + + class SAMModelPatcher(ModelPatcher): def __init__( self, @@ -1180,3 +1198,4 @@ def __exit__(self, exc_type, exc_value, traceback): from transformers.models.clip.modeling_clip import CLIPSdpaAttention CLIPSdpaAttention.forward = self.original_sdpa_forward + diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 7cb5a31d2d..b726764f50 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -976,6 +976,11 @@ class TasksManager: "text-generation-with-past", onnx="GraniteOnnxConfig", ), + "paligemma": supported_tasks_mapping( + "feature-extraction", + "image-to-text", + onnx="PaliGemmaOnnxConfig", + ), "olmo": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 900b5f3b5c..57a79f2322 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -135,6 +135,7 @@ "opt": "hf-internal-testing/tiny-random-OPTModel", "owlv2": "hf-internal-testing/tiny-random-Owlv2Model", "owlvit": "hf-tiny-model-private/tiny-random-OwlViTModel", + "paligemma": {"hf-internal-testing/tiny-random-PaliGemmaForConditionalGeneration": ["image-to-text", "feature-extraction"]}, "pegasus": "hf-internal-testing/tiny-random-PegasusModel", "perceiver": { "hf-internal-testing/tiny-random-language_perceiver": ["fill-mask", "text-classification"],