diff --git a/python/raydp/spark/ray_cluster.py b/python/raydp/spark/ray_cluster.py index 10816d25..54aec3b9 100644 --- a/python/raydp/spark/ray_cluster.py +++ b/python/raydp/spark/ray_cluster.py @@ -17,20 +17,74 @@ import glob import os +import re +import subprocess import sys import platform import pyspark -from typing import Any, Dict +from typing import Any, Dict, Optional import ray from pyspark.sql.session import SparkSession from raydp.services import Cluster -from .ray_cluster_master import RAYDP_SPARK_MASTER_SUFFIX, SPARK_RAY_LOG4J_FACTORY_CLASS_KEY -from .ray_cluster_master import SPARK_LOG4J_CONFIG_FILE_NAME, RAY_LOG4J_CONFIG_FILE_NAME -from .ray_cluster_master import RayDPSparkMaster, SPARK_JAVAAGENT, SPARK_PREFER_CLASSPATH +from .ray_cluster_master import (RAYDP_SPARK_MASTER_SUFFIX, SPARK_RAY_LOG4J_FACTORY_CLASS_KEY, + SPARK_LOG4J_CONFIG_FILE_NAME, RAY_LOG4J_CONFIG_FILE_NAME, + RayDPSparkMaster, SPARK_JAVAAGENT, SPARK_PREFER_CLASSPATH, + RAYDP_APPMASTER_EXTRA_JAVA_OPTIONS) from raydp import versions +_JDK17_ADD_OPENS = " ".join([ + "-XX:+IgnoreUnrecognizedVMOptions", + "--add-opens=java.base/java.lang=ALL-UNNAMED", + "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED", + "--add-opens=java.base/java.io=ALL-UNNAMED", + "--add-opens=java.base/java.net=ALL-UNNAMED", + "--add-opens=java.base/java.nio=ALL-UNNAMED", + "--add-opens=java.base/java.math=ALL-UNNAMED", + "--add-opens=java.base/java.text=ALL-UNNAMED", + "--add-opens=java.base/java.util=ALL-UNNAMED", + "--add-opens=java.base/java.util.concurrent=ALL-UNNAMED", + "--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED", + "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", + "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED", + "--add-opens=java.base/sun.security.action=ALL-UNNAMED", + "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED", +]) + +_cached_java_version: Optional[int] = None + + +def _get_java_major_version() -> Optional[int]: + """Return the major version of the default ``java`` command, or None if + it cannot be determined.""" + global _cached_java_version + if _cached_java_version is not None: + return _cached_java_version + try: + out = subprocess.check_output( + ["java", "-version"], stderr=subprocess.STDOUT, timeout=10 + ).decode("utf-8", errors="replace") + match = re.search(r'"(\d+)[\._]', out) + if match: + major = int(match.group(1)) + # Java 1.x style (e.g. "1.8.0_xxx") means major = 8 + if major == 1: + m2 = re.search(r'"1\.(\d+)', out) + if m2: + major = int(m2.group(1)) + _cached_java_version = major + return major + except Exception: + pass + return None + + +def _needs_add_opens() -> bool: + """Return True if the JVM is version 17+ and needs ``--add-opens`` flags.""" + ver = _get_java_major_version() + return ver is not None and ver >= 17 + class SparkCluster(Cluster): def __init__(self, @@ -176,13 +230,23 @@ def _prepare_spark_configs(self): "If used with autoscaling, calculate it from max_workers.", file=sys.stderr) + # On JDK 17+, inject --add-opens flags for driver and app master when + # the user has not explicitly provided them. + add_opens = _JDK17_ADD_OPENS if _needs_add_opens() else "" + + if add_opens and "spark.driver.extraJavaOptions" not in self._configs: + self._configs["spark.driver.extraJavaOptions"] = add_opens + + if add_opens and RAYDP_APPMASTER_EXTRA_JAVA_OPTIONS not in self._configs: + self._configs[RAYDP_APPMASTER_EXTRA_JAVA_OPTIONS] = add_opens + # set spark.driver.extraJavaOptions for driver (spark-submit) java_opts = ["-javaagent:" + self._configs[SPARK_JAVAAGENT], "-D" + SPARK_RAY_LOG4J_FACTORY_CLASS_KEY + "=" + versions.SPARK_LOG4J_VERSION, "-D" + versions.SPARK_LOG4J_CONFIG_FILE_NAME_KEY + "=" + self._configs[SPARK_LOG4J_CONFIG_FILE_NAME] ] - # Append to existing driver options if they exist (e.g., JDK 17+ flags) + # Append to existing driver options (user-provided or auto-injected above) existing_driver_opts = self._configs.get("spark.driver.extraJavaOptions", "") if existing_driver_opts: all_opts = existing_driver_opts + " " + " ".join(java_opts)