diff --git a/metaflow/plugins/kubernetes/kube_utils.py b/metaflow/plugins/kubernetes/kube_utils.py new file mode 100644 index 00000000000..979b7b9d273 --- /dev/null +++ b/metaflow/plugins/kubernetes/kube_utils.py @@ -0,0 +1,25 @@ +from metaflow.exception import CommandException +from metaflow.util import get_username, get_latest_run_id + + +def parse_cli_options(flow_name, run_id, user, my_runs, echo): + if user and my_runs: + raise CommandException("--user and --my-runs are mutually exclusive.") + + if run_id and my_runs: + raise CommandException("--run_id and --my-runs are mutually exclusive.") + + if my_runs: + user = get_username() + + latest_run = True + + if user and not run_id: + latest_run = False + + if not run_id and latest_run: + run_id = get_latest_run_id(echo, flow_name) + if run_id is None: + raise CommandException("A previous run id was not found. Specify --run-id.") + + return flow_name, run_id, user diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index f572c6e09f2..3b5035d1f59 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -3,10 +3,12 @@ import time import traceback +from metaflow.plugins.kubernetes.kube_utils import parse_cli_options +from metaflow.plugins.kubernetes.kubernetes_client import KubernetesClient import metaflow.tracing as tracing from metaflow import JSONTypeClass, util from metaflow._vendor import click -from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException +from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, MetaflowException from metaflow.metadata.util import sync_local_metadata_from_datastore from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS from metaflow.mflog import TASK_LOG_SOURCE @@ -305,3 +307,84 @@ def _sync_metadata(): sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) finally: _sync_metadata() + + +@kubernetes.command(help="List unfinished Kubernetes tasks of this flow.") +@click.option( + "--my-runs", + default=False, + is_flag=True, + help="List all my unfinished tasks.", +) +@click.option("--user", default=None, help="List unfinished tasks for the given user.") +@click.option( + "--run-id", + default=None, + help="List unfinished tasks corresponding to the run id.", +) +@click.pass_obj +def list(obj, run_id, user, my_runs): + flow_name, run_id, user = parse_cli_options( + obj.flow.name, run_id, user, my_runs, obj.echo + ) + kube_client = KubernetesClient() + pods = kube_client.list(obj.flow.name, run_id, user) + + def format_timestamp(timestamp=None): + if timestamp is None: + return "-" + return timestamp.strftime("%Y-%m-%d %H:%M:%S") + + for pod in pods: + obj.echo( + "Run: *{run_id}* " + "Pod: *{pod_id}* " + "Started At: {startedAt} " + "Status: *{status}*".format( + run_id=pod.metadata.annotations.get( + "metaflow/run_id", + pod.metadata.labels.get("workflows.argoproj.io/workflow"), + ), + pod_id=pod.metadata.name, + startedAt=format_timestamp(pod.status.start_time), + status=pod.status.phase, + ) + ) + + if not pods: + obj.echo("No active Kubernetes pods found.") + + +@kubernetes.command( + help="Terminate unfinished Kubernetes tasks of this flow. Killed pods may result in newer attempts when using @retry." +) +@click.option( + "--my-runs", + default=False, + is_flag=True, + help="Kill all my unfinished tasks.", +) +@click.option( + "--user", + default=None, + help="Terminate unfinished tasks for the given user.", +) +@click.option( + "--run-id", + default=None, + help="Terminate unfinished tasks corresponding to the run id.", +) +@click.pass_obj +def kill(obj, run_id, user, my_runs): + flow_name, run_id, user = parse_cli_options( + obj.flow.name, run_id, user, my_runs, obj.echo + ) + + if run_id is not None and run_id.startswith("argo-") or user == "argo-workflows": + raise MetaflowException( + "Killing pods launched by Argo Workflows is not supported. " + "Use *argo-workflows terminate* instead." + ) + + kube_client = KubernetesClient() + kube_client.kill_pods(flow_name, run_id, user, obj.echo) diff --git a/metaflow/plugins/kubernetes/kubernetes_client.py b/metaflow/plugins/kubernetes/kubernetes_client.py index 2ad0388e359..3ec539394a8 100644 --- a/metaflow/plugins/kubernetes/kubernetes_client.py +++ b/metaflow/plugins/kubernetes/kubernetes_client.py @@ -1,8 +1,10 @@ +from concurrent.futures import ThreadPoolExecutor import os import sys import time from metaflow.exception import MetaflowException +from metaflow.metaflow_config import KUBERNETES_NAMESPACE from .kubernetes_job import KubernetesJob, KubernetesJobSet @@ -28,6 +30,7 @@ def __init__(self): % sys.executable ) self._refresh_client() + self._namespace = KUBERNETES_NAMESPACE def _refresh_client(self): from kubernetes import client, config @@ -60,6 +63,100 @@ def get(self): return self._client + def _find_active_pods(self, flow_name, run_id=None, user=None): + def _request(_continue=None): + # handle paginated responses + return self._client.CoreV1Api().list_namespaced_pod( + namespace=self._namespace, + # limited selector support for K8S api. We want to cover multiple statuses: Running / Pending / Unknown + field_selector="status.phase!=Succeeded,status.phase!=Failed", + limit=1000, + _continue=_continue, + ) + + results = _request() + + if run_id is not None: + # handle argo prefixes in run_id + run_id = run_id[run_id.startswith("argo-") and len("argo-") :] + + while results.metadata._continue or results.items: + for pod in results.items: + match = ( + # arbitrary pods might have no annotations at all. + pod.metadata.annotations + and pod.metadata.labels + and ( + run_id is None + or (pod.metadata.annotations.get("metaflow/run_id") == run_id) + # we want to also match pods launched by argo-workflows + or ( + pod.metadata.labels.get("workflows.argoproj.io/workflow") + == run_id + ) + ) + and ( + user is None + or pod.metadata.annotations.get("metaflow/user") == user + ) + and ( + pod.metadata.annotations.get("metaflow/flow_name") == flow_name + ) + ) + if match: + yield pod + if not results.metadata._continue: + break + results = _request(results.metadata._continue) + + def list(self, flow_name, run_id, user): + results = self._find_active_pods(flow_name, run_id, user) + + return list(results) + + def kill_pods(self, flow_name, run_id, user, echo): + from kubernetes.stream import stream + + api_instance = self._client.CoreV1Api() + job_api = self._client.BatchV1Api() + pods = self._find_active_pods(flow_name, run_id, user) + + def _kill_pod(pod): + echo("Killing Kubernetes pod %s\n" % pod.metadata.name) + try: + stream( + api_instance.connect_get_namespaced_pod_exec, + name=pod.metadata.name, + namespace=pod.metadata.namespace, + command=[ + "/bin/sh", + "-c", + "/sbin/killall5", + ], + stderr=True, + stdin=False, + stdout=True, + tty=False, + ) + except Exception: + # best effort kill for pod can fail. + try: + job_name = pod.metadata.labels.get("job-name", None) + if job_name is None: + raise Exception("Could not determine job name") + + job_api.patch_namespaced_job( + name=job_name, + namespace=pod.metadata.namespace, + field_manager="metaflow", + body={"spec": {"parallelism": 0}}, + ) + except Exception as e: + echo("failed to kill pod %s - %s" % (pod.metadata.name, str(e))) + + with ThreadPoolExecutor() as executor: + executor.map(_kill_pod, list(pods)) + def jobset(self, **kwargs): return KubernetesJobSet(self, **kwargs)