diff --git a/lib/python/src/bailo/helper/model.py b/lib/python/src/bailo/helper/model.py index e4367d92d..28be4b0b9 100644 --- a/lib/python/src/bailo/helper/model.py +++ b/lib/python/src/bailo/helper/model.py @@ -18,9 +18,16 @@ try: import mlflow - ml_flow = True + mlflow_installed = True except ImportError: - ml_flow = False + mlflow_installed = False + +try: + import wandb + + wandb_installed = True +except ImportError: + wandb_installed = False logger = logging.getLogger(__name__) @@ -123,7 +130,7 @@ def from_mlflow( files: bool = True, visibility: ModelVisibility | None = None, ) -> Model: - """Import an MLFlow Model into Bailo. + """Import a MLFlow Model into Bailo. :param client: A client object used to interact with Bailo :param mlflow_uri: MLFlow server URI @@ -135,7 +142,7 @@ def from_mlflow( :param visibility: Visibility of model on Bailo, using ModelVisibility enum (e.g Public or Private), defaults to None :return: A model object """ - if not ml_flow: + if not mlflow_installed: raise ImportError("Optional MLFlow dependencies (needed for this method) are not installed.") mlflow_client = mlflow.tracking.MlflowClient(tracking_uri=mlflow_uri) @@ -199,6 +206,74 @@ def from_mlflow( release.upload(mlflow_dir) return model + @classmethod + def from_wandb( + cls, + client: Client, + team_id: str, + name: str, + schema_id: str | None = None, + files: bool = True, + visibility: ModelVisibility | None = None, + ) -> Model: + """Import a W&B artifact into Bailo as a Model. + + :param client: A client object used to interact with Bailo + :param team_id: A unique team ID + :param name: Name of model (on W&B), format 'entity/project/name'. Latter will be used on Bailo + :param schema_id: A unique schema ID, only required when files is True, defaults to None + :param files: Import files?, defaults to True + :param visibility: Visibility of model on Bailo, using ModelVisibility enum (e.g Public or Private), defaults to None + :return: A model object + + ..note:: User must login to W&B first, either using wandb.login() or the equivalent CLI command + """ + # Define Api() object + if not wandb_installed: + raise ImportError("Optional W&B dependencies (needed for this method) are not installed.") + api = wandb.Api() + + # Fetch artifact object + artifact = api.artifact(name=name) + bailo_name = name.split("/")[-1] + description = artifact.description + if description is None: + description = "" + description = description + " Imported from W&B." + + # Create Model object and unpack + bailo_res = client.post_model( + name=bailo_name, kind=EntryKind.MODEL, description=description, team_id=team_id, visibility=visibility + ) + model_id = bailo_res["model"]["id"] + logger.info(f"W&B model successfully imported to Bailo with ID %s.", model_id) + + model = cls( + client=client, + model_id=model_id, + name=bailo_name, + description=description, + visibility=visibility, + ) + model._unpack(bailo_res["model"]) + + # If files, download and upload artifacts + if files: + if schema_id is None: + raise BailoException( + "Unable to upload files to Bailo. schema_id argument is required in order to create a release." + ) + model.card_from_schema(schema_id=schema_id) + release = model.create_release(version=Version("1.0.0"), notes=" ") + + if len(artifact.file_count): + temp_dir = os.path.join(tempfile.gettempdir(), "wandb_model") + wandb_dir = os.path.join(temp_dir, f"wandb_{bailo_name}") + artifact.download(root=wandb_dir) + release.upload(wandb_dir) + + return model + def update_model_card(self, model_card: dict[str, Any] | None = None) -> None: """Upload and retrieve any changes to the model card on Bailo. @@ -421,7 +496,7 @@ def from_mlflow(self, tracking_uri: str, experiment_id: str): :param experiment_id: MLFlow Tracking experiment ID :raises ImportError: Import error if MLFlow not installed """ - if not ml_flow: + if not mlflow_installed: raise ImportError("Optional MLFlow dependencies (needed for this method) are not installed.") client = mlflow.tracking.MlflowClient(tracking_uri=tracking_uri) @@ -471,6 +546,9 @@ def from_mlflow(self, tracking_uri: str, experiment_id: str): logger.info(f"Successfully imported MLFlow experiment %s.", experiment_id) + def from_wandb(experiment_id: str): + pass + def publish( self, mc_loc: str,