diff --git a/src/speculators/__init__.py b/src/speculators/__init__.py index 662bd57..87838c5 100644 --- a/src/speculators/__init__.py +++ b/src/speculators/__init__.py @@ -6,6 +6,8 @@ TokenProposal, TokenProposalConfig, VerifierConfig, + DraftModelType, + TokenProposalType, ) from .logging import configure_logger, logger from .settings import LoggingSettings, Settings, print_config, reload_settings, settings @@ -25,4 +27,6 @@ "print_config", "reload_settings", "settings", + "DraftModelType", + "TokenProposalType", ] diff --git a/src/speculators/base/__init__.py b/src/speculators/base/__init__.py index 0817627..d4d321f 100644 --- a/src/speculators/base/__init__.py +++ b/src/speculators/base/__init__.py @@ -4,6 +4,9 @@ TokenProposalConfig, VerifierConfig, speculators_config_version, + DraftModelType, + TokenProposalType, + AlgorithmType, ) from .objects import Drafter, SpeculatorModel, TokenProposal @@ -16,4 +19,7 @@ "TokenProposalConfig", "VerifierConfig", "speculators_config_version", + "DraftModelType", + "TokenProposalType", + "AlgorithmType", ] diff --git a/src/speculators/base/config.py b/src/speculators/base/config.py index 29330fd..646bd69 100644 --- a/src/speculators/base/config.py +++ b/src/speculators/base/config.py @@ -1,3 +1,4 @@ +from enum import Enum import json from pathlib import Path from typing import Any, Optional, Union @@ -12,10 +13,22 @@ "TokenProposalConfig", "VerifierConfig", "speculators_config_version", + "DraftModelType", + "TokenProposalType", + "AlgorithmType", ] speculators_config_version = "0.1.0" +class DraftModelType(str, Enum): + LLAMA_EAGLE = "LlamaEagle" + +class TokenProposalType(str, Enum): + GREEDY = "GREEDY" + +class AlgorithmType(str, Enum): + EAGLE = "EAGLE" + class DraftModelConfig(BaseModel): """ @@ -26,7 +39,7 @@ class DraftModelConfig(BaseModel): self-drafting (extensions of the verifier model). """ - type_: str = Field( + type_: DraftModelType = Field( description=( "The type of the speculator architecture. " "This must be a valid architecture name from the speculators library." @@ -55,7 +68,7 @@ class TokenProposalConfig(BaseModel): greedy, nucleus, token tree sampling, etc. """ - type_: str = Field( + type_: TokenProposalType = Field( description=( "The type of the token proposal algorithm. " "This must be a valid proposer from the speculators library." @@ -84,7 +97,7 @@ class VerifierConfig(BaseModel): "transformers library. Ex: `LlamaPreTrainedModel`" ) ) - model: str = Field( + model: Optional[str] = Field( description=( "The name of the verifier model. This must be a valid HF id/name or " "a path to a local model directory. Ex: `meta-llama/Llama-3.3-70B-Instruct`" diff --git a/src/speculators/base/objects.py b/src/speculators/base/objects.py index 234c8e1..dfb522e 100644 --- a/src/speculators/base/objects.py +++ b/src/speculators/base/objects.py @@ -243,6 +243,7 @@ def __init__( self.drafter = drafter self.verifier = verifier # need to make optional and attach later self.proposals = proposals + @abstractmethod @property @@ -254,7 +255,7 @@ def config(self) -> SpeculatorConfig: :return: The config object for the speculator model. """ - ... + return self._config def attach_verifier(self, verifier: Module): """ diff --git a/src/speculators/utils/convert.py b/src/speculators/utils/convert.py index c2c5598..3b8cca2 100644 --- a/src/speculators/utils/convert.py +++ b/src/speculators/utils/convert.py @@ -3,10 +3,18 @@ from torch.nn import Module +from transformers import AutoConfig from speculators.base import ( - SpeculatorConfig, SpeculatorModel, -) # will need to fix circular import + SpeculatorConfig, + DraftModelConfig, + DraftModelType, + TokenProposalConfig, + TokenProposalType, + AlgorithmType, + VerifierConfig, +) +import logging __all__ = [ "SpecDecodeLibraryFormats", @@ -20,7 +28,7 @@ SpecDecodeLibraryFormats = Literal["speculators", "eagle", "eagle2", "eagle3", "hass"] - +logger = logging.getLogger(__name__) def detect_model_format( source: Union[str, Path, Module], # noqa: ARG001 @@ -96,8 +104,41 @@ def from_eagle_format( function signature rather than keeping them in kwargs as needed. :return: The converted speculator model. """ - raise NotImplementedError("Eagle format conversion is not implemented yet.") - + # Load config if it's not already a dict + if not isinstance(config, dict): + config = AutoConfig.from_pretrained(source if config is None else config) + + # Build draft model config + draft_model_config = DraftModelConfig( + type_=DraftModelType.LLAMA_EAGLE, + inputs="model.layers[-1].*", + config=config, + ) + + # Set proposal config + proposal_config = TokenProposalConfig(type_=TokenProposalType.GREEDY) + + # Construct verifier config + verifier_config = VerifierConfig( + architecture=config.get("architectures"), + model=getattr(verifier, "source", None) if verifier else None, + ) + + # Build full speculator config + speculator_config = SpeculatorConfig( + speculators_algorithm=AlgorithmType.EAGLE, + draft_model=draft_model_config, + proposal_methods=proposal_config, + default_proposal_method=proposal_config.type_, + verifier=verifier_config, + ) + + speculator_model = SpeculatorModel.from_config(speculator_config) + logger.info( + "Returning speculator model from config. " + "Note: weights have not been loaded from disk yet." + ) + return speculator_model def from_eagle2_format( source: Union[str, Path, Module], # noqa: ARG001