|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +import random |
| 4 | +import time |
| 5 | + |
| 6 | +import cv2 |
| 7 | +import numpy as np |
| 8 | +import torch |
| 9 | +import torch.optim as optim |
| 10 | +import yaml |
| 11 | +from determined.common.experimental import ModelVersion |
| 12 | +from determined.experimental import Determined |
| 13 | +from determined.pytorch import load_trial_from_checkpoint_path |
| 14 | +from google.cloud import storage |
| 15 | +from kserve import ( |
| 16 | + KServeClient, |
| 17 | + V1beta1InferenceService, |
| 18 | + V1beta1InferenceServiceSpec, |
| 19 | + V1beta1PredictorSpec, |
| 20 | + V1beta1TorchServeSpec, |
| 21 | + constants, |
| 22 | + utils, |
| 23 | +) |
| 24 | +from kubernetes import client |
| 25 | +from torch import nn |
| 26 | +from torch.utils.data import DataLoader, Dataset |
| 27 | +from torchvision import models, transforms |
| 28 | + |
| 29 | +# ===================================================================================== |
| 30 | + |
| 31 | + |
| 32 | +def parse_args(): |
| 33 | + parser = argparse.ArgumentParser(description="Deploy a model to KServe") |
| 34 | + parser.add_argument("--deployment-name", type=str, help="Name of the resulting KServe InferenceService") |
| 35 | + parser.add_argument("--gcs-model-bucket", type=str, help="GS Bucket name to use for storing model artifacts") |
| 36 | + return parser.parse_args() |
| 37 | + |
| 38 | + |
| 39 | +# ===================================================================================== |
| 40 | + |
| 41 | + |
| 42 | +def wait_for_deployment(KServe, k8s_namespace, deployment_name, model_name): |
| 43 | + while KServe.is_isvc_ready(deployment_name, namespace=k8s_namespace) == False: |
| 44 | + print(f"Inference Service '{deployment_name}' is NOT READY. Waiting...") |
| 45 | + time.sleep(5) |
| 46 | + print(f"Inference Service '{deployment_name}' in Namespace '{k8s_namespace}' is READY.") |
| 47 | + response = KServe.get(deployment_name, namespace=k8s_namespace) |
| 48 | + print( |
| 49 | + "Model " |
| 50 | + + model_name |
| 51 | + + " is " |
| 52 | + + str(response["status"]["modelStatus"]["states"]["targetModelState"]) |
| 53 | + + " and available at " |
| 54 | + + str(response["status"]["address"]["url"]) |
| 55 | + + " for predictions." |
| 56 | + ) |
| 57 | + |
| 58 | + |
| 59 | +# ===================================================================================== |
| 60 | + |
| 61 | + |
| 62 | +def get_version(client, model_name, model_version) -> ModelVersion: |
| 63 | + |
| 64 | + for version in client.get_model(model_name).get_versions(): |
| 65 | + if version.name == model_version: |
| 66 | + return version |
| 67 | + |
| 68 | + raise AssertionError(f"Version '{model_version}' not found inside model '{model_name}'") |
| 69 | + |
| 70 | + |
| 71 | +# ===================================================================================== |
| 72 | + |
| 73 | + |
| 74 | +def create_scriptmodule(det_master, det_user, det_pw, model_name, pach_id): |
| 75 | + |
| 76 | + print(f"Loading model version '{model_name}/{pach_id}' from master at '{det_master}...'") |
| 77 | + |
| 78 | + if os.environ["HOME"] == "/": |
| 79 | + os.environ["HOME"] = "/app" |
| 80 | + |
| 81 | + os.environ["SERVING_MODE"] = "true" |
| 82 | + |
| 83 | + start = time.time() |
| 84 | + client = Determined(master=det_master, user=det_user, password=det_pw) |
| 85 | + version = get_version(client, model_name, pach_id) |
| 86 | + checkpoint = version.checkpoint |
| 87 | + checkpoint_dir = checkpoint.download() |
| 88 | + trial = load_trial_from_checkpoint_path(checkpoint_dir, map_location=torch.device("cpu")) |
| 89 | + end = time.time() |
| 90 | + delta = end - start |
| 91 | + print(f"Checkpoint loaded in {delta} seconds.") |
| 92 | + |
| 93 | + print(f"Creating ScriptModule from Determined checkpoint...") |
| 94 | + |
| 95 | + # Create ScriptModule |
| 96 | + m = torch.jit.script(trial.model) |
| 97 | + |
| 98 | + # Save ScriptModule to file |
| 99 | + torch.jit.save(m, "scriptmodule.pt") |
| 100 | + print(f"ScriptModule created successfully.") |
| 101 | + |
| 102 | + |
| 103 | +# ===================================================================================== |
| 104 | + |
| 105 | + |
| 106 | +def create_mar_file(model_name, model_version): |
| 107 | + print(f"Creating .mar file for model '{model_name}'...") |
| 108 | + os.system( |
| 109 | + "torch-model-archiver --model-name %s --version %s --serialized-file ./scriptmodule.pt --handler ./brain_mri_handler.py --force" |
| 110 | + % (model_name, model_version) |
| 111 | + ) |
| 112 | + print(f"Created .mar file successfully.") |
| 113 | + |
| 114 | + |
| 115 | +# ===================================================================================== |
| 116 | + |
| 117 | + |
| 118 | +def create_properties_file(model_name, model_version): |
| 119 | + config_properties = """inference_address=http://0.0.0.0:8085 |
| 120 | +management_address=http://0.0.0.0:8081 |
| 121 | +metrics_address=http://0.0.0.0:8082 |
| 122 | +grpc_inference_port=7070 |
| 123 | +grpc_management_port=7071 |
| 124 | +enable_envvars_config=true |
| 125 | +install_py_dep_per_model=true |
| 126 | +enable_metrics_api=true |
| 127 | +metrics_format=prometheus |
| 128 | +NUM_WORKERS=1 |
| 129 | +number_of_netty_threads=4 |
| 130 | +job_queue_size=10 |
| 131 | +model_store=/mnt/models/model-store |
| 132 | +model_snapshot={"name":"startup.cfg","modelCount":1,"models":{"%s":{"%s":{"defaultVersion":true,"marName":"%s.mar","minWorkers":1,"maxWorkers":5,"batchSize":1,"maxBatchDelay":5000,"responseTimeout":120}}}}""" % ( |
| 133 | + model_name, |
| 134 | + model_version, |
| 135 | + model_name, |
| 136 | + ) |
| 137 | + |
| 138 | + conf_prop = open("config.properties", "w") |
| 139 | + n = conf_prop.write(config_properties) |
| 140 | + conf_prop.close() |
| 141 | + |
| 142 | + model_files = ["config.properties", str(model_name) + ".mar"] |
| 143 | + |
| 144 | + return model_files |
| 145 | + |
| 146 | + |
| 147 | +# ===================================================================================== |
| 148 | + |
| 149 | + |
| 150 | +def upload_model(model_name, files, bucket_name): |
| 151 | + print("Uploading model files to model repository in GCS bucket...") |
| 152 | + storage_client = storage.Client() |
| 153 | + bucket = storage_client.get_bucket(bucket_name) |
| 154 | + |
| 155 | + for file in files: |
| 156 | + if "config" in str(file): |
| 157 | + folder = "config" |
| 158 | + else: |
| 159 | + folder = "model-store" |
| 160 | + blob = bucket.blob(model_name + "/" + folder + "/" + file) |
| 161 | + blob.upload_from_filename("./" + file) |
| 162 | + |
| 163 | + print("Upload to GCS complete.") |
| 164 | + |
| 165 | + |
| 166 | +# ===================================================================================== |
| 167 | + |
| 168 | + |
| 169 | +def create_inference_service(kclient, k8s_namespace, model_name, deployment_name, pach_id, replace: bool): |
| 170 | + |
| 171 | + kserve_version = "v1beta1" |
| 172 | + api_version = constants.KSERVE_GROUP + "/" + kserve_version |
| 173 | + |
| 174 | + isvc = V1beta1InferenceService( |
| 175 | + api_version=api_version, |
| 176 | + kind=constants.KSERVE_KIND, |
| 177 | + metadata=client.V1ObjectMeta( |
| 178 | + name=deployment_name, |
| 179 | + namespace=k8s_namespace, |
| 180 | + annotations={"sidecar.istio.io/inject": "false", "pach_id": pach_id}, |
| 181 | + ), |
| 182 | + spec=V1beta1InferenceServiceSpec( |
| 183 | + predictor=V1beta1PredictorSpec( |
| 184 | + pytorch=( |
| 185 | + V1beta1TorchServeSpec(protocol_version="v2", storage_uri="gs://kserve-models/%s" % (model_name)) |
| 186 | + ) |
| 187 | + ) |
| 188 | + ), |
| 189 | + ) |
| 190 | + |
| 191 | + if replace: |
| 192 | + print(f"Replacing InferenceService with new version...") |
| 193 | + kclient.replace(deployment_name, isvc) |
| 194 | + print(f"InferenceService replaced with new version '{pach_id}'.") |
| 195 | + else: |
| 196 | + print(f"Creating KServe InferenceService for model '{model_name}'.") |
| 197 | + kclient.create(isvc) |
| 198 | + print(f"Inference Service '{deployment_name}' created.") |
| 199 | + |
| 200 | + |
| 201 | +# ===================================================================================== |
| 202 | + |
| 203 | + |
| 204 | +def check_existence(kclient, deployment_name, k8s_namespace): |
| 205 | + |
| 206 | + print(f"Checking if previous version of InferenceService '{deployment_name}' exists...") |
| 207 | + |
| 208 | + try: |
| 209 | + response = kclient.get(deployment_name, namespace=k8s_namespace) |
| 210 | + exists = True |
| 211 | + print(f"Previous version of InferenceService '{deployment_name}' exists.") |
| 212 | + except (RuntimeError): |
| 213 | + exists = False |
| 214 | + print(f"Previous version of InferenceService '{deployment_name}' does not exist.") |
| 215 | + |
| 216 | + return exists |
| 217 | + |
| 218 | + |
| 219 | +# ===================================================================================== |
| 220 | + |
| 221 | + |
| 222 | +class DeterminedInfo: |
| 223 | + def __init__(self): |
| 224 | + self.master = os.getenv("DET_MASTER") |
| 225 | + self.username = os.getenv("DET_USER") |
| 226 | + self.password = os.getenv("DET_PASSWORD") |
| 227 | + |
| 228 | + |
| 229 | +# ===================================================================================== |
| 230 | + |
| 231 | + |
| 232 | +class KServeInfo: |
| 233 | + def __init__(self): |
| 234 | + self.namespace = os.getenv("KSERVE_NAMESPACE") |
| 235 | + |
| 236 | + |
| 237 | +# ===================================================================================== |
| 238 | + |
| 239 | + |
| 240 | +class ModelInfo: |
| 241 | + def __init__(self, file): |
| 242 | + print(f"Reading model info file: {file}") |
| 243 | + info = {} |
| 244 | + with open(file, "r") as stream: |
| 245 | + try: |
| 246 | + info = yaml.safe_load(stream) |
| 247 | + |
| 248 | + self.name = info["name"] |
| 249 | + self.version = info["version"] |
| 250 | + self.pipeline = info["pipeline"] |
| 251 | + self.repository = info["repo"] |
| 252 | + |
| 253 | + print( |
| 254 | + f"Loaded model info: name='{self.name}', version='{self.version}', pipeline='{self.pipeline}', repo='{self.repository}'" |
| 255 | + ) |
| 256 | + except yaml.YAMLError as exc: |
| 257 | + print(exc) |
| 258 | + |
| 259 | + |
| 260 | +# ===================================================================================== |
| 261 | + |
| 262 | + |
| 263 | +def main(): |
| 264 | + args = parse_args() |
| 265 | + det = DeterminedInfo() |
| 266 | + ksrv = KServeInfo() |
| 267 | + model = ModelInfo("/pfs/data/model-info.yaml") |
| 268 | + |
| 269 | + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/determined_shared_fs/service-account.json" |
| 270 | + |
| 271 | + print(f"Starting pipeline: deploy-name='{args.deployment_name}', model='{model.name}', version='{model.version}'") |
| 272 | + |
| 273 | + # Pull Determined.AI Checkpoint, load it, and create ScriptModule (TorchScript) |
| 274 | + create_scriptmodule(det.master, det.username, det.password, model.name, model.version) |
| 275 | + |
| 276 | + # Create .mar file from ScriptModule |
| 277 | + create_mar_file(model.name, model.version) |
| 278 | + |
| 279 | + # Create config.properties for .mar file, return files to upload to GCS bucket |
| 280 | + model_files = create_properties_file(model.name, model.version) |
| 281 | + |
| 282 | + # Upload model artifacts to GCS bucket in the format for TorchServe |
| 283 | + upload_model(model.name, model_files, args.gcs_model_bucket) |
| 284 | + |
| 285 | + # Instantiate KServe Client using kubeconfig |
| 286 | + kclient = KServeClient(config_file="/determined_shared_fs/k8s.config") |
| 287 | + |
| 288 | + # Check if a previous version of the InferenceService exists (return true/false) |
| 289 | + replace = check_existence(kclient, args.deployment_name, ksrv.namespace) |
| 290 | + |
| 291 | + # Create or replace inference service |
| 292 | + create_inference_service(kclient, ksrv.namespace, model.name, args.deployment_name, model.version, replace) |
| 293 | + |
| 294 | + # Wait for InferenceService to be ready for predictions |
| 295 | + wait_for_deployment(kclient, ksrv.namespace, args.deployment_name, model.name) |
| 296 | + |
| 297 | + print(f"Ending pipeline: deploy-name='{args.deployment_name}', model='{model.name}', version='{model.version}'") |
| 298 | + |
| 299 | + |
| 300 | +# ===================================================================================== |
| 301 | + |
| 302 | + |
| 303 | +if __name__ == "__main__": |
| 304 | + main() |
0 commit comments