You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This API can also be used as exception for VLM model since transformers support loading InternChatVL models via AutoModel API we support it via AutoModelForCausalLM API
178
179
Args:
179
180
pretrained_model_name_or_path (str): The name or path of the pre-trained model.
180
-
pooling (Optional[str], optional): The pooling method to use. Defaults to None.
181
-
Options:
182
-
- "mean": Mean pooling
183
-
- "max": Max pooling
184
-
- "cls": CLS token pooling
185
-
- "avg": Average pooling
181
+
pooling (Optional[Union[str, Callable]], optional): The pooling method to use. Defaults to None.
182
+
Options:
183
+
- "mean": Mean pooling
184
+
- "max": Max pooling
185
+
- "cls": CLS token pooling
186
+
- "avg": Average pooling
187
+
- Callable: A custom pooling function
188
+
- None: No pooling applied
186
189
187
190
.. code-block:: python
188
191
189
192
from QEfficient import QEFFAutoModel
190
193
from transformers import AutoTokenizer
191
194
192
195
# Initialize the model using from_pretrained similar to transformers.AutoModel.
193
-
model = QEFFAutoModel.from_pretrained("model_name")
196
+
model = QEFFAutoModel.from_pretrained("model_name", pooling="mean")
194
197
195
198
# Now you can directly compile the model for Cloud AI 100
196
199
model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU
@@ -308,6 +311,9 @@ def compile(
308
311
:str: Path of the compiled ``qpc`` package.
309
312
"""
310
313
314
+
ifisinstance(seq_len, list) andlen(seq_len) >=15:
315
+
warnings.warn("Recommended: `seq_len` should contain fewer than 15 items.")
0 commit comments