Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: NVIDIA allow non-llama model registration #1859

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

raspawar
Copy link
Contributor

@raspawar raspawar commented Apr 2, 2025

What does this PR do?

Adds custom model registration functionality to NVIDIAInferenceAdapter which let's the inference happen on:

Example Usage:

from llama_stack.apis.models import Model, ModelType
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
_ = client.initialize()

client.models.register(
        model_id=model_name,
        model_type=ModelType.llm,
        provider_id="nvidia"
)

response = client.inference.chat_completion(
    model_id=model_name,
    messages=[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"Write a limerick about the wonders of GPU computing."}],
)

Test Plan

pytest tests/unit/providers/nvidia/test_supervised_fine_tuning.py 
========================================================== test session starts ===========================================================
platform linux -- Python 3.10.0, pytest-8.3.5, pluggy-1.5.0
rootdir: /home/ubuntu/llama-stack
configfile: pyproject.toml
plugins: anyio-4.9.0
collected 6 items                                                                                                                        

tests/unit/providers/nvidia/test_supervised_fine_tuning.py ......                                                                  [100%]

============================================================ warnings summary ============================================================
../miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076
  /home/ubuntu/miniconda/envs/nvidia-1/lib/python3.10/site-packages/pydantic/fields.py:1076: PydanticDeprecatedSince20: Using extra keyword arguments on `Field` is deprecated and will be removed. Use `json_schema_extra` instead. (Extra keys: 'contentEncoding'). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
    warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================================================== 6 passed, 1 warning in 1.51s ======================================================

Updated Readme.md

cc: @dglogo, @sumitb, @mattf

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 2, 2025
@raspawar raspawar changed the title add register_model method feature: NVIDIA allow non-llama model registration Apr 2, 2025
@raspawar raspawar changed the title feature: NVIDIA allow non-llama model registration feat: NVIDIA allow non-llama model registration Apr 2, 2025
@raspawar raspawar force-pushed the register_custom_model branch 2 times, most recently from 4032efa to 27a1657 Compare April 2, 2025 10:14
Copy link
Contributor

@mattf mattf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for adding this. few comments inline for you.

if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
base_url = special_model_urls[provider_model_id]

# add /v1 in case of hosted models
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a behavior change.

current behavior: always add /v1
new behavior: add /v1 for hosted and don't for non-hosted

the behavior should be consistent between hosted and non-hosted. for instance, a user should not need to know they're talking to https://integrate.api.nv.c and therefore don't need to supply the /v1 or since they're talking to http://localhost that they do need to provide the /v1.

is there an issue w/ /v1 and customizer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding /v1 may produce errors, specially when user is specifying NVIDIA_BASE_URL
Can we remove the /v1 entirely? And add it in default base_url, what do we miss in that case?
As some models endpoints on API catalogue follow /chat/completion.

NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
"""
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be handled within the provider model id function

if provider_resource_id:
model.provider_resource_id = provider_resource_id
else:
llama_model = model.metadata.get("llama_model")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i believe model is a https://github.com/meta-llama/llama-stack/blob/main/llama_stack/apis/models/models.py#L31, which does not have a metadata

https://github.com/meta-llama/llama-stack/blob/main/llama_stack/models/llama/datatypes.py#L346 has a metadata, and confusingly the same class name.

suggestion: trust the input config. it should fail at inference if it's incorrect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is base register_model() logic: https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/utils/inference/model_registry.py#L76

I modified only parts to allow non-llama models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants