Skip to content

Small fixes to get things working on Kubeflow #509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 31, 2022
49 changes: 22 additions & 27 deletions dask_kubernetes/classic/tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import dask
import pytest
from dask.distributed import Client, wait
from distributed.utils_test import loop, captured_logger # noqa: F401
from distributed.utils_test import captured_logger
from dask.utils import tmpfile

from dask_kubernetes import KubeCluster, make_pod_spec
Expand Down Expand Up @@ -75,17 +75,17 @@ def test_ipython_display(cluster):
sleep(0.5)


def test_env(pod_spec, loop):
with KubeCluster(pod_spec, env={"ABC": "DEF"}, loop=loop) as cluster:
def test_env(pod_spec):
with KubeCluster(pod_spec, env={"ABC": "DEF"}) as cluster:
cluster.scale(1)
with Client(cluster, loop=loop) as client:
with Client(cluster) as client:
while not cluster.scheduler_info["workers"]:
sleep(0.1)
env = client.run(lambda: dict(os.environ))
assert all(v["ABC"] == "DEF" for v in env.values())


def dont_test_pod_template_yaml(docker_image, loop):
def dont_test_pod_template_yaml(docker_image):
test_yaml = {
"kind": "Pod",
"metadata": {"labels": {"app": "dask", "component": "dask-worker"}},
Expand All @@ -109,9 +109,9 @@ def dont_test_pod_template_yaml(docker_image, loop):
with tmpfile(extension="yaml") as fn:
with open(fn, mode="w") as f:
yaml.dump(test_yaml, f)
with KubeCluster(f.name, loop=loop) as cluster:
with KubeCluster(f.name) as cluster:
cluster.scale(2)
with Client(cluster, loop=loop) as client:
with Client(cluster) as client:
future = client.submit(lambda x: x + 1, 10)
result = future.result(timeout=10)
assert result == 11
Expand All @@ -128,7 +128,7 @@ def dont_test_pod_template_yaml(docker_image, loop):
assert all(client.has_what().values())


def test_pod_template_yaml_expand_env_vars(docker_image, loop):
def test_pod_template_yaml_expand_env_vars(docker_image):
try:
os.environ["FOO_IMAGE"] = docker_image

Expand All @@ -155,13 +155,13 @@ def test_pod_template_yaml_expand_env_vars(docker_image, loop):
with tmpfile(extension="yaml") as fn:
with open(fn, mode="w") as f:
yaml.dump(test_yaml, f)
with KubeCluster(f.name, loop=loop) as cluster:
with KubeCluster(f.name) as cluster:
assert cluster.pod_template.spec.containers[0].image == docker_image
finally:
del os.environ["FOO_IMAGE"]


def test_pod_template_dict(docker_image, loop):
def test_pod_template_dict(docker_image):
spec = {
"metadata": {},
"restartPolicy": "Never",
Expand All @@ -185,9 +185,9 @@ def test_pod_template_dict(docker_image, loop):
},
}

with KubeCluster(spec, loop=loop) as cluster:
with KubeCluster(spec) as cluster:
cluster.scale(2)
with Client(cluster, loop=loop) as client:
with Client(cluster) as client:
future = client.submit(lambda x: x + 1, 10)
result = future.result()
assert result == 11
Expand All @@ -202,7 +202,7 @@ def test_pod_template_dict(docker_image, loop):
assert all(client.has_what().values())


def test_pod_template_minimal_dict(docker_image, loop):
def test_pod_template_minimal_dict(docker_image):
spec = {
"spec": {
"containers": [
Expand All @@ -224,9 +224,9 @@ def test_pod_template_minimal_dict(docker_image, loop):
}
}

with KubeCluster(spec, loop=loop) as cluster:
with KubeCluster(spec) as cluster:
cluster.adapt()
with Client(cluster, loop=loop) as client:
with Client(cluster) as client:
future = client.submit(lambda x: x + 1, 10)
result = future.result()
assert result == 11
Expand Down Expand Up @@ -264,9 +264,9 @@ def test_bad_args():
KubeCluster({"kind": "Pod"})


def test_constructor_parameters(pod_spec, loop):
def test_constructor_parameters(pod_spec):
env = {"FOO": "BAR", "A": 1}
with KubeCluster(pod_spec, name="myname", loop=loop, env=env) as cluster:
with KubeCluster(pod_spec, name="myname", env=env) as cluster:
pod = cluster.pod_template

var = [v for v in pod.spec.containers[0].env if v.name == "FOO"]
Expand Down Expand Up @@ -380,15 +380,14 @@ def test_maximum(cluster):
assert "scale beyond maximum number of workers" in result.lower()


def test_extra_pod_config(docker_image, loop):
def test_extra_pod_config(docker_image):
"""
Test that our pod config merging process works fine
"""
with KubeCluster(
make_pod_spec(
docker_image, extra_pod_config={"automountServiceAccountToken": False}
),
loop=loop,
n_workers=0,
) as cluster:

Expand All @@ -397,7 +396,7 @@ def test_extra_pod_config(docker_image, loop):
assert pod.spec.automount_service_account_token is False


def test_extra_container_config(docker_image, loop):
def test_extra_container_config(docker_image):
"""
Test that our container config merging process works fine
"""
Expand All @@ -409,7 +408,6 @@ def test_extra_container_config(docker_image, loop):
"securityContext": {"runAsUser": 0},
},
),
loop=loop,
n_workers=0,
) as cluster:

