Skip to content

[WIP] fill: from_eagle_format w/o weight loading #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/speculators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
TokenProposal,
TokenProposalConfig,
VerifierConfig,
DraftModelType,
TokenProposalType,
)
from .logging import configure_logger, logger
from .settings import LoggingSettings, Settings, print_config, reload_settings, settings
Expand All @@ -25,4 +27,6 @@
"print_config",
"reload_settings",
"settings",
"DraftModelType",
"TokenProposalType",
]
6 changes: 6 additions & 0 deletions src/speculators/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
TokenProposalConfig,
VerifierConfig,
speculators_config_version,
DraftModelType,
TokenProposalType,
AlgorithmType,
)
from .objects import Drafter, SpeculatorModel, TokenProposal

Expand All @@ -16,4 +19,7 @@
"TokenProposalConfig",
"VerifierConfig",
"speculators_config_version",
"DraftModelType",
"TokenProposalType",
"AlgorithmType",
]
19 changes: 16 additions & 3 deletions src/speculators/base/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
import json
from pathlib import Path
from typing import Any, Optional, Union
Expand All @@ -12,10 +13,22 @@
"TokenProposalConfig",
"VerifierConfig",
"speculators_config_version",
"DraftModelType",
"TokenProposalType",
"AlgorithmType",
]

speculators_config_version = "0.1.0"

class DraftModelType(str, Enum):
LLAMA_EAGLE = "LlamaEagle"

class TokenProposalType(str, Enum):
GREEDY = "GREEDY"

class AlgorithmType(str, Enum):
EAGLE = "EAGLE"


class DraftModelConfig(BaseModel):
"""
Expand All @@ -26,7 +39,7 @@ class DraftModelConfig(BaseModel):
self-drafting (extensions of the verifier model).
"""

type_: str = Field(
type_: DraftModelType = Field(
description=(
"The type of the speculator architecture. "
"This must be a valid architecture name from the speculators library."
Expand Down Expand Up @@ -55,7 +68,7 @@ class TokenProposalConfig(BaseModel):
greedy, nucleus, token tree sampling, etc.
"""

type_: str = Field(
type_: TokenProposalType = Field(
description=(
"The type of the token proposal algorithm. "
"This must be a valid proposer from the speculators library."
Expand Down Expand Up @@ -84,7 +97,7 @@ class VerifierConfig(BaseModel):
"transformers library. Ex: `LlamaPreTrainedModel`"
)
)
model: str = Field(
model: Optional[str] = Field(
description=(
"The name of the verifier model. This must be a valid HF id/name or "
"a path to a local model directory. Ex: `meta-llama/Llama-3.3-70B-Instruct`"
Expand Down
3 changes: 2 additions & 1 deletion src/speculators/base/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def __init__(
self.drafter = drafter
self.verifier = verifier # need to make optional and attach later
self.proposals = proposals


@abstractmethod
@property
Expand All @@ -254,7 +255,7 @@ def config(self) -> SpeculatorConfig:

:return: The config object for the speculator model.
"""
...
return self._config

def attach_verifier(self, verifier: Module):
"""
Expand Down
51 changes: 46 additions & 5 deletions src/speculators/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,18 @@

from torch.nn import Module

from transformers import AutoConfig
from speculators.base import (
SpeculatorConfig,
SpeculatorModel,
) # will need to fix circular import
SpeculatorConfig,
DraftModelConfig,
DraftModelType,
TokenProposalConfig,
TokenProposalType,
AlgorithmType,
VerifierConfig,
)
import logging

__all__ = [
"SpecDecodeLibraryFormats",
Expand All @@ -20,7 +28,7 @@


SpecDecodeLibraryFormats = Literal["speculators", "eagle", "eagle2", "eagle3", "hass"]

logger = logging.getLogger(__name__)

def detect_model_format(
source: Union[str, Path, Module], # noqa: ARG001
Expand Down Expand Up @@ -96,8 +104,41 @@ def from_eagle_format(
function signature rather than keeping them in kwargs as needed.
:return: The converted speculator model.
"""
raise NotImplementedError("Eagle format conversion is not implemented yet.")

# Load config if it's not already a dict
if not isinstance(config, dict):
config = AutoConfig.from_pretrained(source if config is None else config)

# Build draft model config
draft_model_config = DraftModelConfig(
type_=DraftModelType.LLAMA_EAGLE,
inputs="model.layers[-1].*",
config=config,
)

# Set proposal config
proposal_config = TokenProposalConfig(type_=TokenProposalType.GREEDY)

# Construct verifier config
verifier_config = VerifierConfig(
architecture=config.get("architectures"),
model=getattr(verifier, "source", None) if verifier else None,
)

# Build full speculator config
speculator_config = SpeculatorConfig(
speculators_algorithm=AlgorithmType.EAGLE,
draft_model=draft_model_config,
proposal_methods=proposal_config,
default_proposal_method=proposal_config.type_,
verifier=verifier_config,
)

speculator_model = SpeculatorModel.from_config(speculator_config)
logger.info(
"Returning speculator model from config. "
"Note: weights have not been loaded from disk yet."
)
return speculator_model

def from_eagle2_format(
source: Union[str, Path, Module], # noqa: ARG001
Expand Down
Loading