Skip to content

W&B Artifact Integration #1403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 83 additions & 5 deletions lib/python/src/bailo/helper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,16 @@
try:
import mlflow

ml_flow = True
mlflow_installed = True
except ImportError:
ml_flow = False
mlflow_installed = False

try:
import wandb

wandb_installed = True
except ImportError:
wandb_installed = False

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -123,7 +130,7 @@ def from_mlflow(
files: bool = True,
visibility: ModelVisibility | None = None,
) -> Model:
"""Import an MLFlow Model into Bailo.
"""Import a MLFlow Model into Bailo.

:param client: A client object used to interact with Bailo
:param mlflow_uri: MLFlow server URI
Expand All @@ -135,7 +142,7 @@ def from_mlflow(
:param visibility: Visibility of model on Bailo, using ModelVisibility enum (e.g Public or Private), defaults to None
:return: A model object
"""
if not ml_flow:
if not mlflow_installed:
raise ImportError("Optional MLFlow dependencies (needed for this method) are not installed.")

mlflow_client = mlflow.tracking.MlflowClient(tracking_uri=mlflow_uri)
Expand Down Expand Up @@ -199,6 +206,74 @@ def from_mlflow(
release.upload(mlflow_dir)
return model

@classmethod
def from_wandb(
cls,
client: Client,
team_id: str,
name: str,
schema_id: str | None = None,
files: bool = True,
visibility: ModelVisibility | None = None,
) -> Model:
"""Import a W&B artifact into Bailo as a Model.

:param client: A client object used to interact with Bailo
:param team_id: A unique team ID
:param name: Name of model (on W&B), format 'entity/project/name'. Latter will be used on Bailo
:param schema_id: A unique schema ID, only required when files is True, defaults to None
:param files: Import files?, defaults to True
:param visibility: Visibility of model on Bailo, using ModelVisibility enum (e.g Public or Private), defaults to None
:return: A model object

..note:: User must login to W&B first, either using wandb.login() or the equivalent CLI command
"""
# Define Api() object
if not wandb_installed:
raise ImportError("Optional W&B dependencies (needed for this method) are not installed.")
api = wandb.Api()

# Fetch artifact object
artifact = api.artifact(name=name)
bailo_name = name.split("/")[-1]
description = artifact.description
if description is None:
description = ""
description = description + " Imported from W&B."

# Create Model object and unpack
bailo_res = client.post_model(
name=bailo_name, kind=EntryKind.MODEL, description=description, team_id=team_id, visibility=visibility
)
model_id = bailo_res["model"]["id"]
logger.info(f"W&B model successfully imported to Bailo with ID %s.", model_id)

model = cls(
client=client,
model_id=model_id,
name=bailo_name,
description=description,
visibility=visibility,
)
model._unpack(bailo_res["model"])

# If files, download and upload artifacts
if files:
if schema_id is None:
raise BailoException(
"Unable to upload files to Bailo. schema_id argument is required in order to create a release."
)
model.card_from_schema(schema_id=schema_id)
release = model.create_release(version=Version("1.0.0"), notes=" ")

if len(artifact.file_count):
temp_dir = os.path.join(tempfile.gettempdir(), "wandb_model")
wandb_dir = os.path.join(temp_dir, f"wandb_{bailo_name}")
artifact.download(root=wandb_dir)
release.upload(wandb_dir)

return model

def update_model_card(self, model_card: dict[str, Any] | None = None) -> None:
"""Upload and retrieve any changes to the model card on Bailo.

Expand Down Expand Up @@ -421,7 +496,7 @@ def from_mlflow(self, tracking_uri: str, experiment_id: str):
:param experiment_id: MLFlow Tracking experiment ID
:raises ImportError: Import error if MLFlow not installed
"""
if not ml_flow:
if not mlflow_installed:
raise ImportError("Optional MLFlow dependencies (needed for this method) are not installed.")

client = mlflow.tracking.MlflowClient(tracking_uri=tracking_uri)
Expand Down Expand Up @@ -471,6 +546,9 @@ def from_mlflow(self, tracking_uri: str, experiment_id: str):

logger.info(f"Successfully imported MLFlow experiment %s.", experiment_id)

def from_wandb(experiment_id: str):
pass

def publish(
self,
mc_loc: str,
Expand Down
Loading