Expand All @@ -419,15 +417,14 @@ def test_extra_container_config(docker_image, loop):
assert pod.spec.containers[0].security_context == {"runAsUser": 0}


def test_container_resources_config(docker_image, loop):
def test_container_resources_config(docker_image):
"""
Test container resource requests / limits being set properly
"""
with KubeCluster(
make_pod_spec(
docker_image, memory_request="0.5G", memory_limit="1G", cpu_limit="1"
),
loop=loop,
n_workers=0,
) as cluster:

Expand All @@ -439,7 +436,7 @@ def test_container_resources_config(docker_image, loop):
assert "cpu" not in pod.spec.containers[0].resources.requests


def test_extra_container_config_merge(docker_image, loop):
def test_extra_container_config_merge(docker_image):
"""
Test that our container config merging process works recursively fine
"""
Expand All @@ -452,7 +449,6 @@ def test_extra_container_config_merge(docker_image, loop):
"args": ["last-item"],
},
),
loop=loop,
n_workers=0,
) as cluster:

Expand All @@ -464,7 +460,7 @@ def test_extra_container_config_merge(docker_image, loop):
assert pod.spec.containers[0].args[-1] == "last-item"


def test_worker_args(docker_image, loop):
def test_worker_args(docker_image):
"""
Test that dask-worker arguments are added to the container args
"""
Expand All @@ -474,7 +470,6 @@ def test_worker_args(docker_image, loop):
memory_limit="5000M",
resources="FOO=1 BAR=2",
),
loop=loop,
n_workers=0,
) as cluster:

Expand Down
26 changes: 22 additions & 4 deletions dask_kubernetes/common/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from weakref import finalize

import kubernetes_asyncio as kubernetes
from tornado.iostream import StreamClosedError

from distributed.core import rpc

from .utils import check_dependency

Expand All @@ -15,7 +18,7 @@ async def get_external_address_for_scheduler_service(
service,
port_forward_cluster_ip=None,
service_name_resolution_retries=20,
port_name="comm",
port_name="tcp-comm",
):
"""Take a service object and return the scheduler address."""
[port] = [
Expand Down Expand Up @@ -108,7 +111,7 @@ async def port_forward_dashboard(service_name, namespace):
return port


async def get_scheduler_address(service_name, namespace, port_name="comm"):
async def get_scheduler_address(service_name, namespace, port_name="tcp-comm"):
async with kubernetes.client.api_client.ApiClient() as api_client:
api = kubernetes.client.CoreV1Api(api_client)
service = await api.read_namespaced_service(service_name, namespace)
Expand All @@ -132,6 +135,21 @@ async def wait_for_scheduler(cluster_name, namespace):
label_selector=f"dask.org/cluster-name={cluster_name},dask.org/component=scheduler",
timeout_seconds=60,
):
if event["object"].status.phase == "Running":
watch.stop()
if event["object"].status.conditions:
conditions = {
c.type: c.status for c in event["object"].status.conditions
}
if "Ready" in conditions and conditions["Ready"] == "True":
watch.stop()
await asyncio.sleep(0.1)


async def wait_for_scheduler_comm(address):
while True:
try:
async with rpc(address) as scheduler_comm:
await scheduler_comm.versions()
except (StreamClosedError, OSError):
await asyncio.sleep(0.1)
continue
break
50 changes: 29 additions & 21 deletions dask_kubernetes/experimental/kubecluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
)

from dask_kubernetes.common.auth import ClusterAuth
from dask_kubernetes.common.utils import namespace_default
from dask_kubernetes.operator import (
build_cluster_spec,
wait_for_service,
)

