"""Ray integration for distributed serverless GPU compute.
This module provides integration with Ray for distributed computing on serverless
GPU infrastructure. It includes:
- Ray cluster setup and management on distributed GPU nodes
- Integration with serverless GPU launcher for Ray workloads
- Utilities for Ray head node detection and connection management
- Support for Ray distributed training and inference patterns
The module enables users to run Ray-based distributed workloads on serverless
GPU compute resources seamlessly.
"""
import os
from typing import Optional, Union
import subprocess
import uuid
from serverless_gpu import runtime as rt
from serverless_gpu.launcher import distributed
from serverless_gpu.utils import get_default_pkl_dir
from serverless_gpu.compute import GPUType
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import ResourceDoesNotExist
import logging
import time
import socket
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
log = logging.getLogger(__name__)
def _check_port(host, port, timeout=5):
try:
with socket.create_connection((host, port), timeout=timeout):
return True
except (socket.timeout, socket.error):
return False
def _wait_for_port_open(host, port, timeout=60, check_interval=2):
"""
Wait for Ray cluster to be ready by checking if we can connect.
"""
import ray # noqa: F401
start_time = time.time()
while time.time() - start_time < timeout:
if _check_port(host, port):
log.info(f"Ray service is available at {host}:{port}")
time.sleep(3)
return True
log.debug(f"Ray service not yet ready at {host}:{port}")
time.sleep(check_interval)
log.error(f"Ray service failed to start at {host}:{port} within {timeout} seconds")
return False
def _robust_ray_init(address, max_retries=3, retry_delay=5):
"""
Robustly initialize Ray with retry logic and proper error handling.
"""
import ray
for attempt in range(max_retries):
try:
log.info(f"Attempting to connect to Ray cluster at {address} (attempt {attempt + 1}/{max_retries})")
if ray.is_initialized():
ray.shutdown()
ray.init(address=address, ignore_reinit_error=True)
if ray.is_initialized():
log.info(f"Successfully connected to Ray cluster at {address}")
return True
else:
log.warning("Ray init succeeded but cluster not properly initialized")
except Exception as e:
log.warning(f"Ray init attempt {attempt + 1} failed: {e}")
try:
ray.shutdown()
except Exception:
pass
if attempt < max_retries - 1:
log.info(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
log.error(f"Failed to initialize Ray after {max_retries} attempts")
return False
def _initialize_ray_cluster() -> Optional[str]:
"""
Initialize a ray cluster across all nodes using dist communications.
"""
from serverless_gpu.utils.sgc_dist import initialize_dist, all_gather_object, barrier
from serverless_gpu.consts import _RAY_METRICS_EXPORT_PORT
import torch
head_proc = worker_proc = None
head_failed = worker_failed = False
# Stop any existing Ray cluster first
try:
log.info("Stopping any existing Ray cluster...")
result = subprocess.run(
"ray stop --force", shell=True, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30
)
if result.returncode == 0:
log.info("Successfully stopped existing Ray cluster")
else:
log.warning(f"Ray stop returned code {result.returncode}")
# Wait for Ray to shut down cleanly
time.sleep(5)
except subprocess.TimeoutExpired:
log.warning("Ray stop command timed out")
except Exception as e:
log.warning(f"Error stopping Ray cluster: {e}")
initialize_dist()
# Get IP address with better error handling
command = "ip addr show eth0 | grep 'inet ' | awk '{print $2}' | cut -d/ -f1"
try:
result = subprocess.run(command, shell=True, capture_output=True, text=True, timeout=10)
if result.returncode != 0:
raise RuntimeError(f"Failed to get IP address: {result.stderr}")
ip_address = result.stdout.strip()
if not ip_address:
raise RuntimeError("Empty IP address returned")
log.info(f"Got IP address: {ip_address}")
except Exception as e:
log.error(f"Failed to get IP address: {e}")
raise
head_ip_address = all_gather_object(ip_address)[0]
log.info(f"Head node IP: {head_ip_address}")
try:
if rt.get_local_rank() == 0 and rt.get_global_rank() == 0:
log.info("Starting Ray head node...")
head_proc = subprocess.Popen(
f"ray start --head --metrics-export-port={_RAY_METRICS_EXPORT_PORT} --block",
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# Wait for head node to be ready
if not _wait_for_port_open(head_ip_address, port="6379", timeout=60):
log.error("Head node failed to start within timeout")
head_failed = True
# Need to set `RAY_ADDRESS` environment variable so that
# `ray.init()` can connect to the started Ray cluster.
os.environ["RAY_ADDRESS"] = f"{head_ip_address}:6379"
if not _robust_ray_init("auto"):
log.error("Failed to initialize Ray on head node")
head_failed = True
except Exception as e:
log.error(f"Exception while launching Ray head: {e}")
head_failed = True
head_status = not head_failed
head_status_all = all_gather_object(head_status)
if not all(head_status_all):
log.error("Ray head node failed detected on one or more workers!")
return None
barrier()
try:
if rt.get_local_rank() == 0 and rt.get_global_rank() != 0:
log.info(f"Starting Ray worker node, connecting to {head_ip_address}:6379...")
worker_proc = subprocess.Popen(
f"ray start --address {head_ip_address}:6379 --metrics-export-port={_RAY_METRICS_EXPORT_PORT} --block",
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# Wait for worker to connect
if not _wait_for_port_open(head_ip_address, port="6379", timeout=60):
log.error("Worker node failed to connect within timeout")
worker_failed = True
if not _robust_ray_init(f"{head_ip_address}:6379"):
log.error("Failed to initialize Ray on worker node")
worker_failed = True
except Exception as e:
log.error(f"Exception while launching Ray worker: {e}")
worker_failed = True
worker_status = not worker_failed
worker_status_all = all_gather_object(worker_status)
if not all(worker_status_all):
log.error("Ray worker node failed on one or more nodes!")
return None
barrier()
try:
torch.distributed.destroy_process_group()
log.info("Ray cluster initialization completed successfully")
return head_ip_address
except Exception as e:
log.error(f"Ray cluster initialization failed: {e}")
# Clean up processes
for proc in [head_proc, worker_proc]:
if proc:
try:
proc.terminate()
proc.wait(timeout=10)
except Exception:
proc.kill()
return None
def _check_ray_installed():
try:
import ray # noqa: F401
except ImportError:
raise ImportError("Ray is not installed. Please install ray to use this decorator")
[docs]
def ray_launch(
gpus: int, gpu_type: Optional[Union[GPUType, str]] = None, remote: bool = False, run_async: bool = False
):
"""
Experimental decorator to launch function with a ray cluster.
Args:
gpus (int): Number of gpus to launch ray on.
gpu_type (Optional[Union[GPUType, str]]): The GPU type to use. Defaults to None. Required if remote is True.
remote (bool): Use remote gpus.
run_async (bool): Whether to run the function asynchronously. Defaults to False.
"""
_check_ray_installed()
enable_ray_metrics_logging = (
remote
and os.environ.get("DATABRICKS_INTERNAL_ENABLE_SKYRUN_RAY_SYSTEM_METRICS_LOGGING", "true").lower() == "true"
)
def decorator(func):
id = str(uuid.uuid4())
done_file = os.path.join(get_default_pkl_dir(remote), "ray", func.__name__, id, "done")
os.makedirs(os.path.dirname(done_file), exist_ok=True)
@distributed(gpus=gpus, gpu_type=gpu_type, remote=remote, run_async=run_async) # always launch on all the gpus
def wrapper(*args, **kwargs):
from serverless_gpu.utils.ray_metrics_monitor import RayMetricsMonitor
# Initialize the ray cluster
head_ip_address = _initialize_ray_cluster()
if not head_ip_address:
raise RuntimeError("Failed to start a ray cluster. Please contact our support team.")
result = None
if rt.get_local_rank() == 0 and enable_ray_metrics_logging:
# Every skyrun node need to start an individual Ray monitor,
# to scrape the Ray metrics from the local Ray node.
# so that the `RayMetricsMonitor` is started
# on all "local-rank-zero" processes.
metrics_monitor = RayMetricsMonitor(
os.environ["MLFLOW_RUN_ID"], is_head_node=(rt.get_global_rank() == 0)
)
metrics_monitor.start()
else:
metrics_monitor = None
if rt.get_global_rank() == 0:
try:
# only run the Ray driver program on rank-zero, it submits Ray tasks
# to the Ray cluster.
result = func(*args, **kwargs)
with open(done_file, "wb") as f:
f.write(b"done")
finally:
try:
import ray
ray.shutdown()
except Exception as e:
log.warning(f"Failed to shutdown ray: {e}")
if rt.get_local_rank() == 0 and rt.get_global_rank() != 0:
is_done = False
while not is_done:
try:
# We need to use the Workspace API because file system operations are cached.
# See https://databricks.slack.com/archives/C05U0QLV95Y/p1737583951631169
with WorkspaceClient().workspace.download(done_file) as f: # Proper file handling
is_done = f.read() == b"done"
except ResourceDoesNotExist:
time.sleep(5)
# terminates the Ray metrics monitor after the user's workloads completes.
if metrics_monitor:
metrics_monitor.finish()
return result
return wrapper
return decorator