Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 69 additions & 5 deletions python/raydp/spark/ray_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,74 @@

import glob
import os
import re
import subprocess
import sys
import platform
import pyspark
from typing import Any, Dict
from typing import Any, Dict, Optional

import ray
from pyspark.sql.session import SparkSession

from raydp.services import Cluster
from .ray_cluster_master import RAYDP_SPARK_MASTER_SUFFIX, SPARK_RAY_LOG4J_FACTORY_CLASS_KEY
from .ray_cluster_master import SPARK_LOG4J_CONFIG_FILE_NAME, RAY_LOG4J_CONFIG_FILE_NAME
from .ray_cluster_master import RayDPSparkMaster, SPARK_JAVAAGENT, SPARK_PREFER_CLASSPATH
from .ray_cluster_master import (RAYDP_SPARK_MASTER_SUFFIX, SPARK_RAY_LOG4J_FACTORY_CLASS_KEY,
SPARK_LOG4J_CONFIG_FILE_NAME, RAY_LOG4J_CONFIG_FILE_NAME,
RayDPSparkMaster, SPARK_JAVAAGENT, SPARK_PREFER_CLASSPATH,
RAYDP_APPMASTER_EXTRA_JAVA_OPTIONS)
from raydp import versions

_JDK17_ADD_OPENS = " ".join([
"-XX:+IgnoreUnrecognizedVMOptions",
"--add-opens=java.base/java.lang=ALL-UNNAMED",
"--add-opens=java.base/java.lang.invoke=ALL-UNNAMED",
"--add-opens=java.base/java.io=ALL-UNNAMED",
"--add-opens=java.base/java.net=ALL-UNNAMED",
"--add-opens=java.base/java.nio=ALL-UNNAMED",
"--add-opens=java.base/java.math=ALL-UNNAMED",
"--add-opens=java.base/java.text=ALL-UNNAMED",
"--add-opens=java.base/java.util=ALL-UNNAMED",
"--add-opens=java.base/java.util.concurrent=ALL-UNNAMED",
"--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED",
"--add-opens=java.base/sun.nio.ch=ALL-UNNAMED",
"--add-opens=java.base/sun.nio.cs=ALL-UNNAMED",
"--add-opens=java.base/sun.security.action=ALL-UNNAMED",
"--add-opens=java.base/sun.util.calendar=ALL-UNNAMED",
])

_cached_java_version: Optional[int] = None


def _get_java_major_version() -> Optional[int]:
"""Return the major version of the default ``java`` command, or None if
it cannot be determined."""
global _cached_java_version
if _cached_java_version is not None:
return _cached_java_version
try:
out = subprocess.check_output(
["java", "-version"], stderr=subprocess.STDOUT, timeout=10
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check JAVA_HOME first?

).decode("utf-8", errors="replace")
match = re.search(r'"(\d+)[\._]', out)
if match:
major = int(match.group(1))
# Java 1.x style (e.g. "1.8.0_xxx") means major = 8
if major == 1:
m2 = re.search(r'"1\.(\d+)', out)
if m2:
major = int(m2.group(1))
_cached_java_version = major
return major
except Exception:
pass
return None


def _needs_add_opens() -> bool:
"""Return True if the JVM is version 17+ and needs ``--add-opens`` flags."""
ver = _get_java_major_version()
return ver is not None and ver >= 17


class SparkCluster(Cluster):
def __init__(self,
Expand Down Expand Up @@ -176,13 +230,23 @@ def _prepare_spark_configs(self):
"If used with autoscaling, calculate it from max_workers.",
file=sys.stderr)

# On JDK 17+, inject --add-opens flags for driver and app master when
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also cover executor?

# the user has not explicitly provided them.
add_opens = _JDK17_ADD_OPENS if _needs_add_opens() else ""

if add_opens and "spark.driver.extraJavaOptions" not in self._configs:
self._configs["spark.driver.extraJavaOptions"] = add_opens

if add_opens and RAYDP_APPMASTER_EXTRA_JAVA_OPTIONS not in self._configs:
self._configs[RAYDP_APPMASTER_EXTRA_JAVA_OPTIONS] = add_opens

# set spark.driver.extraJavaOptions for driver (spark-submit)
java_opts = ["-javaagent:" + self._configs[SPARK_JAVAAGENT],
"-D" + SPARK_RAY_LOG4J_FACTORY_CLASS_KEY + "=" + versions.SPARK_LOG4J_VERSION,
"-D" + versions.SPARK_LOG4J_CONFIG_FILE_NAME_KEY + "=" +
self._configs[SPARK_LOG4J_CONFIG_FILE_NAME]
]
# Append to existing driver options if they exist (e.g., JDK 17+ flags)
# Append to existing driver options (user-provided or auto-injected above)
existing_driver_opts = self._configs.get("spark.driver.extraJavaOptions", "")
if existing_driver_opts:
all_opts = existing_driver_opts + " " + " ".join(java_opts)
Expand Down