from dask_kubernetes.common.networking import (
get_scheduler_address,
port_forward_dashboard,
wait_for_scheduler,
wait_for_scheduler_comm,
)


Expand Down Expand Up @@ -121,7 +122,7 @@ class KubeCluster(Cluster):
def __init__(
self,
name,
namespace="default",
namespace=None,
image="ghcr.io/dask/dask:latest",
n_workers=3,
resources={},
Expand All @@ -133,8 +134,7 @@ def __init__(
**kwargs,
):
self.name = name
# TODO: Set namespace to None and get default namespace from user's context
self.namespace = namespace
self.namespace = namespace or namespace_default()
self.image = image
self.n_workers = n_workers
self.resources = resources
Expand Down Expand Up @@ -208,10 +208,15 @@ async def _create_cluster(self):
) from e
await wait_for_scheduler(cluster_name, self.namespace)
await wait_for_service(core_api, f"{cluster_name}-service", self.namespace)
self.scheduler_comm = rpc(await self._get_scheduler_address())
self.forwarded_dashboard_port = await port_forward_dashboard(
f"{self.name}-cluster-service", self.namespace
scheduler_address = await self._get_scheduler_address()
await wait_for_scheduler_comm(scheduler_address)
self.scheduler_comm = rpc(scheduler_address)
dashboard_address = await get_scheduler_address(
f"{self.name}-cluster-service",
self.namespace,
port_name="http-dashboard",
)
self.forwarded_dashboard_port = dashboard_address.split(":")[-1]

async def _connect_cluster(self):
if self.shutdown_on_close is None:
Expand All @@ -230,10 +235,15 @@ async def _connect_cluster(self):
service_name = f'{cluster_spec["metadata"]["name"]}-service'
await wait_for_scheduler(self.cluster_name, self.namespace)
await wait_for_service(core_api, service_name, self.namespace)
self.scheduler_comm = rpc(await self._get_scheduler_address())
self.forwarded_dashboard_port = await port_forward_dashboard(
f"{self.name}-cluster-service", self.namespace
scheduler_address = await self._get_scheduler_address()
await wait_for_scheduler_comm(scheduler_address)
self.scheduler_comm = rpc(scheduler_address)
dashboard_address = await get_scheduler_address(
service_name,
self.namespace,
port_name="http-dashboard",
)
self.forwarded_dashboard_port = dashboard_address.split(":")[-1]

async def _get_cluster(self):
async with kubernetes.client.api_client.ApiClient() as api_client:
Expand Down Expand Up @@ -465,30 +475,28 @@ def _build_scheduler_spec(self, cluster_name):
{
"name": "scheduler",
"image": self.image,
"args": [
"dask-scheduler",
],
"args": ["dask-scheduler", "--host", "0.0.0.0"],
"env": env,
"resources": self.resources,
"ports": [
{
"name": "comm",
"name": "tcp-comm",
"containerPort": 8786,
"protocol": "TCP",
},
{
"name": "dashboard",
"name": "http-dashboard",
"containerPort": 8787,
"protocol": "TCP",
},
],
"readinessProbe": {
"tcpSocket": {"port": "comm"},
"httpGet": {"port": "http-dashboard", "path": "/health"},
"initialDelaySeconds": 5,
"periodSeconds": 10,
},
"livenessProbe": {
"tcpSocket": {"port": "comm"},
"httpGet": {"port": "http-dashboard", "path": "/health"},
"initialDelaySeconds": 15,
"periodSeconds": 20,
},
Expand All @@ -503,16 +511,16 @@ def _build_scheduler_spec(self, cluster_name):
},
"ports": [
{
"name": "comm",
"name": "tcp-comm",
"protocol": "TCP",
"port": 8786,
"targetPort": "comm",
"targetPort": "tcp-comm",
},
{
"name": "dashboard",
"name": "http-dashboard",
"protocol": "TCP",
"port": 8787,
"targetPort": "dashboard",
"targetPort": "http-dashboard",
},
],
},
Expand Down
4 changes: 2 additions & 2 deletions dask_kubernetes/kubernetes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ kubernetes:
dask.org/cluster-name: "" # Cluster name will be added automatically
dask.org/component: scheduler
ports:
- name: comm
- name: tcp-comm
protocol: TCP
port: 8786
targetPort: 8786
- name: dashboard
- name: http-dashboard
protocol: TCP
port: 8787
targetPort: 8787
Expand Down
Loading