Source code for serverless_gpu.ray

"""Ray integration for distributed serverless GPU compute.

This module provides integration with Ray for distributed computing on serverless
GPU infrastructure. It includes:

- Ray cluster setup and management on distributed GPU nodes
- Integration with serverless GPU launcher for Ray workloads
- Utilities for Ray head node detection and connection management
- Support for Ray distributed training and inference patterns

The module enables users to run Ray-based distributed workloads on serverless
GPU compute resources seamlessly.
"""

import os
from typing import Optional, Union
import subprocess
import uuid
from serverless_gpu import runtime as rt
from serverless_gpu.launcher import distributed
from serverless_gpu.utils import get_default_pkl_dir
from serverless_gpu.compute import GPUType
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import ResourceDoesNotExist
import logging
import time
import socket

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
log = logging.getLogger(__name__)


def _check_port(host, port, timeout=5):
    try:
        with socket.create_connection((host, port), timeout=timeout):
            return True
    except (socket.timeout, socket.error):
        return False


def _wait_for_port_open(host, port, timeout=60, check_interval=2):
    """
    Wait for Ray cluster to be ready by checking if we can connect.
    """
    import ray  # noqa: F401

    start_time = time.time()

    while time.time() - start_time < timeout:
        if _check_port(host, port):
            log.info(f"Ray service is available at {host}:{port}")
            time.sleep(3)
            return True
        log.debug(f"Ray service not yet ready at {host}:{port}")
        time.sleep(check_interval)

    log.error(f"Ray service failed to start at {host}:{port} within {timeout} seconds")
    return False


def _robust_ray_init(address, max_retries=3, retry_delay=5):
    """
    Robustly initialize Ray with retry logic and proper error handling.
    """
    import ray

    for attempt in range(max_retries):
        try:
            log.info(f"Attempting to connect to Ray cluster at {address} (attempt {attempt + 1}/{max_retries})")

            if ray.is_initialized():
                ray.shutdown()

            ray.init(address=address, ignore_reinit_error=True)

            if ray.is_initialized():
                log.info(f"Successfully connected to Ray cluster at {address}")
                return True
            else:
                log.warning("Ray init succeeded but cluster not properly initialized")

        except Exception as e:
            log.warning(f"Ray init attempt {attempt + 1} failed: {e}")
            try:
                ray.shutdown()
            except Exception:
                pass

            if attempt < max_retries - 1:
                log.info(f"Retrying in {retry_delay} seconds...")
                time.sleep(retry_delay)
            else:
                log.error(f"Failed to initialize Ray after {max_retries} attempts")

    return False


def _initialize_ray_cluster() -> Optional[str]:
    """
    Initialize a ray cluster across all nodes using dist communications.
    """
    from serverless_gpu.utils.sgc_dist import initialize_dist, all_gather_object, barrier
    from serverless_gpu.consts import _RAY_METRICS_EXPORT_PORT
    import torch

    head_proc = worker_proc = None
    head_failed = worker_failed = False

    # Stop any existing Ray cluster first
    try:
        log.info("Stopping any existing Ray cluster...")
        result = subprocess.run(
            "ray stop --force", shell=True, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30
        )
        if result.returncode == 0:
            log.info("Successfully stopped existing Ray cluster")
        else:
            log.warning(f"Ray stop returned code {result.returncode}")

        # Wait for Ray to shut down cleanly
        time.sleep(5)
    except subprocess.TimeoutExpired:
        log.warning("Ray stop command timed out")
    except Exception as e:
        log.warning(f"Error stopping Ray cluster: {e}")

    initialize_dist()

    # Get IP address with better error handling
    command = "ip addr show eth0 | grep 'inet ' | awk '{print $2}' | cut -d/ -f1"
    try:
        result = subprocess.run(command, shell=True, capture_output=True, text=True, timeout=10)
        if result.returncode != 0:
            raise RuntimeError(f"Failed to get IP address: {result.stderr}")
        ip_address = result.stdout.strip()
        if not ip_address:
            raise RuntimeError("Empty IP address returned")
        log.info(f"Got IP address: {ip_address}")
    except Exception as e:
        log.error(f"Failed to get IP address: {e}")
        raise

    head_ip_address = all_gather_object(ip_address)[0]
    log.info(f"Head node IP: {head_ip_address}")

    try:
        if rt.get_local_rank() == 0 and rt.get_global_rank() == 0:
            log.info("Starting Ray head node...")
            head_proc = subprocess.Popen(
                f"ray start --head --metrics-export-port={_RAY_METRICS_EXPORT_PORT} --block",
                shell=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
            )

            # Wait for head node to be ready
            if not _wait_for_port_open(head_ip_address, port="6379", timeout=60):
                log.error("Head node failed to start within timeout")
                head_failed = True

            # Need to set `RAY_ADDRESS` environment variable so that
            # `ray.init()` can connect to the started Ray cluster.
            os.environ["RAY_ADDRESS"] = f"{head_ip_address}:6379"

            if not _robust_ray_init("auto"):
                log.error("Failed to initialize Ray on head node")
                head_failed = True

    except Exception as e:
        log.error(f"Exception while launching Ray head: {e}")
        head_failed = True

    head_status = not head_failed
    head_status_all = all_gather_object(head_status)

    if not all(head_status_all):
        log.error("Ray head node failed detected on one or more workers!")
        return None

    barrier()

    try:
        if rt.get_local_rank() == 0 and rt.get_global_rank() != 0:
            log.info(f"Starting Ray worker node, connecting to {head_ip_address}:6379...")
            worker_proc = subprocess.Popen(
                f"ray start --address {head_ip_address}:6379 --metrics-export-port={_RAY_METRICS_EXPORT_PORT} --block",
                shell=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
            )

            # Wait for worker to connect
            if not _wait_for_port_open(head_ip_address, port="6379", timeout=60):
                log.error("Worker node failed to connect within timeout")
                worker_failed = True

            if not _robust_ray_init(f"{head_ip_address}:6379"):
                log.error("Failed to initialize Ray on worker node")
                worker_failed = True
    except Exception as e:
        log.error(f"Exception while launching Ray worker: {e}")
        worker_failed = True

    worker_status = not worker_failed
    worker_status_all = all_gather_object(worker_status)
    if not all(worker_status_all):
        log.error("Ray worker node failed on one or more nodes!")
        return None

    barrier()

    try:
        torch.distributed.destroy_process_group()
        log.info("Ray cluster initialization completed successfully")
        return head_ip_address

    except Exception as e:
        log.error(f"Ray cluster initialization failed: {e}")
        # Clean up processes
        for proc in [head_proc, worker_proc]:
            if proc:
                try:
                    proc.terminate()
                    proc.wait(timeout=10)
                except Exception:
                    proc.kill()
        return None


def _check_ray_installed():
    try:
        import ray  # noqa: F401
    except ImportError:
        raise ImportError("Ray is not installed. Please install ray to use this decorator")


[docs] def ray_launch( gpus: int, gpu_type: Optional[Union[GPUType, str]] = None, remote: bool = False, run_async: bool = False ): """ Experimental decorator to launch function with a ray cluster. Args: gpus (int): Number of gpus to launch ray on. gpu_type (Optional[Union[GPUType, str]]): The GPU type to use. Defaults to None. Required if remote is True. remote (bool): Use remote gpus. run_async (bool): Whether to run the function asynchronously. Defaults to False. """ _check_ray_installed() enable_ray_metrics_logging = ( remote and os.environ.get("DATABRICKS_INTERNAL_ENABLE_SKYRUN_RAY_SYSTEM_METRICS_LOGGING", "true").lower() == "true" ) def decorator(func): id = str(uuid.uuid4()) done_file = os.path.join(get_default_pkl_dir(remote), "ray", func.__name__, id, "done") os.makedirs(os.path.dirname(done_file), exist_ok=True) @distributed(gpus=gpus, gpu_type=gpu_type, remote=remote, run_async=run_async) # always launch on all the gpus def wrapper(*args, **kwargs): from serverless_gpu.utils.ray_metrics_monitor import RayMetricsMonitor # Initialize the ray cluster head_ip_address = _initialize_ray_cluster() if not head_ip_address: raise RuntimeError("Failed to start a ray cluster. Please contact our support team.") result = None if rt.get_local_rank() == 0 and enable_ray_metrics_logging: # Every skyrun node need to start an individual Ray monitor, # to scrape the Ray metrics from the local Ray node. # so that the `RayMetricsMonitor` is started # on all "local-rank-zero" processes. metrics_monitor = RayMetricsMonitor( os.environ["MLFLOW_RUN_ID"], is_head_node=(rt.get_global_rank() == 0) ) metrics_monitor.start() else: metrics_monitor = None if rt.get_global_rank() == 0: try: # only run the Ray driver program on rank-zero, it submits Ray tasks # to the Ray cluster. result = func(*args, **kwargs) with open(done_file, "wb") as f: f.write(b"done") finally: try: import ray ray.shutdown() except Exception as e: log.warning(f"Failed to shutdown ray: {e}") if rt.get_local_rank() == 0 and rt.get_global_rank() != 0: is_done = False while not is_done: try: # We need to use the Workspace API because file system operations are cached. # See https://databricks.slack.com/archives/C05U0QLV95Y/p1737583951631169 with WorkspaceClient().workspace.download(done_file) as f: # Proper file handling is_done = f.read() == b"done" except ResourceDoesNotExist: time.sleep(5) # terminates the Ray metrics monitor after the user's workloads completes. if metrics_monitor: metrics_monitor.finish() return result return wrapper return decorator