diff --git a/R/CRAN_RELEASE.md b/R/CRAN_RELEASE.md index bea8f9fbe4..d6084c7a7c 100644 --- a/R/CRAN_RELEASE.md +++ b/R/CRAN_RELEASE.md @@ -7,7 +7,7 @@ To release SparkR as a package to CRAN, we would use the `devtools` package. Ple First, check that the `Version:` field in the `pkg/DESCRIPTION` file is updated. Also, check for stale files not under source control. -Note that while `check-cran.sh` is running `R CMD check`, it is doing so with `--no-manual --no-vignettes`, which skips a few vignettes or PDF checks - therefore it will be preferred to run `R CMD check` on the source package built manually before uploading a release. +Note that while `run-tests.sh` runs `check-cran.sh` (which runs `R CMD check`), it is doing so with `--no-manual --no-vignettes`, which skips a few vignettes or PDF checks - therefore it will be preferred to run `R CMD check` on the source package built manually before uploading a release. Also note that for CRAN checks for pdf vignettes to success, `qpdf` tool must be there (to install it, eg. `yum -q -y install qpdf`). To upload a release, we would need to update the `cran-comments.md`. This should generally contain the results from running the `check-cran.sh` script along with comments on status of all `WARNING` (should not be any) or `NOTE`. As a part of `check-cran.sh` and the release process, the vignettes is build - make sure `SPARK_HOME` is set and Spark jars are accessible. diff --git a/R/check-cran.sh b/R/check-cran.sh index c5f042848c..1288e7fc9f 100755 --- a/R/check-cran.sh +++ b/R/check-cran.sh @@ -34,8 +34,9 @@ if [ ! -z "$R_HOME" ] fi R_SCRIPT_PATH="$(dirname $(which R))" fi -echo "USING R_HOME = $R_HOME" +echo "Using R_SCRIPT_PATH = ${R_SCRIPT_PATH}" +# Install the package (this is required for code in vignettes to run when building it later) # Build the latest docs, but not vignettes, which is built with the package next $FWDIR/create-docs.sh @@ -82,4 +83,20 @@ else # This will run tests and/or build vignettes, and require SPARK_HOME SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz fi + +# Install source package to get it to generate vignettes rds files, etc. +if [ -n "$CLEAN_INSTALL" ] +then + echo "Removing lib path and installing from source package" + LIB_DIR="$FWDIR/lib" + rm -rf $LIB_DIR + mkdir -p $LIB_DIR + "$R_SCRIPT_PATH/"R CMD INSTALL SparkR_"$VERSION".tar.gz --library=$LIB_DIR + + # Zip the SparkR package so that it can be distributed to worker nodes on YARN + pushd $LIB_DIR > /dev/null + jar cfM "$LIB_DIR/sparkr.zip" SparkR + popd > /dev/null +fi + popd > /dev/null diff --git a/R/install-dev.sh b/R/install-dev.sh index ada6303a72..0f881208bc 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -46,7 +46,7 @@ if [ ! -z "$R_HOME" ] fi R_SCRIPT_PATH="$(dirname $(which R))" fi -echo "USING R_HOME = $R_HOME" +echo "Using R_SCRIPT_PATH = ${R_SCRIPT_PATH}" # Generate Rd files if devtools is installed "$R_SCRIPT_PATH/"Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' diff --git a/R/pkg/.Rbuildignore b/R/pkg/.Rbuildignore index 544d203a6d..f12f8c275a 100644 --- a/R/pkg/.Rbuildignore +++ b/R/pkg/.Rbuildignore @@ -1,5 +1,8 @@ ^.*\.Rproj$ ^\.Rproj\.user$ ^\.lintr$ +^cran-comments\.md$ +^NEWS\.md$ +^README\.Rmd$ ^src-native$ ^html$ diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 981ae12464..0cb3a80a6e 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,8 +1,8 @@ Package: SparkR Type: Package -Title: R Frontend for Apache Spark Version: 2.1.0 -Date: 2016-11-06 +Title: R Frontend for Apache Spark +Description: The SparkR package provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), email = "shivaram@cs.berkeley.edu"), person("Xiangrui", "Meng", role = "aut", @@ -10,19 +10,18 @@ Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), person("Felix", "Cheung", role = "aut", email = "felixcheung@apache.org"), person(family = "The Apache Software Foundation", role = c("aut", "cph"))) +License: Apache License (== 2.0) URL: http://www.apache.org/ http://spark.apache.org/ BugReports: http://spark.apache.org/contributing.html Depends: R (>= 3.0), methods Suggests: + knitr, + rmarkdown, testthat, e1071, - survival, - knitr, - rmarkdown -Description: The SparkR package provides an R frontend for Apache Spark. -License: Apache License (== 2.0) + survival Collate: 'schema.R' 'generics.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index daee09de88..c3ec3f4fb1 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -3,7 +3,7 @@ importFrom("methods", "setGeneric", "setMethod", "setOldClass") importFrom("methods", "is", "new", "signature", "show") importFrom("stats", "gaussian", "setNames") -importFrom("utils", "download.file", "object.size", "packageVersion", "untar") +importFrom("utils", "download.file", "object.size", "packageVersion", "tail", "untar") # Disable native libraries till we figure out how to package it # See SPARKR-7839 @@ -16,6 +16,7 @@ export("sparkR.stop") export("sparkR.session.stop") export("sparkR.conf") export("sparkR.version") +export("sparkR.uiWebUrl") export("print.jobj") export("sparkR.newJObject") diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 38d83c6e5c..6f48cd6639 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -634,7 +634,7 @@ tableNames <- function(x, ...) { cacheTable.default <- function(tableName) { sparkSession <- getSparkSession() catalog <- callJMethod(sparkSession, "catalog") - callJMethod(catalog, "cacheTable", tableName) + invisible(callJMethod(catalog, "cacheTable", tableName)) } cacheTable <- function(x, ...) { @@ -663,7 +663,7 @@ cacheTable <- function(x, ...) { uncacheTable.default <- function(tableName) { sparkSession <- getSparkSession() catalog <- callJMethod(sparkSession, "catalog") - callJMethod(catalog, "uncacheTable", tableName) + invisible(callJMethod(catalog, "uncacheTable", tableName)) } uncacheTable <- function(x, ...) { @@ -686,7 +686,7 @@ uncacheTable <- function(x, ...) { clearCache.default <- function() { sparkSession <- getSparkSession() catalog <- callJMethod(sparkSession, "catalog") - callJMethod(catalog, "clearCache") + invisible(callJMethod(catalog, "clearCache")) } clearCache <- function() { @@ -730,6 +730,7 @@ dropTempTable <- function(x, ...) { #' If the view has been cached before, then it will also be uncached. #' #' @param viewName the name of the view to be dropped. +#' @return TRUE if the view is dropped successfully, FALSE otherwise. #' @rdname dropTempView #' @name dropTempView #' @export diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 438d77a388..1138caf98e 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -87,8 +87,8 @@ objectFile <- function(sc, path, minPartitions = NULL) { #' in the list are split into \code{numSlices} slices and distributed to nodes #' in the cluster. #' -#' If size of serialized slices is larger than spark.r.maxAllocationLimit or (200MB), the function -#' will write it to disk and send the file name to JVM. Also to make sure each slice is not +#' If size of serialized slices is larger than spark.r.maxAllocationLimit or (200MB), the function +#' will write it to disk and send the file name to JVM. Also to make sure each slice is not #' larger than that limit, number of slices may be increased. #' #' @param sc SparkContext to use @@ -379,5 +379,5 @@ spark.lapply <- function(list, func) { #' @note setLogLevel since 2.0.0 setLogLevel <- function(level) { sc <- getSparkContext() - callJMethod(sc, "setLogLevel", level) + invisible(callJMethod(sc, "setLogLevel", level)) } diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index 69b0a523b8..097b7ad4be 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -79,19 +79,28 @@ install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL, dir.create(localDir, recursive = TRUE) } - packageLocalDir <- file.path(localDir, packageName) - if (overwrite) { message(paste0("Overwrite = TRUE: download and overwrite the tar file", "and Spark package directory if they exist.")) } + releaseUrl <- Sys.getenv("SPARKR_RELEASE_DOWNLOAD_URL") + if (releaseUrl != "") { + packageName <- basenameSansExtFromUrl(releaseUrl) + } + + packageLocalDir <- file.path(localDir, packageName) + # can use dir.exists(packageLocalDir) under R 3.2.0 or later if (!is.na(file.info(packageLocalDir)$isdir) && !overwrite) { - fmt <- "%s for Hadoop %s found, with SPARK_HOME set to %s" - msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion), - packageLocalDir) - message(msg) + if (releaseUrl != "") { + message(paste(packageName, "found, setting SPARK_HOME to", packageLocalDir)) + } else { + fmt <- "%s for Hadoop %s found, setting SPARK_HOME to %s" + msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion), + packageLocalDir) + message(msg) + } Sys.setenv(SPARK_HOME = packageLocalDir) return(invisible(packageLocalDir)) } else { @@ -104,7 +113,12 @@ install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL, if (tarExists && !overwrite) { message("tar file found.") } else { - robustDownloadTar(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) + if (releaseUrl != "") { + message("Downloading from alternate URL:\n- ", releaseUrl) + downloadUrl(releaseUrl, packageLocalPath, paste0("Fetch failed from ", releaseUrl)) + } else { + robustDownloadTar(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) + } } message(sprintf("Installing to %s", localDir)) @@ -182,16 +196,18 @@ getPreferredMirror <- function(version, packageName) { } directDownloadTar <- function(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) { - packageRemotePath <- paste0( - file.path(mirrorUrl, version, packageName), ".tgz") + packageRemotePath <- paste0(file.path(mirrorUrl, version, packageName), ".tgz") fmt <- "Downloading %s for Hadoop %s from:\n- %s" msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion), packageRemotePath) message(msg) + downloadUrl(packageRemotePath, packageLocalPath, paste0("Fetch failed from ", mirrorUrl)) +} - isFail <- tryCatch(download.file(packageRemotePath, packageLocalPath), +downloadUrl <- function(remotePath, localPath, errorMessage) { + isFail <- tryCatch(download.file(remotePath, localPath), error = function(e) { - message(sprintf("Fetch failed from %s", mirrorUrl)) + message(errorMessage) print(e) TRUE }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index eed829356f..d736bbb5e9 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -191,7 +191,7 @@ predict_internal <- function(object, newData) { #' @param regParam regularization parameter for L2 regularization. #' @param ... additional arguments passed to the method. #' @aliases spark.glm,SparkDataFrame,formula-method -#' @return \code{spark.glm} returns a fitted generalized linear model +#' @return \code{spark.glm} returns a fitted generalized linear model. #' @rdname spark.glm #' @name spark.glm #' @export @@ -277,12 +277,12 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat # Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). #' @param object a fitted generalized linear model. -#' @return \code{summary} returns a summary object of the fitted model, a list of components -#' including at least the coefficients matrix (which includes coefficients, standard error -#' of coefficients, t value and p value), null/residual deviance, null/residual degrees of -#' freedom, AIC and number of iterations IRLS takes. If there are collinear columns -#' in you data, the coefficients matrix only provides coefficients. -#' +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes at least the \code{coefficients} (coefficients matrix, which includes +#' coefficients, standard error of coefficients, t value and p value), +#' \code{null.deviance} (null/residual degrees of freedom), \code{aic} (AIC) +#' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in the data, +#' the coefficients matrix only provides coefficients. #' @rdname spark.glm #' @export #' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 @@ -328,7 +328,7 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), # Prints the summary of GeneralizedLinearRegressionModel #' @rdname spark.glm -#' @param x summary object of fitted generalized linear model returned by \code{summary} function +#' @param x summary object of fitted generalized linear model returned by \code{summary} function. #' @export #' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0 print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { @@ -361,7 +361,7 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named -#' "prediction" +#' "prediction". #' @rdname spark.glm #' @export #' @note predict(GeneralizedLinearRegressionModel) since 1.5.0 @@ -375,7 +375,7 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named -#' "prediction" +#' "prediction". #' @rdname spark.naiveBayes #' @export #' @note predict(NaiveBayesModel) since 2.0.0 @@ -387,8 +387,9 @@ setMethod("predict", signature(object = "NaiveBayesModel"), # Returns the summary of a naive Bayes model produced by \code{spark.naiveBayes} #' @param object a naive Bayes model fitted by \code{spark.naiveBayes}. -#' @return \code{summary} returns a list containing \code{apriori}, the label distribution, and -#' \code{tables}, conditional probabilities given the target label. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{apriori} (the label distribution) and +#' \code{tables} (conditional probabilities given the target label). #' @rdname spark.naiveBayes #' @export #' @note summary(NaiveBayesModel) since 2.0.0 @@ -409,9 +410,9 @@ setMethod("summary", signature(object = "NaiveBayesModel"), # Returns posterior probabilities from a Latent Dirichlet Allocation model produced by spark.lda() -#' @param newData A SparkDataFrame for testing +#' @param newData A SparkDataFrame for testing. #' @return \code{spark.posterior} returns a SparkDataFrame containing posterior probabilities -#' vectors named "topicDistribution" +#' vectors named "topicDistribution". #' @rdname spark.lda #' @aliases spark.posterior,LDAModel,SparkDataFrame-method #' @export @@ -425,7 +426,8 @@ setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkData #' @param object A Latent Dirichlet Allocation model fitted by \code{spark.lda}. #' @param maxTermsPerTopic Maximum number of terms to collect for each topic. Default value of 10. -#' @return \code{summary} returns a list containing +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes #' \item{\code{docConcentration}}{concentration parameter commonly named \code{alpha} for #' the prior placed on documents distributions over topics \code{theta}} #' \item{\code{topicConcentration}}{concentration parameter commonly named \code{beta} or @@ -476,7 +478,7 @@ setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFr # Saves the Latent Dirichlet Allocation model to the input path. -#' @param path The directory where the model is saved +#' @param path The directory where the model is saved. #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @@ -495,16 +497,16 @@ setMethod("write.ml", signature(object = "LDAModel", path = "character"), #' Fits an Isotonic Regression model against a Spark DataFrame, similarly to R's isoreg(). #' Users can print, make predictions on the produced model and save the model to the input path. #' -#' @param data SparkDataFrame for training +#' @param data SparkDataFrame for training. #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param isotonic Whether the output sequence should be isotonic/increasing (TRUE) or -#' antitonic/decreasing (FALSE) +#' antitonic/decreasing (FALSE). #' @param featureIndex The index of the feature if \code{featuresCol} is a vector column -#' (default: 0), no effect otherwise +#' (default: 0), no effect otherwise. #' @param weightCol The weight column name. #' @param ... additional arguments passed to the method. -#' @return \code{spark.isoreg} returns a fitted Isotonic Regression model +#' @return \code{spark.isoreg} returns a fitted Isotonic Regression model. #' @rdname spark.isoreg #' @aliases spark.isoreg,SparkDataFrame,formula-method #' @name spark.isoreg @@ -550,9 +552,9 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula" # Predicted values based on an isotonicRegression model -#' @param object a fitted IsotonicRegressionModel -#' @param newData SparkDataFrame for testing -#' @return \code{predict} returns a SparkDataFrame containing predicted values +#' @param object a fitted IsotonicRegressionModel. +#' @param newData SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted values. #' @rdname spark.isoreg #' @aliases predict,IsotonicRegressionModel,SparkDataFrame-method #' @export @@ -564,7 +566,9 @@ setMethod("predict", signature(object = "IsotonicRegressionModel"), # Get the summary of an IsotonicRegressionModel model -#' @return \code{summary} returns the model's boundaries and prediction as lists +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes model's \code{boundaries} (boundaries in increasing order) +#' and \code{predictions} (predictions associated with the boundaries at the same index). #' @rdname spark.isoreg #' @aliases summary,IsotonicRegressionModel-method #' @export @@ -661,7 +665,11 @@ setMethod("fitted", signature(object = "KMeansModel"), # Get the summary of a k-means model #' @param object a fitted k-means model. -#' @return \code{summary} returns the model's features, coefficients, k, size and cluster. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes the model's \code{k} (number of cluster centers), +#' \code{coefficients} (model cluster centers), +#' \code{size} (number of data points in each cluster), and \code{cluster} +#' (cluster centers of the transformed data). #' @rdname spark.kmeans #' @export #' @note summary(KMeansModel) since 2.0.0 @@ -681,7 +689,7 @@ setMethod("summary", signature(object = "KMeansModel"), } else { dataFrame(callJMethod(jobj, "cluster")) } - list(coefficients = coefficients, size = size, + list(k = k, coefficients = coefficients, size = size, cluster = cluster, is.loaded = is.loaded) }) @@ -703,7 +711,7 @@ setMethod("predict", signature(object = "KMeansModel"), #' with pivoting; "multinomial": Multinomial logistic (softmax) regression without pivoting, similar to glmnet. #' Users can print, make predictions on the produced model and save the model to the input path. #' -#' @param data SparkDataFrame for training +#' @param data SparkDataFrame for training. #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param regParam the regularization parameter. @@ -733,11 +741,8 @@ setMethod("predict", signature(object = "KMeansModel"), #' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p #' is the original probability of that class and t is the class's threshold. #' @param weightCol The weight column name. -#' @param aggregationDepth depth for treeAggregate (>= 2). If the dimensions of features or the number of partitions -#' are large, this param could be adjusted to a larger size. -#' @param probabilityCol column name for predicted class conditional probabilities. #' @param ... additional arguments passed to the method. -#' @return \code{spark.logit} returns a fitted logistic regression model +#' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit #' @aliases spark.logit,SparkDataFrame,formula-method #' @name spark.logit @@ -746,45 +751,35 @@ setMethod("predict", signature(object = "KMeansModel"), #' \dontrun{ #' sparkR.session() #' # binary logistic regression -#' label <- c(0.0, 0.0, 0.0, 1.0, 1.0) -#' features <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) -#' binary_data <- as.data.frame(cbind(label, features)) -#' binary_df <- createDataFrame(binary_data) -#' blr_model <- spark.logit(binary_df, label ~ features, thresholds = 1.0) -#' blr_predict <- collect(select(predict(blr_model, binary_df), "prediction")) -#' -#' # summary of binary logistic regression -#' blr_summary <- summary(blr_model) -#' blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure")) +#' df <- createDataFrame(iris) +#' training <- df[df$Species %in% c("versicolor", "virginica"), ] +#' model <- spark.logit(training, Species ~ ., regParam = 0.5) +#' summary <- summary(model) +#' +#' # fitted values on training data +#' fitted <- predict(model, training) +#' #' # save fitted model to input path #' path <- "path/to/model" -#' write.ml(blr_model, path) +#' write.ml(model, path) #' #' # can also read back the saved model and predict #' # Note that summary deos not work on loaded model #' savedModel <- read.ml(path) -#' blr_predict2 <- collect(select(predict(savedModel, binary_df), "prediction")) +#' summary(savedModel) #' #' # multinomial logistic regression #' -#' label <- c(0.0, 1.0, 2.0, 0.0, 0.0) -#' feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667) -#' feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987) -#' feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130) -#' feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842) -#' data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4)) -#' df <- createDataFrame(data) +#' df <- createDataFrame(iris) +#' model <- spark.logit(df, Species ~ ., regParam = 0.5) +#' summary <- summary(model) #' -#' # Note that summary of multinomial logistic regression is not implemented yet -#' model <- spark.logit(df, label ~ ., family = "multinomial", thresholds = c(0, 1, 1)) -#' predict1 <- collect(select(predict(model, df), "prediction")) #' } #' @note spark.logit since 2.1.0 setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, tol = 1E-6, family = "auto", standardization = TRUE, - thresholds = 0.5, weightCol = NULL, aggregationDepth = 2, - probabilityCol = "probability") { + thresholds = 0.5, weightCol = NULL) { formula <- paste(deparse(formula), collapse = "") if (is.null(weightCol)) { @@ -796,8 +791,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") as.numeric(elasticNetParam), as.integer(maxIter), as.numeric(tol), as.character(family), as.logical(standardization), as.array(thresholds), - as.character(weightCol), as.integer(aggregationDepth), - as.character(probabilityCol)) + as.character(weightCol)) new("LogisticRegressionModel", jobj = jobj) }) @@ -816,11 +810,9 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), # Get the summary of an LogisticRegressionModel -#' @param object an LogisticRegressionModel fitted by \code{spark.logit} -#' @return \code{summary} returns the Binary Logistic regression results of a given model as list, -#' including roc, areaUnderROC, pr, fMeasureByThreshold, precisionByThreshold, -#' recallByThreshold, totalIterations, objectiveHistory. Note that Multinomial logistic -#' regression summary is not available now. +#' @param object an LogisticRegressionModel fitted by \code{spark.logit}. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{coefficients} (coefficients matrix of the fitted model). #' @rdname spark.logit #' @aliases summary,LogisticRegressionModel-method #' @export @@ -828,33 +820,21 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), setMethod("summary", signature(object = "LogisticRegressionModel"), function(object) { jobj <- object@jobj - is.loaded <- callJMethod(jobj, "isLoaded") - - if (is.loaded) { - stop("Loaded model doesn't have training summary.") + features <- callJMethod(jobj, "rFeatures") + labels <- callJMethod(jobj, "labels") + coefficients <- callJMethod(jobj, "rCoefficients") + nCol <- length(coefficients) / length(features) + coefficients <- matrix(coefficients, ncol = nCol) + # If nCol == 1, means this is a binomial logistic regression model with pivoting. + # Otherwise, it's a multinomial logistic regression model without pivoting. + if (nCol == 1) { + colnames(coefficients) <- c("Estimate") + } else { + colnames(coefficients) <- unlist(labels) } + rownames(coefficients) <- unlist(features) - roc <- dataFrame(callJMethod(jobj, "roc")) - - areaUnderROC <- callJMethod(jobj, "areaUnderROC") - - pr <- dataFrame(callJMethod(jobj, "pr")) - - fMeasureByThreshold <- dataFrame(callJMethod(jobj, "fMeasureByThreshold")) - - precisionByThreshold <- dataFrame(callJMethod(jobj, "precisionByThreshold")) - - recallByThreshold <- dataFrame(callJMethod(jobj, "recallByThreshold")) - - totalIterations <- callJMethod(jobj, "totalIterations") - - objectiveHistory <- callJMethod(jobj, "objectiveHistory") - - list(roc = roc, areaUnderROC = areaUnderROC, pr = pr, - fMeasureByThreshold = fMeasureByThreshold, - precisionByThreshold = precisionByThreshold, - recallByThreshold = recallByThreshold, - totalIterations = totalIterations, objectiveHistory = objectiveHistory) + list(coefficients = coefficients) }) #' Multilayer Perceptron Classification Model @@ -871,7 +851,7 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param blockSize blockSize parameter. -#' @param layers integer vector containing the number of nodes for each layer +#' @param layers integer vector containing the number of nodes for each layer. #' @param solver solver parameter, supported options: "gd" (minibatch gradient descent) or "l-bfgs". #' @param maxIter maximum iteration number. #' @param tol convergence tolerance of iterations. @@ -949,10 +929,12 @@ setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel # Returns the summary of a Multilayer Perceptron Classification Model produced by \code{spark.mlp} #' @param object a Multilayer Perceptron Classification Model fitted by \code{spark.mlp} -#' @return \code{summary} returns a list containing \code{numOfInputs}, \code{numOfOutputs}, -#' \code{layers}, and \code{weights}. For \code{weights}, it is a numeric vector with -#' length equal to the expected given the architecture (i.e., for 8-10-2 network, -#' 112 connection weights). +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{numOfInputs} (number of inputs), \code{numOfOutputs} +#' (number of outputs), \code{layers} (array of layer sizes including input +#' and output layers), and \code{weights} (the weights of layers). +#' For \code{weights}, it is a numeric vector with length equal to the expected +#' given the architecture (i.e., for 8-10-2 network, 112 connection weights). #' @rdname spark.mlp #' @export #' @aliases summary,MultilayerPerceptronClassificationModel-method @@ -1017,7 +999,7 @@ setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "form # Saves the Bernoulli naive Bayes model to the input path. -#' @param path the directory where the model is saved +#' @param path the directory where the model is saved. #' @param overwrite overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @@ -1091,7 +1073,7 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode # Save fitted IsotonicRegressionModel to the input path -#' @param path The directory where the model is saved +#' @param path The directory where the model is saved. #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @@ -1106,7 +1088,7 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char # Save fitted LogisticRegressionModel to the input path -#' @param path The directory where the model is saved +#' @param path The directory where the model is saved. #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @@ -1233,7 +1215,7 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula #' posterior probabilities on new data, \code{spark.perplexity} to compute log perplexity on new #' data and \code{write.ml}/\code{read.ml} to save/load fitted models. #' -#' @param data A SparkDataFrame for training +#' @param data A SparkDataFrame for training. #' @param features Features column name. Either libSVM-format column or character-format column is #' valid. #' @param k Number of topics. @@ -1253,7 +1235,7 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula #' parameter if libSVM-format column is used as the features column. #' @param maxVocabSize maximum vocabulary size, default 1 << 18 #' @param ... additional argument(s) passed to the method. -#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model +#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model. #' @rdname spark.lda #' @aliases spark.lda,SparkDataFrame-method #' @seealso topicmodels: \url{https://cran.r-project.org/package=topicmodels} @@ -1301,8 +1283,9 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"), # similarly to R's summary(). #' @param object a fitted AFT survival regression model. -#' @return \code{summary} returns a list containing the model's features, coefficients, -#' intercept and log(scale) +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes the model's \code{coefficients} (features, coefficients, +#' intercept and log(scale)). #' @rdname spark.survreg #' @export #' @note summary(AFTSurvivalRegressionModel) since 2.0.0 @@ -1322,7 +1305,7 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns a SparkDataFrame containing predicted values -#' on the original scale of the data (mean predicted value at scale = 1.0). +#' on the original scale of the data (mean predicted value at scale = 1.0). #' @rdname spark.survreg #' @export #' @note predict(AFTSurvivalRegressionModel) since 2.0.0 @@ -1389,7 +1372,9 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = # Get the summary of a multivariate gaussian mixture model #' @param object a fitted gaussian mixture model. -#' @return \code{summary} returns the model's lambda, mu, sigma, k, dim and posterior. +#' @return \code{summary} returns summary of the fitted model, which is a list. +#' The list includes the model's \code{lambda} (lambda), \code{mu} (mu), +#' \code{sigma} (sigma), and \code{posterior} (posterior). #' @aliases spark.gaussianMixture,SparkDataFrame,formula-method #' @rdname spark.gaussianMixture #' @export @@ -1453,7 +1438,7 @@ setMethod("predict", signature(object = "GaussianMixtureModel"), #' @param userCol column name for user ids. Ids must be (or can be coerced into) integers. #' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers. #' @param rank rank of the matrix factorization (> 0). -#' @param reg regularization parameter (>= 0). +#' @param regParam regularization parameter (>= 0). #' @param maxIter maximum number of iterations (>= 0). #' @param nonnegative logical value indicating whether to apply nonnegativity constraints. #' @param implicitPrefs logical value indicating whether to use implicit preference. @@ -1463,7 +1448,7 @@ setMethod("predict", signature(object = "GaussianMixtureModel"), #' @param numItemBlocks number of item blocks used to parallelize computation (> 0). #' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1). #' @param ... additional argument(s) passed to the method. -#' @return \code{spark.als} returns a fitted ALS model +#' @return \code{spark.als} returns a fitted ALS model. #' @rdname spark.als #' @aliases spark.als,SparkDataFrame-method #' @name spark.als @@ -1492,21 +1477,21 @@ setMethod("predict", signature(object = "GaussianMixtureModel"), #' #' # set other arguments #' modelS <- spark.als(df, "rating", "user", "item", rank = 20, -#' reg = 0.1, nonnegative = TRUE) +#' regParam = 0.1, nonnegative = TRUE) #' statsS <- summary(modelS) #' } #' @note spark.als since 2.1.0 setMethod("spark.als", signature(data = "SparkDataFrame"), function(data, ratingCol = "rating", userCol = "user", itemCol = "item", - rank = 10, reg = 0.1, maxIter = 10, nonnegative = FALSE, + rank = 10, regParam = 0.1, maxIter = 10, nonnegative = FALSE, implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, numItemBlocks = 10, checkpointInterval = 10, seed = 0) { if (!is.numeric(rank) || rank <= 0) { stop("rank should be a positive number.") } - if (!is.numeric(reg) || reg < 0) { - stop("reg should be a nonnegative number.") + if (!is.numeric(regParam) || regParam < 0) { + stop("regParam should be a nonnegative number.") } if (!is.numeric(maxIter) || maxIter <= 0) { stop("maxIter should be a positive number.") @@ -1514,7 +1499,7 @@ setMethod("spark.als", signature(data = "SparkDataFrame"), jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper", "fit", data@sdf, ratingCol, userCol, itemCol, as.integer(rank), - reg, as.integer(maxIter), implicitPrefs, alpha, nonnegative, + regParam, as.integer(maxIter), implicitPrefs, alpha, nonnegative, as.integer(numUserBlocks), as.integer(numItemBlocks), as.integer(checkpointInterval), as.integer(seed)) new("ALSModel", jobj = jobj) @@ -1523,9 +1508,11 @@ setMethod("spark.als", signature(data = "SparkDataFrame"), # Returns a summary of the ALS model produced by spark.als. #' @param object a fitted ALS model. -#' @return \code{summary} returns a list containing the names of the user column, -#' the item column and the rating column, the estimated user and item factors, -#' rank, regularization parameter and maximum number of iterations used in training. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{user} (the names of the user column), +#' \code{item} (the item column), \code{rating} (the rating column), \code{userFactors} +#' (the estimated user factors), \code{itemFactors} (the estimated item factors), +#' and \code{rank} (rank of the matrix factorization model). #' @rdname spark.als #' @aliases summary,ALSModel-method #' @export @@ -1608,14 +1595,14 @@ setMethod("write.ml", signature(object = "ALSModel", path = "character"), #' \dontrun{ #' data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25)) #' df <- createDataFrame(data) -#' test <- spark.ktest(df, "test", "norm", c(0, 1)) +#' test <- spark.kstest(df, "test", "norm", c(0, 1)) #' #' # get a summary of the test result #' testSummary <- summary(test) #' testSummary #' #' # print out the summary in an organized way -#' print.summary.KSTest(test) +#' print.summary.KSTest(testSummary) #' } #' @note spark.kstest since 2.1.0 setMethod("spark.kstest", signature(data = "SparkDataFrame"), @@ -1638,9 +1625,10 @@ setMethod("spark.kstest", signature(data = "SparkDataFrame"), # Get the summary of Kolmogorov-Smirnov (KS) Test. #' @param object test result object of KSTest by \code{spark.kstest}. -#' @return \code{summary} returns a list containing the p-value, test statistic computed for the -#' test, the null hypothesis with its parameters tested against -#' and degrees of freedom of the test. +#' @return \code{summary} returns summary information of KSTest object, which is a list. +#' The list includes the \code{p.value} (p-value), \code{statistic} (test statistic +#' computed for the test), \code{nullHypothesis} (the null hypothesis with its +#' parameters tested against) and \code{degreesOfFreedom} (degrees of freedom of the test). #' @rdname spark.kstest #' @aliases summary,KSTest-method #' @export @@ -1712,8 +1700,6 @@ print.summary.KSTest <- function(x, ...) { #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. -#' @param probabilityCol column name for predicted class conditional probabilities, only for -#' classification. #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -1748,7 +1734,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, - maxMemoryInMB = 256, cacheNodeIds = FALSE, probabilityCol = "probability") { + maxMemoryInMB = 256, cacheNodeIds = FALSE) { type <- match.arg(type) formula <- paste(deparse(formula), collapse = "") if (!is.null(seed)) { @@ -1777,7 +1763,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo impurity, as.integer(minInstancesPerNode), as.numeric(minInfoGain), as.integer(checkpointInterval), as.character(featureSubsetStrategy), seed, - as.numeric(subsamplingRate), as.character(probabilityCol), + as.numeric(subsamplingRate), as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) new("RandomForestClassificationModel", jobj = jobj) } @@ -1788,7 +1774,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named -#' "prediction" +#' "prediction". #' @rdname spark.randomForest #' @aliases predict,RandomForestRegressionModel-method #' @export @@ -1809,8 +1795,8 @@ setMethod("predict", signature(object = "RandomForestClassificationModel"), # Save the Random Forest Regression or Classification model to the input path. -#' @param object A fitted Random Forest regression model or classification model -#' @param path The directory where the model is saved +#' @param object A fitted Random Forest regression model or classification model. +#' @param path The directory where the model is saved. #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @@ -1852,9 +1838,11 @@ summary.treeEnsemble <- function(model) { # Get the summary of a Random Forest Regression Model -#' @return \code{summary} returns a summary object of the fitted model, a list of components -#' including formula, number of features, list of features, feature importances, number of -#' trees, and tree weights +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes \code{formula} (formula), +#' \code{numFeatures} (number of features), \code{features} (list of features), +#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), +#' and \code{treeWeights} (tree weights). #' @rdname spark.randomForest #' @aliases summary,RandomForestRegressionModel-method #' @export @@ -2031,7 +2019,7 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), #' @param newData a SparkDataFrame for testing. #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named -#' "prediction" +#' "prediction". #' @rdname spark.gbt #' @aliases predict,GBTRegressionModel-method #' @export @@ -2052,8 +2040,8 @@ setMethod("predict", signature(object = "GBTClassificationModel"), # Save the Gradient Boosted Tree Regression or Classification model to the input path. -#' @param object A fitted Gradient Boosted Tree regression model or classification model -#' @param path The directory where the model is saved +#' @param object A fitted Gradient Boosted Tree regression model or classification model. +#' @param path The directory where the model is saved. #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' @aliases write.ml,GBTRegressionModel,character-method @@ -2076,9 +2064,11 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara # Get the summary of a Gradient Boosted Tree Regression Model -#' @return \code{summary} returns a summary object of the fitted model, a list of components -#' including formula, number of features, list of features, feature importances, number of -#' trees, and tree weights +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes \code{formula} (formula), +#' \code{numFeatures} (number of features), \code{features} (list of features), +#' \code{featureImportances} (feature importances), \code{numTrees} (number of trees), +#' and \code{treeWeights} (tree weights). #' @rdname spark.gbt #' @aliases summary,GBTRegressionModel-method #' @export diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index a7152b4313..e9d42c1e0a 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -322,6 +322,9 @@ sparkRHive.init <- function(jsc = NULL) { #' SparkSession or initializes a new SparkSession. #' Additional Spark properties can be set in \code{...}, and these named parameters take priority #' over values in \code{master}, \code{appName}, named lists of \code{sparkConfig}. +#' When called in an interactive session, this checks for the Spark installation, and, if not +#' found, it will be downloaded and cached automatically. Alternatively, \code{install.spark} can +#' be called manually. #' #' For details on how to initialize and use SparkR, refer to SparkR programming guide at #' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession}. @@ -407,6 +410,30 @@ sparkR.session <- function( sparkSession } +#' Get the URL of the SparkUI instance for the current active SparkSession +#' +#' Get the URL of the SparkUI instance for the current active SparkSession. +#' +#' @return the SparkUI URL, or NA if it is disabled, or not started. +#' @rdname sparkR.uiWebUrl +#' @name sparkR.uiWebUrl +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' url <- sparkR.uiWebUrl() +#' } +#' @note sparkR.uiWebUrl since 2.2.0 +sparkR.uiWebUrl <- function() { + sc <- sparkR.callJMethod(getSparkContext(), "sc") + u <- callJMethod(sc, "uiWebUrl") + if (callJMethod(u, "isDefined")) { + callJMethod(u, "get") + } else { + NA + } +} + #' Assigns a group ID to all the jobs started by this thread until the group ID is set to a #' different value or cleared. #' @@ -424,7 +451,7 @@ sparkR.session <- function( #' @method setJobGroup default setJobGroup.default <- function(groupId, description, interruptOnCancel) { sc <- getSparkContext() - callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel) + invisible(callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel)) } setJobGroup <- function(sc, groupId, description, interruptOnCancel) { @@ -454,7 +481,7 @@ setJobGroup <- function(sc, groupId, description, interruptOnCancel) { #' @method clearJobGroup default clearJobGroup.default <- function() { sc <- getSparkContext() - callJMethod(sc, "clearJobGroup") + invisible(callJMethod(sc, "clearJobGroup")) } clearJobGroup <- function(sc) { @@ -481,7 +508,7 @@ clearJobGroup <- function(sc) { #' @method cancelJobGroup default cancelJobGroup.default <- function(groupId) { sc <- getSparkContext() - callJMethod(sc, "cancelJobGroup", groupId) + invisible(callJMethod(sc, "cancelJobGroup", groupId)) } cancelJobGroup <- function(sc, groupId) { @@ -565,7 +592,7 @@ sparkCheckInstall <- function(sparkHome, master, deployMode) { message(msg) NULL } else { - if (isMasterLocal(master)) { + if (interactive() || isMasterLocal(master)) { msg <- paste0("Spark not found in SPARK_HOME: ", sparkHome) message(msg) packageLocalDir <- install.spark() diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 098c0e3e31..1283449f35 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -841,7 +841,7 @@ captureJVMException <- function(e, method) { # # @param inputData a list of rows, with each row a list # @return data.frame with raw columns as lists -rbindRaws <- function(inputData){ +rbindRaws <- function(inputData) { row1 <- inputData[[1]] rawcolumns <- ("raw" == sapply(row1, class)) @@ -851,3 +851,15 @@ rbindRaws <- function(inputData){ out[!rawcolumns] <- lapply(out[!rawcolumns], unlist) out } + +# Get basename without extension from URL +basenameSansExtFromUrl <- function(url) { + # split by '/' + splits <- unlist(strsplit(url, "^.+/")) + last <- tail(splits, 1) + # this is from file_path_sans_ext + # first, remove any compression extension + filename <- sub("[.](gz|bz2|xz)$", "", last) + # then, strip extension by the last '.' + sub("([^.]+)\\.[[:alnum:]]+$", "\\1", filename) +} diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 0553e704bd..0f0d831c6f 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -64,16 +64,6 @@ test_that("spark.glm and predict", { rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - # binomial family - binomialTraining <- training[training$Species %in% c("versicolor", "virginica"), ] - model <- spark.glm(binomialTraining, Species ~ Sepal_Length + Sepal_Width, - family = binomial(link = "logit")) - prediction <- predict(model, binomialTraining) - expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character") - expected <- c("virginica", "virginica", "virginica", "versicolor", "virginica", - "versicolor", "virginica", "versicolor", "virginica", "versicolor") - expect_equal(as.list(take(select(prediction, "prediction"), 10))[[1]], expected) - # poisson family model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, family = poisson(link = identity)) @@ -138,10 +128,10 @@ test_that("spark.glm summary", { expect_equal(stats$aic, rStats$aic) # Test spark.glm works with weighted dataset - a1 <- c(0, 1, 2, 3, 4) - a2 <- c(5, 2, 1, 3, 2) - w <- c(1, 2, 3, 4, 5) - b <- c(1, 0, 1, 0, 0) + a1 <- c(0, 1, 2, 3) + a2 <- c(5, 2, 1, 3) + w <- c(1, 2, 3, 4) + b <- c(1, 0, 1, 0) data <- as.data.frame(cbind(a1, a2, w, b)) df <- createDataFrame(data) @@ -168,7 +158,7 @@ test_that("spark.glm summary", { data <- as.data.frame(cbind(a1, a2, b)) df <- suppressWarnings(createDataFrame(data)) regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0)) - expect_equal(regStats$aic, 14.00976, tolerance = 1e-4) # 14.00976 is from summary() result + expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result # Test spark.glm works on collinear data A <- matrix(c(1, 2, 3, 4, 2, 4, 6, 8), 4, 2) @@ -360,6 +350,8 @@ test_that("spark.kmeans", { # Test summary works on KMeans summary.model <- summary(model) cluster <- summary.model$cluster + k <- summary.model$k + expect_equal(k, 2) expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) # Test model save/load @@ -645,68 +637,141 @@ test_that("spark.isotonicRegression", { }) test_that("spark.logit", { - # test binary logistic regression - label <- c(0.0, 0.0, 0.0, 1.0, 1.0) - feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) - binary_data <- as.data.frame(cbind(label, feature)) - binary_df <- createDataFrame(binary_data) - - blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0) - blr_predict <- collect(select(predict(blr_model, binary_df), "prediction")) - expect_equal(blr_predict$prediction, c("0.0", "0.0", "0.0", "0.0", "0.0")) - blr_model1 <- spark.logit(binary_df, label ~ feature, thresholds = 0.0) - blr_predict1 <- collect(select(predict(blr_model1, binary_df), "prediction")) - expect_equal(blr_predict1$prediction, c("1.0", "1.0", "1.0", "1.0", "1.0")) - - # test summary of binary logistic regression - blr_summary <- summary(blr_model) - blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure")) - expect_equal(blr_fmeasure$threshold, c(0.6565513, 0.6214563, 0.3325291, 0.2115995, 0.1778653), - tolerance = 1e-4) - expect_equal(blr_fmeasure$"F-Measure", c(0.6666667, 0.5000000, 0.8000000, 0.6666667, 0.5714286), - tolerance = 1e-4) - blr_precision <- collect(select(blr_summary$precisionByThreshold, "threshold", "precision")) - expect_equal(blr_precision$precision, c(1.0000000, 0.5000000, 0.6666667, 0.5000000, 0.4000000), - tolerance = 1e-4) - blr_recall <- collect(select(blr_summary$recallByThreshold, "threshold", "recall")) - expect_equal(blr_recall$recall, c(0.5000000, 0.5000000, 1.0000000, 1.0000000, 1.0000000), - tolerance = 1e-4) + # R code to reproduce the result. + # nolint start + #' library(glmnet) + #' iris.x = as.matrix(iris[, 1:4]) + #' iris.y = as.factor(as.character(iris[, 5])) + #' logit = glmnet(iris.x, iris.y, family="multinomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # $setosa + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 1.0981324 + # Sepal.Length -0.2909860 + # Sepal.Width 0.5510907 + # Petal.Length -0.1915217 + # Petal.Width -0.4211946 + # + # $versicolor + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 1.520061e+00 + # Sepal.Length 2.524501e-02 + # Sepal.Width -5.310313e-01 + # Petal.Length 3.656543e-02 + # Petal.Width -3.144464e-05 + # + # $virginica + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # -2.61819385 + # Sepal.Length 0.26574097 + # Sepal.Width -0.02005932 + # Petal.Length 0.15495629 + # Petal.Width 0.42122607 + # nolint end - # test model save and read - modelPath <- tempfile(pattern = "spark-logisticRegression", fileext = ".tmp") - write.ml(blr_model, modelPath) - expect_error(write.ml(blr_model, modelPath)) - write.ml(blr_model, modelPath, overwrite = TRUE) - blr_model2 <- read.ml(modelPath) - blr_predict2 <- collect(select(predict(blr_model2, binary_df), "prediction")) - expect_equal(blr_predict$prediction, blr_predict2$prediction) - expect_error(summary(blr_model2)) + # Test multinomial logistic regression againt three classes + df <- suppressWarnings(createDataFrame(iris)) + model <- spark.logit(df, Species ~ ., regParam = 0.5) + summary <- summary(model) + versicolorCoefsR <- c(1.52, 0.03, -0.53, 0.04, 0.00) + virginicaCoefsR <- c(-2.62, 0.27, -0.02, 0.16, 0.42) + setosaCoefsR <- c(1.10, -0.29, 0.55, -0.19, -0.42) + versicolorCoefs <- unlist(summary$coefficients[, "versicolor"]) + virginicaCoefs <- unlist(summary$coefficients[, "virginica"]) + setosaCoefs <- unlist(summary$coefficients[, "setosa"]) + expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) + expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) + expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1)) + + # Test model save and load + modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) unlink(modelPath) - # test prediction label as text - training <- suppressWarnings(createDataFrame(iris)) - binomial_training <- training[training$Species %in% c("versicolor", "virginica"), ] - binomial_model <- spark.logit(binomial_training, Species ~ Sepal_Length + Sepal_Width) - prediction <- predict(binomial_model, binomial_training) + # R code to reproduce the result. + # nolint start + #' library(glmnet) + #' iris2 <- iris[iris$Species %in% c("versicolor", "virginica"), ] + #' iris.x = as.matrix(iris2[, 1:4]) + #' iris.y = as.factor(as.character(iris2[, 5])) + #' logit = glmnet(iris.x, iris.y, family="multinomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # $versicolor + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 3.93844796 + # Sepal.Length -0.13538675 + # Sepal.Width -0.02386443 + # Petal.Length -0.35076451 + # Petal.Width -0.77971954 + # + # $virginica + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # -3.93844796 + # Sepal.Length 0.13538675 + # Sepal.Width 0.02386443 + # Petal.Length 0.35076451 + # Petal.Width 0.77971954 + # + #' logit = glmnet(iris.x, iris.y, family="binomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # (Intercept) -6.0824412 + # Sepal.Length 0.2458260 + # Sepal.Width 0.1642093 + # Petal.Length 0.4759487 + # Petal.Width 1.0383948 + # + # nolint end + + # Test multinomial logistic regression againt two classes + df <- suppressWarnings(createDataFrame(iris)) + training <- df[df$Species %in% c("versicolor", "virginica"), ] + model <- spark.logit(training, Species ~ ., regParam = 0.5, family = "multinomial") + summary <- summary(model) + versicolorCoefsR <- c(3.94, -0.16, -0.02, -0.35, -0.78) + virginicaCoefsR <- c(-3.94, 0.16, -0.02, 0.35, 0.78) + versicolorCoefs <- unlist(summary$coefficients[, "versicolor"]) + virginicaCoefs <- unlist(summary$coefficients[, "virginica"]) + expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) + expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) + + # Test binomial logistic regression againt two classes + model <- spark.logit(training, Species ~ ., regParam = 0.5) + summary <- summary(model) + coefsR <- c(-6.08, 0.25, 0.16, 0.48, 1.04) + coefs <- unlist(summary$coefficients[, "Estimate"]) + expect_true(all(abs(coefsR - coefs) < 0.1)) + + # Test prediction with string label + prediction <- predict(model, training) expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character") - expected <- c("virginica", "virginica", "virginica", "versicolor", "virginica", - "versicolor", "virginica", "versicolor", "virginica", "versicolor") + expected <- c("versicolor", "versicolor", "virginica", "versicolor", "versicolor", + "versicolor", "versicolor", "versicolor", "versicolor", "versicolor") expect_equal(as.list(take(select(prediction, "prediction"), 10))[[1]], expected) - # test multinomial logistic regression - label <- c(0.0, 1.0, 2.0, 0.0, 0.0) - feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667) - feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987) - feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130) - feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842) - data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4)) + # Test prediction with numeric label + label <- c(0.0, 0.0, 0.0, 1.0, 1.0) + feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) + data <- as.data.frame(cbind(label, feature)) df <- createDataFrame(data) - - model <- spark.logit(df, label ~., family = "multinomial", thresholds = c(0, 1, 1)) - predict1 <- collect(select(predict(model, df), "prediction")) - expect_equal(predict1$prediction, c("0.0", "0.0", "0.0", "0.0", "0.0")) - # Summary of multinomial logistic regression is not implemented yet - expect_error(summary(model)) + model <- spark.logit(df, label ~ feature) + prediction <- collect(select(predict(model, df), "prediction")) + expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0")) }) test_that("spark.gaussianMixture", { @@ -863,10 +928,10 @@ test_that("spark.posterior and spark.perplexity", { test_that("spark.als", { data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), - list(2, 1, 1.0), list(2, 2, 5.0)) + list(2, 1, 1.0), list(2, 2, 5.0)) df <- createDataFrame(data, c("user", "item", "score")) model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item", - rank = 10, maxIter = 5, seed = 0, reg = 0.1) + rank = 10, maxIter = 5, seed = 0, regParam = 0.1) stats <- summary(model) expect_equal(stats$rank, 10) test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item")) @@ -921,6 +986,12 @@ test_that("spark.kstest", { expect_equal(stats$p.value, rStats$p.value, tolerance = 1e-4) expect_equal(stats$statistic, unname(rStats$statistic), tolerance = 1e-4) expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") + + # Test print.summary.KSTest + printStats <- capture.output(print.summary.KSTest(stats)) + expect_match(printStats[1], "Kolmogorov-Smirnov test summary:") + expect_match(printStats[5], + "Low presumption against null hypothesis: Sample follows theoretical distribution. ") }) test_that("spark.randomForest", { @@ -944,10 +1015,11 @@ test_that("spark.randomForest", { model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, numTrees = 20, seed = 123) predictions <- collect(predict(model, data)) - expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258, - 63.736, 64.296, 64.868, 64.300, - 66.709, 67.697, 67.966, 67.252, - 68.866, 69.593, 69.195, 69.658), + expect_equal(predictions$prediction, c(60.32820, 61.22315, 60.69025, 62.11070, + 63.53160, 64.05470, 65.12710, 64.30450, + 66.70910, 67.86125, 68.08700, 67.21865, + 68.89275, 69.53180, 69.39640, 69.68250), + tolerance = 1e-4) stats <- summary(model) expect_equal(stats$numTrees, 20) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index c669c2e2e2..4490f31cd8 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -205,6 +205,7 @@ test_that("create DataFrame from RDD", { c(16)) expect_equal(collect(sql("SELECT height from people WHERE name ='Bob'"))$height, c(176.5)) + sql("DROP TABLE people") unsetHiveContext() }) @@ -576,7 +577,7 @@ test_that("test tableNames and tables", { tables <- tables() expect_equal(count(tables), 2) suppressWarnings(dropTempTable("table1")) - dropTempView("table2") + expect_true(dropTempView("table2")) tables <- tables() expect_equal(count(tables), 0) @@ -589,7 +590,7 @@ test_that( newdf <- sql("SELECT * FROM table1 where name = 'Michael'") expect_is(newdf, "SparkDataFrame") expect_equal(count(newdf), 1) - dropTempView("table1") + expect_true(dropTempView("table1")) createOrReplaceTempView(df, "dfView") sqlCast <- collect(sql("select cast('2' as decimal) as x from dfView limit 1")) @@ -600,7 +601,7 @@ test_that( expect_equal(ncol(sqlCast), 1) expect_equal(out[1], " x") expect_equal(out[2], "1 2") - dropTempView("dfView") + expect_true(dropTempView("dfView")) }) test_that("test cache, uncache and clearCache", { @@ -609,7 +610,7 @@ test_that("test cache, uncache and clearCache", { cacheTable("table1") uncacheTable("table1") clearCache() - dropTempView("table1") + expect_true(dropTempView("table1")) }) test_that("insertInto() on a registered table", { @@ -630,13 +631,13 @@ test_that("insertInto() on a registered table", { insertInto(dfParquet2, "table1") expect_equal(count(sql("select * from table1")), 5) expect_equal(first(sql("select * from table1 order by age"))$name, "Michael") - dropTempView("table1") + expect_true(dropTempView("table1")) createOrReplaceTempView(dfParquet, "table1") insertInto(dfParquet2, "table1", overwrite = TRUE) expect_equal(count(sql("select * from table1")), 2) expect_equal(first(sql("select * from table1 order by age"))$name, "Bob") - dropTempView("table1") + expect_true(dropTempView("table1")) unlink(jsonPath2) unlink(parquetPath2) @@ -650,7 +651,7 @@ test_that("tableToDF() returns a new DataFrame", { expect_equal(count(tabledf), 3) tabledf2 <- tableToDF("table1") expect_equal(count(tabledf2), 3) - dropTempView("table1") + expect_true(dropTempView("table1")) }) test_that("toRDD() returns an RRDD", { @@ -2612,7 +2613,7 @@ test_that("randomSplit", { expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 }))) }) -test_that("Setting and getting config on SparkSession", { +test_that("Setting and getting config on SparkSession, sparkR.conf(), sparkR.uiWebUrl()", { # first, set it to a random but known value conf <- callJMethod(sparkSession, "conf") property <- paste0("spark.testing.", as.character(runif(1))) @@ -2636,6 +2637,9 @@ test_that("Setting and getting config on SparkSession", { expect_equal(appNameValue, "sparkSession test") expect_equal(testValue, value) expect_error(sparkR.conf("completely.dummy"), "Config 'completely.dummy' is not set") + + url <- sparkR.uiWebUrl() + expect_equal(substr(url, 1, 7), "http://") }) test_that("enableHiveSupport on SparkSession", { diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 607c407f04..c875248428 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -228,4 +228,15 @@ test_that("varargsToStrEnv", { expect_warning(varargsToStrEnv(1, 2, 3, 4), "Unnamed arguments ignored: 1, 2, 3, 4.") }) +test_that("basenameSansExtFromUrl", { + x <- paste0("http://people.apache.org/~pwendell/spark-nightly/spark-branch-2.1-bin/spark-2.1.1-", + "SNAPSHOT-2016_12_09_11_08-eb2d9bf-bin/spark-2.1.1-SNAPSHOT-bin-hadoop2.7.tgz") + y <- paste0("http://people.apache.org/~pwendell/spark-releases/spark-2.1.0-rc2-bin/spark-2.1.0-", + "bin-hadoop2.4-without-hive.tgz") + expect_equal(basenameSansExtFromUrl(x), "spark-2.1.1-SNAPSHOT-bin-hadoop2.7") + expect_equal(basenameSansExtFromUrl(y), "spark-2.1.0-bin-hadoop2.4-without-hive") + z <- "http://people.apache.org/~pwendell/spark-releases/spark-2.1.0--hive.tar.gz" + expect_equal(basenameSansExtFromUrl(z), "spark-2.1.0--hive") +}) + sparkR.session.stop() diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 73a5e26a3b..6f11c5c516 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -94,13 +94,13 @@ sparkR.session.stop() Different from many other R packages, to use SparkR, you need an additional installation of Apache Spark. The Spark installation will be used to run a backend process that will compile and execute SparkR programs. -If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](http://spark.apache.org/downloads.html). Alternatively, we provide an easy-to-use function `install.spark` to complete this process. You don't have to call it explicitly. We will check the installation when `sparkR.session` is called and `install.spark` function will be triggered automatically if no installation is found. +After installing the SparkR package, you can call `sparkR.session` as explained in the previous section to start and it will check for the Spark installation. If you are working with SparkR from an interactive shell (eg. R, RStudio) then Spark is downloaded and cached automatically if it is not found. Alternatively, we provide an easy-to-use function `install.spark` for running this manually. If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](http://spark.apache.org/downloads.html). ```{r, eval=FALSE} install.spark() ``` -If you already have Spark installed, you don't have to install again and can pass the `sparkHome` argument to `sparkR.session` to let SparkR know where the Spark installation is. +If you already have Spark installed, you don't have to install again and can pass the `sparkHome` argument to `sparkR.session` to let SparkR know where the existing Spark installation is. ```{r, eval=FALSE} sparkR.session(sparkHome = "/HOME/spark") @@ -447,25 +447,43 @@ head(teenagers) SparkR supports the following machine learning models and algorithms. -* Generalized Linear Model (GLM) +#### Classification -* Naive Bayes Model +* Logistic Regression -* $k$-means Clustering +* Multilayer Perceptron (MLP) + +* Naive Bayes + +#### Regression * Accelerated Failure Time (AFT) Survival Model +* Generalized Linear Model (GLM) + +* Isotonic Regression + +#### Tree - Classification and Regression + +* Gradient-Boosted Trees (GBT) + +* Random Forest + +#### Clustering + * Gaussian Mixture Model (GMM) +* $k$-means Clustering + * Latent Dirichlet Allocation (LDA) -* Multilayer Perceptron Model +#### Collaborative Filtering -* Collaborative Filtering with Alternating Least Squares (ALS) +* Alternating Least Squares (ALS) -* Isotonic Regression Model +#### Statistics -More will be added in the future. +* Kolmogorov-Smirnov Test ### R Formula @@ -490,9 +508,115 @@ count(carsDF_test) head(carsDF_test) ``` - ### Models and Algorithms +#### Logistic Regression + +[Logistic regression](https://en.wikipedia.org/wiki/Logistic_regression) is a widely-used model when the response is categorical. It can be seen as a special case of the [Generalized Linear Predictive Model](https://en.wikipedia.org/wiki/Generalized_linear_model). +We provide `spark.logit` on top of `spark.glm` to support logistic regression with advanced hyper-parameters. +It supports both binary and multiclass classification with elastic-net regularization and feature standardization, similar to `glmnet`. + +We use a simple example to demonstrate `spark.logit` usage. In general, there are three steps of using `spark.logit`: +1). Create a dataframe from a proper data source; 2). Fit a logistic regression model using `spark.logit` with a proper parameter setting; +and 3). Obtain the coefficient matrix of the fitted model using `summary` and use the model for prediction with `predict`. + +Binomial logistic regression +```{r, warning=FALSE} +df <- createDataFrame(iris) +# Create a DataFrame containing two classes +training <- df[df$Species %in% c("versicolor", "virginica"), ] +model <- spark.logit(training, Species ~ ., regParam = 0.00042) +summary(model) +``` + +Predict values on training data +```{r} +fitted <- predict(model, training) +``` + +Multinomial logistic regression against three classes +```{r, warning=FALSE} +df <- createDataFrame(iris) +# Note in this case, Spark infers it is multinomial logistic regression, so family = "multinomial" is optional. +model <- spark.logit(df, Species ~ ., regParam = 0.056) +summary(model) +``` + +#### Multilayer Perceptron + +Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). MLPC consists of multiple layers of nodes. Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes map inputs to outputs by a linear combination of the inputs with the node’s weights $w$ and bias $b$ and applying an activation function. This can be written in matrix form for MLPC with $K+1$ layers as follows: +$$ +y(x)=f_K(\ldots f_2(w_2^T f_1(w_1^T x + b_1) + b_2) \ldots + b_K). +$$ + +Nodes in intermediate layers use sigmoid (logistic) function: +$$ +f(z_i) = \frac{1}{1+e^{-z_i}}. +$$ + +Nodes in the output layer use softmax function: +$$ +f(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}}. +$$ + +The number of nodes $N$ in the output layer corresponds to the number of classes. + +MLPC employs backpropagation for learning the model. We use the logistic loss function for optimization and L-BFGS as an optimization routine. + +`spark.mlp` requires at least two columns in `data`: one named `"label"` and the other one `"features"`. The `"features"` column should be in libSVM-format. + +We use iris data set to show how to use `spark.mlp` in classification. +```{r, warning=FALSE} +df <- createDataFrame(iris) +# fit a Multilayer Perceptron Classification Model +model <- spark.mlp(df, Species ~ ., blockSize = 128, layers = c(4, 3), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) +``` + +To avoid lengthy display, we only present partial results of the model summary. You can check the full result from your sparkR shell. +```{r, include=FALSE} +ops <- options() +options(max.print=5) +``` +```{r} +# check the summary of the fitted model +summary(model) +``` +```{r, include=FALSE} +options(ops) +``` +```{r} +# make predictions use the fitted model +predictions <- predict(model, df) +head(select(predictions, predictions$prediction)) +``` + +#### Naive Bayes + +Naive Bayes model assumes independence among the features. `spark.naiveBayes` fits a [Bernoulli naive Bayes model](https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Bernoulli_naive_Bayes) against a SparkDataFrame. The data should be all categorical. These models are often used for document classification. + +```{r} +titanic <- as.data.frame(Titanic) +titanicDF <- createDataFrame(titanic[titanic$Freq > 0, -5]) +naiveBayesModel <- spark.naiveBayes(titanicDF, Survived ~ Class + Sex + Age) +summary(naiveBayesModel) +naiveBayesPrediction <- predict(naiveBayesModel, titanicDF) +head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction")) +``` + +#### Accelerated Failure Time Survival Model + +Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. + +Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. +```{r, warning=FALSE} +library(survival) +ovarianDF <- createDataFrame(ovarian) +aftModel <- spark.survreg(ovarianDF, Surv(futime, fustat) ~ ecog_ps + rx) +summary(aftModel) +aftPredictions <- predict(aftModel, ovarianDF) +head(aftPredictions) +``` + #### Generalized Linear Model The main function is `spark.glm`. The following families and link functions are supported. The default is gaussian. @@ -526,46 +650,78 @@ gaussianFitted <- predict(gaussianGLM, carsDF) head(select(gaussianFitted, "model", "prediction", "mpg", "wt", "hp")) ``` -#### Naive Bayes Model +#### Isotonic Regression -Naive Bayes model assumes independence among the features. `spark.naiveBayes` fits a [Bernoulli naive Bayes model](https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Bernoulli_naive_Bayes) against a SparkDataFrame. The data should be all categorical. These models are often used for document classification. +`spark.isoreg` fits an [Isotonic Regression](https://en.wikipedia.org/wiki/Isotonic_regression) model against a `SparkDataFrame`. It solves a weighted univariate a regression problem under a complete order constraint. Specifically, given a set of real observed responses $y_1, \ldots, y_n$, corresponding real features $x_1, \ldots, x_n$, and optionally positive weights $w_1, \ldots, w_n$, we want to find a monotone (piecewise linear) function $f$ to minimize +$$ +\ell(f) = \sum_{i=1}^n w_i (y_i - f(x_i))^2. +$$ + +There are a few more arguments that may be useful. + +* `weightCol`: a character string specifying the weight column. + +* `isotonic`: logical value indicating whether the output sequence should be isotonic/increasing (`TRUE`) or antitonic/decreasing (`FALSE`). + +* `featureIndex`: the index of the feature on the right hand side of the formula if it is a vector column (default: 0), no effect otherwise. + +We use an artificial example to show the use. ```{r} -titanic <- as.data.frame(Titanic) -titanicDF <- createDataFrame(titanic[titanic$Freq > 0, -5]) -naiveBayesModel <- spark.naiveBayes(titanicDF, Survived ~ Class + Sex + Age) -summary(naiveBayesModel) -naiveBayesPrediction <- predict(naiveBayesModel, titanicDF) -head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction")) +y <- c(3.0, 6.0, 8.0, 5.0, 7.0) +x <- c(1.0, 2.0, 3.5, 3.0, 4.0) +w <- rep(1.0, 5) +data <- data.frame(y = y, x = x, w = w) +df <- createDataFrame(data) +isoregModel <- spark.isoreg(df, y ~ x, weightCol = "w") +isoregFitted <- predict(isoregModel, df) +head(select(isoregFitted, "x", "y", "prediction")) ``` -#### k-Means Clustering +In the prediction stage, based on the fitted monotone piecewise function, the rules are: -`spark.kmeans` fits a $k$-means clustering model against a `SparkDataFrame`. As an unsupervised learning method, we don't need a response variable. Hence, the left hand side of the R formula should be left blank. The clustering is based only on the variables on the right hand side. +* If the prediction input exactly matches a training feature then associated prediction is returned. In case there are multiple predictions with the same feature then one of them is returned. Which one is undefined. + +* If the prediction input is lower or higher than all training features then prediction with lowest or highest feature is returned respectively. In case there are multiple predictions with the same feature then the lowest or highest is returned respectively. + +* If the prediction input falls between two training features then prediction is treated as piecewise linear function and interpolated value is calculated from the predictions of the two closest features. In case there are multiple values with the same feature then the same rules as in previous point are used. + +For example, when the input is $3.2$, the two closest feature values are $3.0$ and $3.5$, then predicted value would be a linear interpolation between the predicted values at $3.0$ and $3.5$. ```{r} -kmeansModel <- spark.kmeans(carsDF, ~ mpg + hp + wt, k = 3) -summary(kmeansModel) -kmeansPredictions <- predict(kmeansModel, carsDF) -head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20L) +newDF <- createDataFrame(data.frame(x = c(1.5, 3.2))) +head(predict(isoregModel, newDF)) ``` -#### AFT Survival Model -Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. +#### Gradient-Boosted Trees + +`spark.gbt` fits a [gradient-boosted tree](https://en.wikipedia.org/wiki/Gradient_boosting) classification or regression model on a `SparkDataFrame`. +Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. + +Similar to the random forest example above, we use the `longley` dataset to train a gradient-boosted tree and make predictions: -Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. ```{r, warning=FALSE} -library(survival) -ovarianDF <- createDataFrame(ovarian) -aftModel <- spark.survreg(ovarianDF, Surv(futime, fustat) ~ ecog_ps + rx) -summary(aftModel) -aftPredictions <- predict(aftModel, ovarianDF) -head(aftPredictions) +df <- createDataFrame(longley) +gbtModel <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 2, maxIter = 2) +summary(gbtModel) +predictions <- predict(gbtModel, df) ``` -#### Gaussian Mixture Model +#### Random Forest -(Coming in 2.1.0) +`spark.randomForest` fits a [random forest](https://en.wikipedia.org/wiki/Random_forest) classification or regression model on a `SparkDataFrame`. +Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. + +In the following example, we use the `longley` dataset to train a random forest and make predictions: + +```{r, warning=FALSE} +df <- createDataFrame(longley) +rfModel <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 2, numTrees = 2) +summary(rfModel) +predictions <- predict(rfModel, df) +``` + +#### Gaussian Mixture Model `spark.gaussianMixture` fits multivariate [Gaussian Mixture Model](https://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) (GMM) against a `SparkDataFrame`. [Expectation-Maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) (EM) is used to approximate the maximum likelihood estimator (MLE) of the model. @@ -581,10 +737,18 @@ gmmFitted <- predict(gmmModel, df) head(select(gmmFitted, "V1", "V2", "prediction")) ``` +#### k-Means Clustering -#### Latent Dirichlet Allocation +`spark.kmeans` fits a $k$-means clustering model against a `SparkDataFrame`. As an unsupervised learning method, we don't need a response variable. Hence, the left hand side of the R formula should be left blank. The clustering is based only on the variables on the right hand side. -(Coming in 2.1.0) +```{r} +kmeansModel <- spark.kmeans(carsDF, ~ mpg + hp + wt, k = 3) +summary(kmeansModel) +kmeansPredictions <- predict(kmeansModel, carsDF) +head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20L) +``` + +#### Latent Dirichlet Allocation `spark.lda` fits a [Latent Dirichlet Allocation](https://en.wikipedia.org/wiki/Latent_Dirichlet_allocation) model on a `SparkDataFrame`. It is often used in topic modeling in which topics are inferred from a collection of text documents. LDA can be thought of as a clustering algorithm as follows: @@ -600,22 +764,6 @@ To use LDA, we need to specify a `features` column in `data` where each entry re * libSVM: Each entry is a collection of words and will be processed directly. -There are several parameters LDA takes for fitting the model. - -* `k`: number of topics (default 10). - -* `maxIter`: maximum iterations (default 20). - -* `optimizer`: optimizer to train an LDA model, "online" (default) uses [online variational inference](https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf). "em" uses [expectation-maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm). - -* `subsamplingRate`: For `optimizer = "online"`. Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1] (default 0.05). - -* `topicConcentration`: concentration parameter (commonly named beta or eta) for the prior placed on topic distributions over terms, default -1 to set automatically on the Spark side. Use `summary` to retrieve the effective topicConcentration. Only 1-size numeric is accepted. - -* `docConcentration`: concentration parameter (commonly named alpha) for the prior placed on documents distributions over topics (theta), default -1 to set automatically on the Spark side. Use `summary` to retrieve the effective docConcentration. Only 1-size or k-size numeric is accepted. - -* `maxVocabSize`: maximum vocabulary size, default 1 << 18. - Two more functions are provided for the fitted model. * `spark.posterior` returns a `SparkDataFrame` containing a column of posterior probabilities vectors named "topicDistribution". @@ -654,47 +802,7 @@ perplexity <- spark.perplexity(model, corpusDF) perplexity ``` - -#### Multilayer Perceptron - -(Coming in 2.1.0) - -Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). MLPC consists of multiple layers of nodes. Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes map inputs to outputs by a linear combination of the inputs with the node’s weights $w$ and bias $b$ and applying an activation function. This can be written in matrix form for MLPC with $K+1$ layers as follows: -$$ -y(x)=f_K(\ldots f_2(w_2^T f_1(w_1^T x + b_1) + b_2) \ldots + b_K). -$$ - -Nodes in intermediate layers use sigmoid (logistic) function: -$$ -f(z_i) = \frac{1}{1+e^{-z_i}}. -$$ - -Nodes in the output layer use softmax function: -$$ -f(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}}. -$$ - -The number of nodes $N$ in the output layer corresponds to the number of classes. - -MLPC employs backpropagation for learning the model. We use the logistic loss function for optimization and L-BFGS as an optimization routine. - -`spark.mlp` requires at least two columns in `data`: one named `"label"` and the other one `"features"`. The `"features"` column should be in libSVM-format. According to the description above, there are several additional parameters that can be set: - -* `layers`: integer vector containing the number of nodes for each layer. - -* `solver`: solver parameter, supported options: `"gd"` (minibatch gradient descent) or `"l-bfgs"`. - -* `maxIter`: maximum iteration number. - -* `tol`: convergence tolerance of iterations. - -* `stepSize`: step size for `"gd"`. - -* `seed`: seed parameter for weights initialization. - -#### Collaborative Filtering - -(Coming in 2.1.0) +#### Alternating Least Squares `spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). @@ -723,53 +831,28 @@ predicted <- predict(model, df) head(predicted) ``` -#### Isotonic Regression Model - -(Coming in 2.1.0) - -`spark.isoreg` fits an [Isotonic Regression](https://en.wikipedia.org/wiki/Isotonic_regression) model against a `SparkDataFrame`. It solves a weighted univariate a regression problem under a complete order constraint. Specifically, given a set of real observed responses $y_1, \ldots, y_n$, corresponding real features $x_1, \ldots, x_n$, and optionally positive weights $w_1, \ldots, w_n$, we want to find a monotone (piecewise linear) function $f$ to minimize -$$ -\ell(f) = \sum_{i=1}^n w_i (y_i - f(x_i))^2. -$$ - -There are a few more arguments that may be useful. - -* `weightCol`: a character string specifying the weight column. - -* `isotonic`: logical value indicating whether the output sequence should be isotonic/increasing (`TRUE`) or antitonic/decreasing (`FALSE`). - -* `featureIndex`: the index of the feature on the right hand side of the formula if it is a vector column (default: 0), no effect otherwise. - -We use an artificial example to show the use. - -```{r} -y <- c(3.0, 6.0, 8.0, 5.0, 7.0) -x <- c(1.0, 2.0, 3.5, 3.0, 4.0) -w <- rep(1.0, 5) -data <- data.frame(y = y, x = x, w = w) -df <- createDataFrame(data) -isoregModel <- spark.isoreg(df, y ~ x, weightCol = "w") -isoregFitted <- predict(isoregModel, df) -head(select(isoregFitted, "x", "y", "prediction")) -``` - -In the prediction stage, based on the fitted monotone piecewise function, the rules are: +#### Kolmogorov-Smirnov Test -* If the prediction input exactly matches a training feature then associated prediction is returned. In case there are multiple predictions with the same feature then one of them is returned. Which one is undefined. +`spark.kstest` runs a two-sided, one-sample [Kolmogorov-Smirnov (KS) test](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test). +Given a `SparkDataFrame`, the test compares continuous data in a given column `testCol` with the theoretical distribution +specified by parameter `nullHypothesis`. +Users can call `summary` to get a summary of the test results. -* If the prediction input is lower or higher than all training features then prediction with lowest or highest feature is returned respectively. In case there are multiple predictions with the same feature then the lowest or highest is returned respectively. +In the following example, we test whether the `longley` dataset's `Armed_Forces` column +follows a normal distribution. We set the parameters of the normal distribution using +the mean and standard deviation of the sample. -* If the prediction input falls between two training features then prediction is treated as piecewise linear function and interpolated value is calculated from the predictions of the two closest features. In case there are multiple values with the same feature then the same rules as in previous point are used. - -For example, when the input is $3.2$, the two closest feature values are $3.0$ and $3.5$, then predicted value would be a linear interpolation between the predicted values at $3.0$ and $3.5$. +```{r, warning=FALSE} +df <- createDataFrame(longley) +afStats <- head(select(df, mean(df$Armed_Forces), sd(df$Armed_Forces))) +afMean <- afStats[1] +afStd <- afStats[2] -```{r} -newDF <- createDataFrame(data.frame(x = c(1.5, 3.2))) -head(predict(isoregModel, newDF)) +test <- spark.kstest(df, "Armed_Forces", "norm", c(afMean, afStd)) +testSummary <- summary(test) +testSummary ``` -#### What's More? -We also expect Decision Tree, Random Forest, Kolmogorov-Smirnov Test coming in the next version 2.1.0. ### Model Persistence The following example shows how to save/load an ML model by SparkR. diff --git a/README.md b/README.md index 853f7f5ded..d0eca1ddea 100644 --- a/README.md +++ b/README.md @@ -13,8 +13,7 @@ and Spark Streaming for stream processing. ## Online Documentation You can find the latest Spark documentation, including a programming -guide, on the [project web page](http://spark.apache.org/documentation.html) -and [project wiki](https://cwiki.apache.org/confluence/display/SPARK). +guide, on the [project web page](http://spark.apache.org/documentation.html). This README file only contains basic setup instructions. ## Building Spark @@ -30,8 +29,7 @@ You can build Spark using more than one thread by using the -T option with Maven More detailed documentation is available from the project site, at ["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). -For general development tips, including info on developing Spark using an IDE, see -[http://spark.apache.org/developer-tools.html](the Useful Developer Tools page). +For general development tips, including info on developing Spark using an IDE, see ["Useful Developer Tools"](http://spark.apache.org/developer-tools.html). ## Interactive Scala Shell diff --git a/assembly/pom.xml b/assembly/pom.xml index ec243eaeba..53f18796e6 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../pom.xml diff --git a/build/mvn b/build/mvn index c3ab62da36..866bad892c 100755 --- a/build/mvn +++ b/build/mvn @@ -91,13 +91,13 @@ install_mvn() { # Install zinc under the build/ folder install_zinc() { - local zinc_path="zinc-0.3.9/bin/zinc" + local zinc_path="zinc-0.3.11/bin/zinc" [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} install_app \ - "${TYPESAFE_MIRROR}/zinc/0.3.9" \ - "zinc-0.3.9.tgz" \ + "${TYPESAFE_MIRROR}/zinc/0.3.11" \ + "zinc-0.3.11.tgz" \ "${zinc_path}" ZINC_BIN="${_DIR}/${zinc_path}" } diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index ca99fa89eb..8657af744c 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -91,6 +91,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + org.mockito mockito-core diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java index 78034a69f7..340986a63b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java @@ -60,7 +60,7 @@ public class AesCipher { private final Properties properties; public AesCipher(AesConfigMessage configMessage, TransportConf conf) throws IOException { - this.properties = CryptoStreamUtils.toCryptoConf(conf); + this.properties = conf.cryptoConf(); this.inKeySpec = new SecretKeySpec(configMessage.inKey, "AES"); this.inIvSpec = new IvParameterSpec(configMessage.inIv); this.outKeySpec = new SecretKeySpec(configMessage.outKey, "AES"); @@ -105,7 +105,7 @@ public void addToChannel(Channel ch) throws IOException { */ public static AesConfigMessage createConfigMessage(TransportConf conf) { int keySize = conf.aesCipherKeySize(); - Properties properties = CryptoStreamUtils.toCryptoConf(conf); + Properties properties = conf.cryptoConf(); try { int paramLen = CryptoCipherFactory.getCryptoCipher(AesCipher.TRANSFORM, properties) @@ -128,19 +128,6 @@ public static AesConfigMessage createConfigMessage(TransportConf conf) { } } - /** - * CryptoStreamUtils is used to convert config from TransportConf to AES Crypto config. - */ - private static class CryptoStreamUtils { - public static Properties toCryptoConf(TransportConf conf) { - Properties props = new Properties(); - if (conf.aesCipherClass() != null) { - props.setProperty(CryptoCipherFactory.CLASSES_KEY, conf.aesCipherClass()); - } - return props; - } - } - private static class AesEncryptHandler extends ChannelOutboundHandlerAdapter { private final ByteArrayWritableChannel byteChannel; private final CryptoOutputStream cos; diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java index d944d9da1c..f6aef499b2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java @@ -17,6 +17,7 @@ package org.apache.spark.network.util; +import java.util.Map; import java.util.NoSuchElementException; /** @@ -26,6 +27,9 @@ public abstract class ConfigProvider { /** Obtains the value of the given config, throws NoSuchElementException if it doesn't exist. */ public abstract String get(String name); + /** Returns all the config values in the provider. */ + public abstract Iterable> getAll(); + public String get(String name, String defaultValue) { try { return get(name); @@ -49,4 +53,5 @@ public double getDouble(String name, double defaultValue) { public boolean getBoolean(String name, boolean defaultValue) { return Boolean.parseBoolean(get(name, Boolean.toString(defaultValue))); } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/CryptoUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/CryptoUtils.java new file mode 100644 index 0000000000..a6d8358ee9 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/CryptoUtils.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.Map; +import java.util.Properties; + +/** + * Utility methods related to the commons-crypto library. + */ +public class CryptoUtils { + + // The prefix for the configurations passing to Apache Commons Crypto library. + public static final String COMMONS_CRYPTO_CONFIG_PREFIX = "commons.crypto."; + + /** + * Extract the commons-crypto configuration embedded in a list of config values. + * + * @param prefix Prefix in the given configuration that identifies the commons-crypto configs. + * @param conf List of configuration values. + */ + public static Properties toCryptoConf(String prefix, Iterable> conf) { + Properties props = new Properties(); + for (Map.Entry e : conf) { + String key = e.getKey(); + if (key.startsWith(prefix)) { + props.setProperty(COMMONS_CRYPTO_CONFIG_PREFIX + key.substring(prefix.length()), + e.getValue()); + } + } + return props; + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java index 668d2356b9..b6667998b5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java @@ -19,11 +19,16 @@ import com.google.common.collect.Maps; +import java.util.Collections; import java.util.Map; import java.util.NoSuchElementException; /** ConfigProvider based on a Map (copied in the constructor). */ public class MapConfigProvider extends ConfigProvider { + + public static final MapConfigProvider EMPTY = new MapConfigProvider( + Collections.emptyMap()); + private final Map config; public MapConfigProvider(Map config) { @@ -38,4 +43,10 @@ public String get(String name) { } return value; } + + @Override + public Iterable> getAll() { + return config.entrySet(); + } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 012bb098f6..223d6d88de 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -17,6 +17,8 @@ package org.apache.spark.network.util; +import java.util.Properties; + import com.google.common.primitives.Ints; /** @@ -24,11 +26,6 @@ */ public class TransportConf { - static { - // Set this due to Netty PR #5661 for Netty 4.0.37+ to work - System.setProperty("io.netty.maxDirectMemory", "0"); - } - private final String SPARK_NETWORK_IO_MODE_KEY; private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY; private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY; @@ -179,21 +176,22 @@ public boolean saslServerAlwaysEncrypt() { * The trigger for enabling AES encryption. */ public boolean aesEncryptionEnabled() { - return conf.getBoolean("spark.authenticate.encryption.aes.enabled", false); + return conf.getBoolean("spark.network.aes.enabled", false); } /** - * The implementation class for crypto cipher + * The key size to use when AES cipher is enabled. Notice that the length should be 16, 24 or 32 + * bytes. */ - public String aesCipherClass() { - return conf.get("spark.authenticate.encryption.aes.cipher.class", null); + public int aesCipherKeySize() { + return conf.getInt("spark.network.aes.keySize", 16); } /** - * The bytes of AES cipher key which is effective when AES cipher is enabled. Notice that - * the length should be 16, 24 or 32 bytes. + * The commons-crypto configuration for the module. */ - public int aesCipherKeySize() { - return conf.getInt("spark.authenticate.encryption.aes.cipher.keySize", 16); + public Properties cryptoConf() { + return CryptoUtils.toCryptoConf("spark.network.aes.config.", conf.getAll()); } + } diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 6d62eaf35d..5bb8819132 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -48,7 +48,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class ChunkFetchIntegrationSuite { @@ -87,7 +87,7 @@ public static void setUp() throws Exception { Closeables.close(fp, shouldSuppressIOException); } - final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); streamManager = new StreamManager() { diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index a7a99f3bfc..8ff737b129 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -42,7 +42,7 @@ import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class RpcIntegrationSuite { @@ -53,7 +53,7 @@ public class RpcIntegrationSuite { @BeforeClass public static void setUp() throws Exception { - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); rpcHandler = new RpcHandler() { @Override public void receive( diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index 9c49556927..f253a07e64 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -47,7 +47,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class StreamSuite { @@ -91,7 +91,7 @@ public static void setUp() throws Exception { fp.close(); } - final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); final StreamManager streamManager = new StreamManager() { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 44d16d5422..f54a64cb0f 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -40,9 +40,8 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.util.ConfigProvider; -import org.apache.spark.network.util.SystemPropertyConfigProvider; -import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; public class TransportClientFactorySuite { @@ -53,7 +52,7 @@ public class TransportClientFactorySuite { @Before public void setUp() { - conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); RpcHandler rpcHandler = new NoOpRpcHandler(); context = new TransportContext(conf, rpcHandler); server1 = context.createServer(); @@ -199,6 +198,11 @@ public String get(String name) { } return value; } + + @Override + public Iterable> getAll() { + throw new UnsupportedOperationException(); + } }); TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); TransportClientFactory factory = context.createClientFactory(); diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index ef2ab34b22..e27301f49e 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -24,7 +24,9 @@ import java.lang.reflect.Method; import java.nio.ByteBuffer; import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeoutException; @@ -32,6 +34,7 @@ import java.util.concurrent.atomic.AtomicReference; import javax.security.sasl.SaslException; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.io.ByteStreams; import com.google.common.io.Files; @@ -60,7 +63,7 @@ import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.ByteArrayWritableChannel; import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; /** @@ -225,7 +228,7 @@ public void testEncryptedMessage() throws Exception { public void testEncryptedMessageChunking() throws Exception { File file = File.createTempFile("sasltest", ".txt"); try { - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); byte[] data = new byte[8 * 1024]; new Random().nextBytes(data); @@ -253,14 +256,14 @@ public void testEncryptedMessageChunking() throws Exception { @Test public void testFileRegionEncryption() throws Exception { - final String blockSizeConf = "spark.network.sasl.maxEncryptedBlockSize"; - System.setProperty(blockSizeConf, "1k"); + final Map testConf = ImmutableMap.of( + "spark.network.sasl.maxEncryptedBlockSize", "1k"); final AtomicReference response = new AtomicReference<>(); final File file = File.createTempFile("sasltest", ".txt"); SaslTestCtx ctx = null; try { - final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf)); StreamManager sm = mock(StreamManager.class); when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { @Override @@ -276,7 +279,7 @@ public ManagedBuffer answer(InvocationOnMock invocation) { new Random().nextBytes(data); Files.write(data, file); - ctx = new SaslTestCtx(rpcHandler, true, false, false); + ctx = new SaslTestCtx(rpcHandler, true, false, false, testConf); final CountDownLatch lock = new CountDownLatch(1); @@ -307,18 +310,15 @@ public Void answer(InvocationOnMock invocation) { if (response.get() != null) { response.get().release(); } - System.clearProperty(blockSizeConf); } } @Test public void testServerAlwaysEncrypt() throws Exception { - final String alwaysEncryptConfName = "spark.network.sasl.serverAlwaysEncrypt"; - System.setProperty(alwaysEncryptConfName, "true"); - SaslTestCtx ctx = null; try { - ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, false); + ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, false, + ImmutableMap.of("spark.network.sasl.serverAlwaysEncrypt", "true")); fail("Should have failed to connect without encryption."); } catch (Exception e) { assertTrue(e.getCause() instanceof SaslException); @@ -326,7 +326,6 @@ public void testServerAlwaysEncrypt() throws Exception { if (ctx != null) { ctx.close(); } - System.clearProperty(alwaysEncryptConfName); } } @@ -381,7 +380,7 @@ public void testAesEncryption() throws Exception { final File file = File.createTempFile("sasltest", ".txt"); SaslTestCtx ctx = null; try { - final TransportConf conf = new TransportConf("rpc", new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("rpc", MapConfigProvider.EMPTY); final TransportConf spyConf = spy(conf); doReturn(true).when(spyConf).aesEncryptionEnabled(); @@ -454,7 +453,19 @@ private static class SaslTestCtx { boolean aesEnable) throws Exception { - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + this(rpcHandler, encrypt, disableClientEncryption, aesEnable, + Collections.emptyMap()); + } + + SaslTestCtx( + RpcHandler rpcHandler, + boolean encrypt, + boolean disableClientEncryption, + boolean aesEnable, + Map testConf) + throws Exception { + + TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf)); if (aesEnable) { conf = spy(conf); diff --git a/common/network-common/src/test/java/org/apache/spark/network/util/CryptoUtilsSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/CryptoUtilsSuite.java new file mode 100644 index 0000000000..2b45d1e397 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/util/CryptoUtilsSuite.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.Map; +import java.util.Properties; + +import com.google.common.collect.ImmutableMap; +import org.junit.Test; +import static org.junit.Assert.*; + +public class CryptoUtilsSuite { + + @Test + public void testConfConversion() { + String prefix = "my.prefix.commons.config."; + + String confKey1 = prefix + "a.b.c"; + String confVal1 = "val1"; + String cryptoKey1 = CryptoUtils.COMMONS_CRYPTO_CONFIG_PREFIX + "a.b.c"; + + String confKey2 = prefix.substring(0, prefix.length() - 1) + "A.b.c"; + String confVal2 = "val2"; + String cryptoKey2 = CryptoUtils.COMMONS_CRYPTO_CONFIG_PREFIX + "A.b.c"; + + Map conf = ImmutableMap.of( + confKey1, confVal1, + confKey2, confVal2); + + Properties cryptoConf = CryptoUtils.toCryptoConf(prefix, conf.entrySet()); + + assertEquals(confVal1, cryptoConf.getProperty(cryptoKey1)); + assertFalse(cryptoConf.containsKey(cryptoKey2)); + } + +} diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 511e1f29de..24c10fb1dd 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -70,6 +70,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + log4j log4j diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 6ba937dddb..298a487ebb 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -55,7 +55,7 @@ import org.apache.spark.network.shuffle.protocol.RegisterExecutor; import org.apache.spark.network.shuffle.protocol.StreamHandle; import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class SaslIntegrationSuite { @@ -73,7 +73,7 @@ public class SaslIntegrationSuite { @BeforeClass public static void beforeAll() throws IOException { - conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); context = new TransportContext(conf, new TestRpcHandler()); secretKeyHolder = mock(SecretKeyHolder.class); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 35d6346474..bc97594903 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -25,7 +25,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.io.CharStreams; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; import org.junit.AfterClass; @@ -42,7 +42,7 @@ public class ExternalShuffleBlockResolverSuite { private static TestShuffleDataContext dataContext; private static final TransportConf conf = - new TransportConf("shuffle", new SystemPropertyConfigProvider()); + new TransportConf("shuffle", MapConfigProvider.EMPTY); @BeforeClass public static void beforeAll() throws IOException { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index bdd218db69..7757500b41 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -29,14 +29,14 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class ExternalShuffleCleanupSuite { // Same-thread Executor used to ensure cleanup happens synchronously in test thread. private Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); - private TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + private TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); private static final String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; @Test diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 552b5366c5..8dd97b29eb 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -28,6 +28,7 @@ import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Sets; import org.junit.After; @@ -43,7 +44,7 @@ import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class ExternalShuffleIntegrationSuite { @@ -84,7 +85,7 @@ public static void beforeAll() throws IOException { dataContext0.create(); dataContext0.insertSortShuffleData(0, 0, exec0Blocks); - conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); handler = new ExternalShuffleBlockHandler(conf, null); TransportContext transportContext = new TransportContext(conf, handler); server = transportContext.createServer(); @@ -115,12 +116,16 @@ public void releaseBuffers() { // Fetch a set of blocks from a pre-registered executor. private FetchResult fetchBlocks(String execId, String[] blockIds) throws Exception { - return fetchBlocks(execId, blockIds, server.getPort()); + return fetchBlocks(execId, blockIds, conf, server.getPort()); } // Fetch a set of blocks from a pre-registered executor. Connects to the server on the given port, // to allow connecting to invalid servers. - private FetchResult fetchBlocks(String execId, String[] blockIds, int port) throws Exception { + private FetchResult fetchBlocks( + String execId, + String[] blockIds, + TransportConf clientConf, + int port) throws Exception { final FetchResult res = new FetchResult(); res.successBlocks = Collections.synchronizedSet(new HashSet()); res.failedBlocks = Collections.synchronizedSet(new HashSet()); @@ -128,7 +133,7 @@ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) thro final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false); + ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false, false); client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @@ -227,16 +232,13 @@ public void testFetchUnregisteredExecutor() throws Exception { @Test public void testFetchNoServer() throws Exception { - System.setProperty("spark.shuffle.io.maxRetries", "0"); - try { - registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-0", - new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, 1 /* port */); - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); - } finally { - System.clearProperty("spark.shuffle.io.maxRetries"); - } + TransportConf clientConf = new TransportConf("shuffle", + new MapConfigProvider(ImmutableMap.of("spark.shuffle.io.maxRetries", "0"))); + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, clientConf, 1 /* port */); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); } private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index a0f69ca29a..aed25a161e 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -33,12 +33,12 @@ import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class ExternalShuffleSecuritySuite { - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); TransportServer server; @Before diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java index 91882e3b3b..a2509f5f34 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -27,8 +27,6 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; -import org.junit.After; -import org.junit.Before; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -39,7 +37,7 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; import static org.apache.spark.network.shuffle.RetryingBlockFetcher.BlockFetchStarter; @@ -53,18 +51,6 @@ public class RetryingBlockFetcherSuite { ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); - @Before - public void beforeEach() { - System.setProperty("spark.shuffle.io.maxRetries", "2"); - System.setProperty("spark.shuffle.io.retryWait", "0"); - } - - @After - public void afterEach() { - System.clearProperty("spark.shuffle.io.maxRetries"); - System.clearProperty("spark.shuffle.io.retryWait"); - } - @Test public void testNoFailures() throws IOException { BlockFetchingListener listener = mock(BlockFetchingListener.class); @@ -254,7 +240,10 @@ private static void performInteractions(List> inte BlockFetchingListener listener) throws IOException { - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + MapConfigProvider provider = new MapConfigProvider(ImmutableMap.of( + "spark.shuffle.io.maxRetries", "2", + "spark.shuffle.io.retryWait", "0")); + TransportConf conf = new TransportConf("shuffle", provider); BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); Stubber stub = null; diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 606ad15739..5e5a80bd44 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -50,6 +50,17 @@ spark-tags_${scala.binary.version} + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + org.apache.hadoop diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java index 884861752e..62a6cca4ed 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java @@ -17,6 +17,7 @@ package org.apache.spark.network.yarn.util; +import java.util.Map; import java.util.NoSuchElementException; import org.apache.hadoop.conf.Configuration; @@ -39,4 +40,10 @@ public String get(String name) { } return value; } + + @Override + public Iterable> getAll() { + return conf; + } + } diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 626f023a5b..bcd26d4352 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -39,6 +39,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + diff --git a/common/tags/pom.xml b/common/tags/pom.xml index 1c60d510e5..09f6fa12b9 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -34,14 +34,6 @@ tags - - - org.scalatest - scalatest_${scala.binary.version} - compile - - - target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes diff --git a/common/tags/src/main/java/org/apache/spark/tags/DockerTest.java b/common/tags/src/test/java/org/apache/spark/tags/DockerTest.java similarity index 100% rename from common/tags/src/main/java/org/apache/spark/tags/DockerTest.java rename to common/tags/src/test/java/org/apache/spark/tags/DockerTest.java diff --git a/common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java b/common/tags/src/test/java/org/apache/spark/tags/ExtendedHiveTest.java similarity index 100% rename from common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java rename to common/tags/src/test/java/org/apache/spark/tags/ExtendedHiveTest.java diff --git a/common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java b/common/tags/src/test/java/org/apache/spark/tags/ExtendedYarnTest.java similarity index 100% rename from common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java rename to common/tags/src/test/java/org/apache/spark/tags/ExtendedYarnTest.java diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 45af98d94e..dc19f4ad5f 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -39,6 +39,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + com.twitter chill_${scala.binary.version} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index a7b0e6f80c..fd6e95c3e0 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -252,7 +252,7 @@ public static long parseSecondNano(String secondNano) throws IllegalArgumentExce public final int months; public final long microseconds; - public final long milliseconds() { + public long milliseconds() { return this.microseconds / MICROS_PER_MILLI; } diff --git a/core/pom.xml b/core/pom.xml index eac99ab82a..97a463abbe 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../pom.xml @@ -337,6 +337,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + org.apache.commons commons-crypto diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index d2fcdea4f2..44120e591f 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -170,6 +170,8 @@ public final class BytesToBytesMap extends MemoryConsumer { private long peakMemoryUsedBytes = 0L; + private final int initialCapacity; + private final BlockManager blockManager; private final SerializerManager serializerManager; private volatile MapIterator destructiveIterator = null; @@ -202,6 +204,7 @@ public BytesToBytesMap( throw new IllegalArgumentException("Page size " + pageSizeBytes + " cannot exceed " + TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); } + this.initialCapacity = initialCapacity; allocate(initialCapacity); } @@ -902,12 +905,12 @@ public LongArray getArray() { public void reset() { numKeys = 0; numValues = 0; - longArray.zeroOut(); - + freeArray(longArray); while (dataPages.size() > 0) { MemoryBlock dataPage = dataPages.removeLast(); freePage(dataPage); } + allocate(initialCapacity); currentPage = null; pageCursor = 0; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 252a35ec6b..5b42843717 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -22,6 +22,8 @@ import org.apache.avro.reflect.Nullable; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskKilledException; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; @@ -253,6 +255,7 @@ public final class SortedIterator extends UnsafeSorterIterator implements Clonea private long keyPrefix; private int recordLength; private long currentPageNumber; + private final TaskContext taskContext = TaskContext.get(); private SortedIterator(int numRecords, int offset) { this.numRecords = numRecords; @@ -283,6 +286,14 @@ public boolean hasNext() { @Override public void loadNext() { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. This check is added here in `loadNext()` instead of in + // `hasNext()` because it's technically possible for the caller to be relying on + // `getNumRecords()` instead of `hasNext()` to know when to stop. + if (taskContext != null && taskContext.isInterrupted()) { + throw new TaskKilledException(); + } // This pointer points to a 4-byte record length, followed by the record's bytes final long recordPointer = array.get(offset + position); currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index a658e5eb47..b6323c624b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -23,6 +23,8 @@ import com.google.common.io.Closeables; import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskKilledException; import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; @@ -51,6 +53,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen private byte[] arr = new byte[1024 * 1024]; private Object baseObject = arr; private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; + private final TaskContext taskContext = TaskContext.get(); public UnsafeSorterSpillReader( SerializerManager serializerManager, @@ -94,6 +97,14 @@ public boolean hasNext() { @Override public void loadNext() throws IOException { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. This check is added here in `loadNext()` instead of in + // `hasNext()` because it's technically possible for the caller to be relying on + // `getNumRecords()` instead of `hasNext()` to know when to stop. + if (taskContext != null && taskContext.isInterrupted()) { + throw new TaskKilledException(); + } recordLength = din.readInt(); keyPrefix = din.readLong(); if (recordLength > arr.length) { diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index 1df67337ea..fe5db6aa26 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -411,10 +411,6 @@ $(document).ready(function () { } ], "columnDefs": [ - { - "targets": [ 15 ], - "visible": logsExist(response) - }, { "targets": [ 16 ], "visible": getThreadDumpEnabled() @@ -423,7 +419,8 @@ $(document).ready(function () { "order": [[0, "asc"]] }; - $(selector).DataTable(conf); + var dt = $(selector).DataTable(conf); + dt.column(15).visible(logsExist(response)); $('#active-executors [data-toggle="tooltip"]').tooltip(); var sumSelector = "#summary-execs-table"; diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 8fd91865b0..54810edaf1 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -78,6 +78,12 @@ jQuery.extend( jQuery.fn.dataTableExt.oSort, { } } ); +jQuery.extend( jQuery.fn.dataTableExt.ofnSearch, { + "appid-numeric": function ( a ) { + return a.replace(/[\r\n]/g, " ").replace(/<.*?>/g, ""); + } +} ); + $(document).ajaxStop($.unblockUI); $(document).ajaxStart(function(){ $.blockUI({ message: '

Loading history summary...

'}); diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 7f8f0f5131..6f5c31d7ab 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -322,7 +322,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, if (minSizeForBroadcast > maxRpcMessageSize) { val msg = s"spark.shuffle.mapOutput.minSizeForBroadcast ($minSizeForBroadcast bytes) must " + s"be <= spark.rpc.message.maxSize ($maxRpcMessageSize bytes) to prevent sending an rpc " + - "message that is to large." + "message that is too large." logError(msg) throw new IllegalArgumentException(msg) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b8414b5d09..efb5f9d501 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -382,6 +382,9 @@ class SparkContext(config: SparkConf) extends Logging { throw new SparkException("An application name must be set in your configuration") } + // log out spark.app.name in the Spark driver logs + logInfo(s"Submitted application: $appName") + // System property spark.yarn.app.id must be set if user code ran by AM on a YARN cluster if (master == "yarn" && deployMode == "cluster" && !_conf.contains("spark.yarn.app.id")) { throw new SparkException("Detected yarn cluster mode, but isn't running on a cluster. " + @@ -670,10 +673,10 @@ class SparkContext(config: SparkConf) extends Logging { * sc.cancelJobGroup("some_job_to_cancel") * }}} * - * If interruptOnCancel is set to true for the job group, then job cancellation will result - * in Thread.interrupt() being called on the job's executor threads. This is useful to help ensure - * that the tasks are actually stopped in a timely manner, but is off by default due to HDFS-1208, - * where HDFS may respond to Thread.interrupt() by marking nodes as dead. + * @param interruptOnCancel If true, then job cancellation will result in `Thread.interrupt()` + * being called on the job's executor threads. This is useful to help ensure that the tasks + * are actually stopped in a timely manner, but is off by default due to HDFS-1208, where HDFS + * may respond to Thread.interrupt() by marking nodes as dead. */ def setJobGroup(groupId: String, description: String, interruptOnCancel: Boolean = false) { setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description) @@ -709,6 +712,9 @@ class SparkContext(config: SparkConf) extends Logging { * modified collection. Pass a copy of the argument to avoid this. * @note avoid using `parallelize(Seq())` to create an empty `RDD`. Consider `emptyRDD` for an * RDD with no partitions, or `parallelize(Seq[T]())` for an RDD of `T` with empty partitions. + * @param seq Scala collection to distribute + * @param numSlices number of partitions to divide the collection into + * @return RDD representing distributed collection */ def parallelize[T: ClassTag]( seq: Seq[T], @@ -726,8 +732,8 @@ class SparkContext(config: SparkConf) extends Logging { * @param start the start value. * @param end the end value. * @param step the incremental step - * @param numSlices the partition number of the new RDD. - * @return + * @param numSlices number of partitions to divide the collection into + * @return RDD representing distributed range */ def range( start: Long, @@ -792,6 +798,9 @@ class SparkContext(config: SparkConf) extends Logging { /** Distribute a local Scala collection to form an RDD. * * This method is identical to `parallelize`. + * @param seq Scala collection to distribute + * @param numSlices number of partitions to divide the collection into + * @return RDD representing distributed collection */ def makeRDD[T: ClassTag]( seq: Seq[T], @@ -803,6 +812,8 @@ class SparkContext(config: SparkConf) extends Logging { * Distribute a local Scala collection to form an RDD, with one or more * location preferences (hostnames of Spark nodes) for each object. * Create a new partition for each collection item. + * @param seq list of tuples of data and location preferences (hostnames of Spark nodes) + * @return RDD representing data partitioned according to location preferences */ def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = withScope { assertNotStopped() @@ -813,6 +824,9 @@ class SparkContext(config: SparkConf) extends Logging { /** * Read a text file from HDFS, a local file system (available on all nodes), or any * Hadoop-supported file system URI, and return it as an RDD of Strings. + * @param path path to the text file on a supported file system + * @param minPartitions suggested minimum number of partitions for the resulting RDD + * @return RDD of lines of the text file */ def textFile( path: String, @@ -848,10 +862,13 @@ class SparkContext(config: SparkConf) extends Logging { * @note Small files are preferred, large file is also allowable, but may cause bad performance. * @note On some filesystems, `.../path/*` can be a more efficient way to read all files * in a directory rather than `.../path/` or `.../path` + * @note Partitioning is determined by data locality. This may result in too few partitions + * by default. * * @param path Directory to the input data files, the path can be comma separated paths as the * list of inputs. * @param minPartitions A suggestion value of the minimal splitting number for input data. + * @return RDD representing tuples of file path and the corresponding file content */ def wholeTextFiles( path: String, @@ -897,10 +914,13 @@ class SparkContext(config: SparkConf) extends Logging { * @note Small files are preferred; very large files may cause bad performance. * @note On some filesystems, `.../path/*` can be a more efficient way to read all files * in a directory rather than `.../path/` or `.../path` + * @note Partitioning is determined by data locality. This may result in too few partitions + * by default. * * @param path Directory to the input data files, the path can be comma separated paths as the * list of inputs. * @param minPartitions A suggestion value of the minimal splitting number for input data. + * @return RDD representing tuples of file path and corresponding file content */ def binaryFiles( path: String, @@ -961,10 +981,11 @@ class SparkContext(config: SparkConf) extends Logging { * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make * sure you won't modify the conf. A safe approach is always creating a new conf for * a new RDD. - * @param inputFormatClass Class of the InputFormat - * @param keyClass Class of the keys - * @param valueClass Class of the values + * @param inputFormatClass storage format of the data to be read + * @param keyClass `Class` of the key associated with the `inputFormatClass` parameter + * @param valueClass `Class` of the value associated with the `inputFormatClass` parameter * @param minPartitions Minimum number of Hadoop Splits to generate. + * @return RDD of tuples of key and corresponding value * * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle @@ -996,6 +1017,13 @@ class SparkContext(config: SparkConf) extends Logging { * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first * copy them using a `map` function. + * @param path directory to the input data files, the path can be comma separated paths + * as a list of inputs + * @param inputFormatClass storage format of the data to be read + * @param keyClass `Class` of the key associated with the `inputFormatClass` parameter + * @param valueClass `Class` of the value associated with the `inputFormatClass` parameter + * @param minPartitions suggested minimum number of partitions for the resulting RDD + * @return RDD of tuples of key and corresponding value */ def hadoopFile[K, V]( path: String, @@ -1035,6 +1063,10 @@ class SparkContext(config: SparkConf) extends Logging { * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first * copy them using a `map` function. + * @param path directory to the input data files, the path can be comma separated paths + * as a list of inputs + * @param minPartitions suggested minimum number of partitions for the resulting RDD + * @return RDD of tuples of key and corresponding value */ def hadoopFile[K, V, F <: InputFormat[K, V]] (path: String, minPartitions: Int) @@ -1059,13 +1091,32 @@ class SparkContext(config: SparkConf) extends Logging { * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first * copy them using a `map` function. + * @param path directory to the input data files, the path can be comma separated paths as + * a list of inputs + * @return RDD of tuples of key and corresponding value */ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String) (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = withScope { hadoopFile[K, V, F](path, defaultMinPartitions) } - /** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */ + /** + * Smarter version of `newApiHadoopFile` that uses class tags to figure out the classes of keys, + * values and the `org.apache.hadoop.mapreduce.InputFormat` (new MapReduce API) so that user + * don't need to pass them directly. Instead, callers can just write, for example: + * ``` + * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path) + * ``` + * + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. + * @param path directory to the input data files, the path can be comma separated paths + * as a list of inputs + * @return RDD of tuples of key and corresponding value + */ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]] (path: String) (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = withScope { @@ -1085,6 +1136,13 @@ class SparkContext(config: SparkConf) extends Logging { * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first * copy them using a `map` function. + * @param path directory to the input data files, the path can be comma separated paths + * as a list of inputs + * @param fClass storage format of the data to be read + * @param kClass `Class` of the key associated with the `fClass` parameter + * @param vClass `Class` of the value associated with the `fClass` parameter + * @param conf Hadoop configuration + * @return RDD of tuples of key and corresponding value */ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]]( path: String, @@ -1116,9 +1174,9 @@ class SparkContext(config: SparkConf) extends Logging { * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make * sure you won't modify the conf. A safe approach is always creating a new conf for * a new RDD. - * @param fClass Class of the InputFormat - * @param kClass Class of the keys - * @param vClass Class of the values + * @param fClass storage format of the data to be read + * @param kClass `Class` of the key associated with the `fClass` parameter + * @param vClass `Class` of the value associated with the `fClass` parameter * * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle @@ -1151,6 +1209,12 @@ class SparkContext(config: SparkConf) extends Logging { * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first * copy them using a `map` function. + * @param path directory to the input data files, the path can be comma separated paths + * as a list of inputs + * @param keyClass `Class` of the key associated with `SequenceFileInputFormat` + * @param valueClass `Class` of the value associated with `SequenceFileInputFormat` + * @param minPartitions suggested minimum number of partitions for the resulting RDD + * @return RDD of tuples of key and corresponding value */ def sequenceFile[K, V](path: String, keyClass: Class[K], @@ -1170,6 +1234,11 @@ class SparkContext(config: SparkConf) extends Logging { * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first * copy them using a `map` function. + * @param path directory to the input data files, the path can be comma separated paths + * as a list of inputs + * @param keyClass `Class` of the key associated with `SequenceFileInputFormat` + * @param valueClass `Class` of the value associated with `SequenceFileInputFormat` + * @return RDD of tuples of key and corresponding value */ def sequenceFile[K, V]( path: String, @@ -1200,6 +1269,10 @@ class SparkContext(config: SparkConf) extends Logging { * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first * copy them using a `map` function. + * @param path directory to the input data files, the path can be comma separated paths + * as a list of inputs + * @param minPartitions suggested minimum number of partitions for the resulting RDD + * @return RDD of tuples of key and corresponding value */ def sequenceFile[K, V] (path: String, minPartitions: Int = defaultMinPartitions) @@ -1224,6 +1297,11 @@ class SparkContext(config: SparkConf) extends Logging { * be pretty slow if you use the default serializer (Java serialization), * though the nice thing about it is that there's very little effort required to save arbitrary * objects. + * + * @param path directory to the input data files, the path can be comma separated paths + * as a list of inputs + * @param minPartitions suggested minimum number of partitions for the resulting RDD + * @return RDD representing deserialized data from the file(s) */ def objectFile[T: ClassTag]( path: String, @@ -1403,6 +1481,9 @@ class SparkContext(config: SparkConf) extends Logging { * Broadcast a read-only variable to the cluster, returning a * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. + * + * @param value value to broadcast to the Spark nodes + * @return `Broadcast` object, a read-only variable cached on each machine */ def broadcast[T: ClassTag](value: T): Broadcast[T] = { assertNotStopped() @@ -1417,8 +1498,9 @@ class SparkContext(config: SparkConf) extends Logging { /** * Add a file to be downloaded with this Spark job on every node. - * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * + * @param path can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. */ def addFile(path: String): Unit = { @@ -1432,12 +1514,12 @@ class SparkContext(config: SparkConf) extends Logging { /** * Add a file to be downloaded with this Spark job on every node. - * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, - * use `SparkFiles.get(fileName)` to find its download location. * - * A directory can be given if the recursive option is set to true. Currently directories are only - * supported for Hadoop-supported filesystems. + * @param path can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(fileName)` to find its download location. + * @param recursive if true, a directory can be given in `path`. Currently directories are + * only supported for Hadoop-supported filesystems. */ def addFile(path: String, recursive: Boolean): Unit = { val uri = new Path(path).toUri @@ -1489,6 +1571,15 @@ class SparkContext(config: SparkConf) extends Logging { listenerBus.addListener(listener) } + /** + * :: DeveloperApi :: + * Deregister the listener from Spark's listener bus. + */ + @DeveloperApi + def removeSparkListener(listener: SparkListenerInterface): Unit = { + listenerBus.removeListener(listener) + } + private[spark] def getExecutorIds(): Seq[String] = { schedulerBackend match { case b: CoarseGrainedSchedulerBackend => @@ -1708,9 +1799,9 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. - * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. + * Adds a JAR dependency for all tasks to be executed on this `SparkContext` in the future. + * @param path can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), + * an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. */ def addJar(path: String) { if (path == null) { @@ -1757,25 +1848,30 @@ class SparkContext(config: SparkConf) extends Logging { def listJars(): Seq[String] = addedJars.keySet.toSeq /** - * Shut down the SparkContext. + * When stopping SparkContext inside Spark components, it's easy to cause dead-lock since Spark + * may wait for some internal threads to finish. It's better to use this method to stop + * SparkContext instead. */ - def stop(): Unit = { - if (env.rpcEnv.isInRPCThread) { - // `stop` will block until all RPC threads exit, so we cannot call stop inside a RPC thread. - // We should launch a new thread to call `stop` to avoid dead-lock. - new Thread("stop-spark-context") { - setDaemon(true) - - override def run(): Unit = { - _stop() + private[spark] def stopInNewThread(): Unit = { + new Thread("stop-spark-context") { + setDaemon(true) + + override def run(): Unit = { + try { + SparkContext.this.stop() + } catch { + case e: Throwable => + logError(e.getMessage, e) + throw e } - }.start() - } else { - _stop() - } + } + }.start() } - private def _stop() { + /** + * Shut down the SparkContext. + */ + def stop(): Unit = { if (LiveListenerBus.withinListenerThread.value) { throw new SparkException( s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}") @@ -1895,6 +1991,12 @@ class SparkContext(config: SparkConf) extends Logging { /** * Run a function on a given set of partitions in an RDD and pass the results to the given * handler function. This is the main entry point for all actions in Spark. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like `first()` + * @param resultHandler callback to pass each result to */ def runJob[T, U: ClassTag]( rdd: RDD[T], @@ -1917,6 +2019,14 @@ class SparkContext(config: SparkConf) extends Logging { /** * Run a function on a given set of partitions in an RDD and return the results as an array. + * The function that is run against each partition additionally takes `TaskContext` argument. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like `first()` + * @return in-memory collection with a result of the job (each collection element will contain + * a result from one partition) */ def runJob[T, U: ClassTag]( rdd: RDD[T], @@ -1928,8 +2038,14 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Run a job on a given set of partitions of an RDD, but take a function of type - * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. + * Run a function on a given set of partitions in an RDD and return the results as an array. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like `first()` + * @return in-memory collection with a result of the job (each collection element will contain + * a result from one partition) */ def runJob[T, U: ClassTag]( rdd: RDD[T], @@ -1940,7 +2056,13 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Run a job on all partitions in an RDD and return the results in an array. + * Run a job on all partitions in an RDD and return the results in an array. The function + * that is run against each partition additionally takes `TaskContext` argument. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @return in-memory collection with a result of the job (each collection element will contain + * a result from one partition) */ def runJob[T, U: ClassTag](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = { runJob(rdd, func, 0 until rdd.partitions.length) @@ -1948,13 +2070,23 @@ class SparkContext(config: SparkConf) extends Logging { /** * Run a job on all partitions in an RDD and return the results in an array. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @return in-memory collection with a result of the job (each collection element will contain + * a result from one partition) */ def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { runJob(rdd, func, 0 until rdd.partitions.length) } /** - * Run a job on all partitions in an RDD and pass the results to a handler function. + * Run a job on all partitions in an RDD and pass the results to a handler function. The function + * that is run against each partition additionally takes `TaskContext` argument. + * + * @param rdd target RDD to run tasks on + * @param processPartition a function to run on each partition of the RDD + * @param resultHandler callback to pass each result to */ def runJob[T, U: ClassTag]( rdd: RDD[T], @@ -1966,6 +2098,10 @@ class SparkContext(config: SparkConf) extends Logging { /** * Run a job on all partitions in an RDD and pass the results to a handler function. + * + * @param rdd target RDD to run tasks on + * @param processPartition a function to run on each partition of the RDD + * @param resultHandler callback to pass each result to */ def runJob[T, U: ClassTag]( rdd: RDD[T], @@ -1979,6 +2115,13 @@ class SparkContext(config: SparkConf) extends Logging { /** * :: DeveloperApi :: * Run a job that can return approximate results. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param evaluator `ApproximateEvaluator` to receive the partial results + * @param timeout maximum time to wait for the job, in milliseconds + * @return partial result (how partial depends on whether the job was finished before or + * after timeout) */ @DeveloperApi def runApproximateJob[T, U, R]( @@ -2000,6 +2143,13 @@ class SparkContext(config: SparkConf) extends Logging { /** * Submit a job for execution and return a FutureJob holding the result. + * + * @param rdd target RDD to run tasks on + * @param processPartition a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like `first()` + * @param resultHandler callback to pass each result to + * @param resultFunc function to be executed when the result is ready */ def submitJob[T, U, R]( rdd: RDD[T], @@ -2084,6 +2234,7 @@ class SparkContext(config: SparkConf) extends Logging { * @param checkSerializable whether or not to immediately check f for serializability * @throws SparkException if checkSerializable is set but f is not * serializable + * @return the cleaned closure */ private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = { ClosureCleaner.clean(f, checkSerializable) @@ -2091,8 +2242,9 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Set the directory under which RDDs are going to be checkpointed. The directory must - * be a HDFS path if running on a cluster. + * Set the directory under which RDDs are going to be checkpointed. + * @param directory path to the directory where checkpoint files will be stored + * (must be HDFS path if running in cluster) */ def setCheckpointDir(directory: String) { @@ -2299,6 +2451,8 @@ object SparkContext extends Logging { * * @note This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. + * @param config `SparkConfig` that will be used for initialisation of the `SparkContext` + * @return current `SparkContext` (or a new one if it wasn't created before the function call) */ def getOrCreate(config: SparkConf): SparkContext = { // Synchronize to ensure that multiple create requests don't trigger an exception @@ -2324,6 +2478,7 @@ object SparkContext extends Logging { * * @note This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. + * @return current `SparkContext` (or a new one if wasn't created before the function call) */ def getOrCreate(): SparkContext = { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { @@ -2404,6 +2559,9 @@ object SparkContext extends Logging { /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to SparkContext. + * + * @param cls class that should be inside of the jar + * @return jar that contains the Class, `None` if not found */ def jarOfClass(cls: Class[_]): Option[String] = { val uri = cls.getResource("/" + cls.getName.replace('.', '/') + ".class") @@ -2425,6 +2583,9 @@ object SparkContext extends Logging { * Find the JAR that contains the class of a particular object, to make it easy for users * to pass their JARs to SparkContext. In most cases you can call jarOfObject(this) in * your driver program. + * + * @param obj reference to an instance which class should be inside of the jar + * @return jar that contains the class of the instance, `None` if not found */ def jarOfObject(obj: AnyRef): Option[String] = jarOfClass(obj.getClass) @@ -2562,8 +2723,8 @@ object SparkContext extends Logging { val serviceLoaders = ServiceLoader.load(classOf[ExternalClusterManager], loader).asScala.filter(_.canCreate(url)) if (serviceLoaders.size > 1) { - throw new SparkException(s"Multiple Cluster Managers ($serviceLoaders) registered " + - s"for the url $url:") + throw new SparkException( + s"Multiple external cluster managers registered for the url $url: $serviceLoaders") } serviceLoaders.headOption } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 2909191bd6..b5b201409a 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -93,7 +93,10 @@ private[spark] object TestUtils { val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) for (file <- files) { - val jarEntry = new JarEntry(Paths.get(directoryPrefix.getOrElse(""), file.getName).toString) + // The `name` for the argument in `JarEntry` should use / for its separator. This is + // ZIP specification. + val prefix = directoryPrefix.map(d => s"$d/").getOrElse("") + val jarEntry = new JarEntry(prefix + file.getName) jarStream.putNextEntry(jarEntry) val in = new FileInputStream(file) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0ca91b9bf8..04ae97ed3c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -275,6 +275,11 @@ private[spark] class PythonRunner( dataOut.writeInt(partitionIndex) // Python version of driver PythonRDD.writeUTF(pythonVer, dataOut) + // Write out the TaskContextInfo + dataOut.writeInt(context.stageId()) + dataOut.writeInt(context.partitionId()) + dataOut.writeInt(context.attemptNumber()) + dataOut.writeLong(context.taskAttemptId()) // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) diff --git a/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala b/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala new file mode 100644 index 0000000000..3432700f11 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.ConcurrentHashMap + +/** JVM object ID wrapper */ +private[r] case class JVMObjectId(id: String) { + require(id != null, "Object ID cannot be null.") +} + +/** + * Counter that tracks JVM objects returned to R. + * This is useful for referencing these objects in RPC calls. + */ +private[r] class JVMObjectTracker { + + private[this] val objMap = new ConcurrentHashMap[JVMObjectId, Object]() + private[this] val objCounter = new AtomicInteger() + + /** + * Returns the JVM object associated with the input key or None if not found. + */ + final def get(id: JVMObjectId): Option[Object] = this.synchronized { + if (objMap.containsKey(id)) { + Some(objMap.get(id)) + } else { + None + } + } + + /** + * Returns the JVM object associated with the input key or throws an exception if not found. + */ + @throws[NoSuchElementException]("if key does not exist.") + final def apply(id: JVMObjectId): Object = { + get(id).getOrElse( + throw new NoSuchElementException(s"$id does not exist.") + ) + } + + /** + * Adds a JVM object to track and returns assigned ID, which is unique within this tracker. + */ + final def addAndGetId(obj: Object): JVMObjectId = { + val id = JVMObjectId(objCounter.getAndIncrement().toString) + objMap.put(id, obj) + id + } + + /** + * Removes and returns a JVM object with the specific ID from the tracker, or None if not found. + */ + final def remove(id: JVMObjectId): Option[Object] = this.synchronized { + if (objMap.containsKey(id)) { + Some(objMap.remove(id)) + } else { + None + } + } + + /** + * Number of JVM objects being tracked. + */ + final def size: Int = objMap.size() + + /** + * Clears the tracker. + */ + final def clear(): Unit = objMap.clear() +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 550746c552..2d1152a036 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -22,7 +22,7 @@ import java.net.{InetAddress, InetSocketAddress, ServerSocket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap -import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup} +import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel @@ -42,6 +42,9 @@ private[spark] class RBackend { private[this] var bootstrap: ServerBootstrap = null private[this] var bossGroup: EventLoopGroup = null + /** Tracks JVM objects returned to R for this RBackend instance. */ + private[r] val jvmObjectTracker = new JVMObjectTracker + def init(): Int = { val conf = new SparkConf() val backendConnectionTimeout = conf.getInt( @@ -94,6 +97,7 @@ private[spark] class RBackend { bootstrap.childGroup().shutdownGracefully() } bootstrap = null + jvmObjectTracker.clear() } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 9f5afa29d6..cfd37ac54b 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -20,7 +20,6 @@ package org.apache.spark.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.util.concurrent.TimeUnit -import scala.collection.mutable.HashMap import scala.language.existentials import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} @@ -62,7 +61,7 @@ private[r] class RBackendHandler(server: RBackend) assert(numArgs == 1) writeInt(dos, 0) - writeObject(dos, args(0)) + writeObject(dos, args(0), server.jvmObjectTracker) case "stopBackend" => writeInt(dos, 0) writeType(dos, "void") @@ -72,9 +71,9 @@ private[r] class RBackendHandler(server: RBackend) val t = readObjectType(dis) assert(t == 'c') val objToRemove = readString(dis) - JVMObjectTracker.remove(objToRemove) + server.jvmObjectTracker.remove(JVMObjectId(objToRemove)) writeInt(dos, 0) - writeObject(dos, null) + writeObject(dos, null, server.jvmObjectTracker) } catch { case e: Exception => logError(s"Removing $objId failed", e) @@ -143,12 +142,8 @@ private[r] class RBackendHandler(server: RBackend) val cls = if (isStatic) { Utils.classForName(objId) } else { - JVMObjectTracker.get(objId) match { - case None => throw new IllegalArgumentException("Object not found " + objId) - case Some(o) => - obj = o - o.getClass - } + obj = server.jvmObjectTracker(JVMObjectId(objId)) + obj.getClass } val args = readArgs(numArgs, dis) @@ -173,7 +168,7 @@ private[r] class RBackendHandler(server: RBackend) // Write status bit writeInt(dos, 0) - writeObject(dos, ret.asInstanceOf[AnyRef]) + writeObject(dos, ret.asInstanceOf[AnyRef], server.jvmObjectTracker) } else if (methodName == "") { // methodName should be "" for constructor val ctors = cls.getConstructors @@ -193,7 +188,7 @@ private[r] class RBackendHandler(server: RBackend) val obj = ctors(index.get).newInstance(args : _*) writeInt(dos, 0) - writeObject(dos, obj.asInstanceOf[AnyRef]) + writeObject(dos, obj.asInstanceOf[AnyRef], server.jvmObjectTracker) } else { throw new IllegalArgumentException("invalid method " + methodName + " for object " + objId) } @@ -210,7 +205,7 @@ private[r] class RBackendHandler(server: RBackend) // Read a number of arguments from the data input stream def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { (0 until numArgs).map { _ => - readObject(dis) + readObject(dis, server.jvmObjectTracker) }.toArray } @@ -286,37 +281,4 @@ private[r] class RBackendHandler(server: RBackend) } } -/** - * Helper singleton that tracks JVM objects returned to R. - * This is useful for referencing these objects in RPC calls. - */ -private[r] object JVMObjectTracker { - - // TODO: This map should be thread-safe if we want to support multiple - // connections at the same time - private[this] val objMap = new HashMap[String, Object] - - // TODO: We support only one connection now, so an integer is fine. - // Investigate using use atomic integer in the future. - private[this] var objCounter: Int = 0 - - def getObject(id: String): Object = { - objMap(id) - } - - def get(id: String): Option[Object] = { - objMap.get(id) - } - - def put(obj: Object): String = { - val objId = objCounter.toString - objCounter = objCounter + 1 - objMap.put(objId, obj) - objId - } - def remove(id: String): Option[Object] = { - objMap.remove(id) - } - -} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 7ef64723d9..29e21b3b1a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -152,7 +152,7 @@ private[spark] class RRunner[U]( dataOut.writeInt(mode) if (isDataFrame) { - SerDe.writeObject(dataOut, colNames) + SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null) } if (!iter.hasNext) { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 550e075a95..dad928cdcf 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -28,13 +28,20 @@ import scala.collection.mutable.WrappedArray * Utility functions to serialize, deserialize objects to / from R */ private[spark] object SerDe { - type ReadObject = (DataInputStream, Char) => Object - type WriteObject = (DataOutputStream, Object) => Boolean + type SQLReadObject = (DataInputStream, Char) => Object + type SQLWriteObject = (DataOutputStream, Object) => Boolean - var sqlSerDe: (ReadObject, WriteObject) = _ + private[this] var sqlReadObject: SQLReadObject = _ + private[this] var sqlWriteObject: SQLWriteObject = _ - def registerSqlSerDe(sqlSerDe: (ReadObject, WriteObject)): Unit = { - this.sqlSerDe = sqlSerDe + def setSQLReadObject(value: SQLReadObject): this.type = { + sqlReadObject = value + this + } + + def setSQLWriteObject(value: SQLWriteObject): this.type = { + sqlWriteObject = value + this } // Type mapping from R to Java @@ -56,32 +63,33 @@ private[spark] object SerDe { dis.readByte().toChar } - def readObject(dis: DataInputStream): Object = { + def readObject(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): Object = { val dataType = readObjectType(dis) - readTypedObject(dis, dataType) + readTypedObject(dis, dataType, jvmObjectTracker) } def readTypedObject( dis: DataInputStream, - dataType: Char): Object = { + dataType: Char, + jvmObjectTracker: JVMObjectTracker): Object = { dataType match { case 'n' => null case 'i' => new java.lang.Integer(readInt(dis)) case 'd' => new java.lang.Double(readDouble(dis)) case 'b' => new java.lang.Boolean(readBoolean(dis)) case 'c' => readString(dis) - case 'e' => readMap(dis) + case 'e' => readMap(dis, jvmObjectTracker) case 'r' => readBytes(dis) - case 'a' => readArray(dis) - case 'l' => readList(dis) + case 'a' => readArray(dis, jvmObjectTracker) + case 'l' => readList(dis, jvmObjectTracker) case 'D' => readDate(dis) case 't' => readTime(dis) - case 'j' => JVMObjectTracker.getObject(readString(dis)) + case 'j' => jvmObjectTracker(JVMObjectId(readString(dis))) case _ => - if (sqlSerDe == null || sqlSerDe._1 == null) { + if (sqlReadObject == null) { throw new IllegalArgumentException (s"Invalid type $dataType") } else { - val obj = (sqlSerDe._1)(dis, dataType) + val obj = sqlReadObject(dis, dataType) if (obj == null) { throw new IllegalArgumentException (s"Invalid type $dataType") } else { @@ -181,28 +189,28 @@ private[spark] object SerDe { } // All elements of an array must be of the same type - def readArray(dis: DataInputStream): Array[_] = { + def readArray(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): Array[_] = { val arrType = readObjectType(dis) arrType match { case 'i' => readIntArr(dis) case 'c' => readStringArr(dis) case 'd' => readDoubleArr(dis) case 'b' => readBooleanArr(dis) - case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'j' => readStringArr(dis).map(x => jvmObjectTracker(JVMObjectId(x))) case 'r' => readBytesArr(dis) case 'a' => val len = readInt(dis) - (0 until len).map(_ => readArray(dis)).toArray + (0 until len).map(_ => readArray(dis, jvmObjectTracker)).toArray case 'l' => val len = readInt(dis) - (0 until len).map(_ => readList(dis)).toArray + (0 until len).map(_ => readList(dis, jvmObjectTracker)).toArray case _ => - if (sqlSerDe == null || sqlSerDe._1 == null) { + if (sqlReadObject == null) { throw new IllegalArgumentException (s"Invalid array type $arrType") } else { val len = readInt(dis) (0 until len).map { _ => - val obj = (sqlSerDe._1)(dis, arrType) + val obj = sqlReadObject(dis, arrType) if (obj == null) { throw new IllegalArgumentException (s"Invalid array type $arrType") } else { @@ -215,17 +223,19 @@ private[spark] object SerDe { // Each element of a list can be of different type. They are all represented // as Object on JVM side - def readList(dis: DataInputStream): Array[Object] = { + def readList(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): Array[Object] = { val len = readInt(dis) - (0 until len).map(_ => readObject(dis)).toArray + (0 until len).map(_ => readObject(dis, jvmObjectTracker)).toArray } - def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + def readMap( + in: DataInputStream, + jvmObjectTracker: JVMObjectTracker): java.util.Map[Object, Object] = { val len = readInt(in) if (len > 0) { // Keys is an array of String - val keys = readArray(in).asInstanceOf[Array[Object]] - val values = readList(in) + val keys = readArray(in, jvmObjectTracker).asInstanceOf[Array[Object]] + val values = readList(in, jvmObjectTracker) keys.zip(values).toMap.asJava } else { @@ -272,7 +282,11 @@ private[spark] object SerDe { } } - private def writeKeyValue(dos: DataOutputStream, key: Object, value: Object): Unit = { + private def writeKeyValue( + dos: DataOutputStream, + key: Object, + value: Object, + jvmObjectTracker: JVMObjectTracker): Unit = { if (key == null) { throw new IllegalArgumentException("Key in map can't be null.") } else if (!key.isInstanceOf[String]) { @@ -280,10 +294,10 @@ private[spark] object SerDe { } writeString(dos, key.asInstanceOf[String]) - writeObject(dos, value) + writeObject(dos, value, jvmObjectTracker) } - def writeObject(dos: DataOutputStream, obj: Object): Unit = { + def writeObject(dos: DataOutputStream, obj: Object, jvmObjectTracker: JVMObjectTracker): Unit = { if (obj == null) { writeType(dos, "void") } else { @@ -373,14 +387,14 @@ private[spark] object SerDe { case v: Array[Object] => writeType(dos, "list") writeInt(dos, v.length) - v.foreach(elem => writeObject(dos, elem)) + v.foreach(elem => writeObject(dos, elem, jvmObjectTracker)) // Handle Properties // This must be above the case java.util.Map below. // (Properties implements Map and will be serialized as map otherwise) case v: java.util.Properties => writeType(dos, "jobj") - writeJObj(dos, value) + writeJObj(dos, value, jvmObjectTracker) // Handle map case v: java.util.Map[_, _] => @@ -392,19 +406,21 @@ private[spark] object SerDe { val key = entry.getKey val value = entry.getValue - writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + writeKeyValue( + dos, key.asInstanceOf[Object], value.asInstanceOf[Object], jvmObjectTracker) } case v: scala.collection.Map[_, _] => writeType(dos, "map") writeInt(dos, v.size) - v.foreach { case (key, value) => - writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + v.foreach { case (k1, v1) => + writeKeyValue(dos, k1.asInstanceOf[Object], v1.asInstanceOf[Object], jvmObjectTracker) } case _ => - if (sqlSerDe == null || sqlSerDe._2 == null || !(sqlSerDe._2)(dos, value)) { + val sqlWriteSucceeded = sqlWriteObject != null && sqlWriteObject(dos, value) + if (!sqlWriteSucceeded) { writeType(dos, "jobj") - writeJObj(dos, value) + writeJObj(dos, value, jvmObjectTracker) } } } @@ -447,9 +463,9 @@ private[spark] object SerDe { out.write(value) } - def writeJObj(out: DataOutputStream, value: Object): Unit = { - val objId = JVMObjectTracker.put(value) - writeString(out, objId) + def writeJObj(out: DataOutputStream, value: Object, jvmObjectTracker: JVMObjectTracker): Unit = { + val JVMObjectId(id) = jvmObjectTracker.addAndGetId(value) + writeString(out, id) } def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index fd7b4fc88b..ece4ae6ab0 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -24,9 +24,8 @@ import org.apache.spark.SparkConf /** * An interface for all the broadcast implementations in Spark (to allow - * multiple broadcast implementations). SparkContext uses a user-specified - * BroadcastFactory implementation to instantiate a particular broadcast for the - * entire Spark job. + * multiple broadcast implementations). SparkContext uses a BroadcastFactory + * implementation to instantiate a particular broadcast for the entire Spark job. */ private[spark] trait BroadcastFactory { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index f350784378..22d01c47e6 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -207,11 +207,15 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) TorrentBroadcast.synchronized { setConf(SparkEnv.get.conf) val blockManager = SparkEnv.get.blockManager - blockManager.getLocalValues(broadcastId).map(_.data.next()) match { - case Some(x) => - releaseLock(broadcastId) - x.asInstanceOf[T] - + blockManager.getLocalValues(broadcastId) match { + case Some(blockResult) => + if (blockResult.data.hasNext) { + val x = blockResult.data.next().asInstanceOf[T] + releaseLock(broadcastId) + x + } else { + throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") + } case None => logInfo("Started reading broadcast variable " + id) val startTimeMs = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala index 3d2cabcdfd..050778a895 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala @@ -176,26 +176,31 @@ private[deploy] object RPackageUtils extends Logging { val file = new File(Utils.resolveURI(jarPath)) if (file.exists()) { val jar = new JarFile(file) - if (checkManifestForR(jar)) { - print(s"$file contains R source code. Now installing package.", printStream, Level.INFO) - val rSource = extractRFolder(jar, printStream, verbose) - if (RUtils.rPackages.isEmpty) { - RUtils.rPackages = Some(Utils.createTempDir().getAbsolutePath) - } - try { - if (!rPackageBuilder(rSource, printStream, verbose, RUtils.rPackages.get)) { - print(s"ERROR: Failed to build R package in $file.", printStream) - print(RJarDoc, printStream) + Utils.tryWithSafeFinally { + if (checkManifestForR(jar)) { + print(s"$file contains R source code. Now installing package.", printStream, Level.INFO) + val rSource = extractRFolder(jar, printStream, verbose) + if (RUtils.rPackages.isEmpty) { + RUtils.rPackages = Some(Utils.createTempDir().getAbsolutePath) } - } finally { // clean up - if (!rSource.delete()) { - logWarning(s"Error deleting ${rSource.getPath()}") + try { + if (!rPackageBuilder(rSource, printStream, verbose, RUtils.rPackages.get)) { + print(s"ERROR: Failed to build R package in $file.", printStream) + print(RJarDoc, printStream) + } + } finally { + // clean up + if (!rSource.delete()) { + logWarning(s"Error deleting ${rSource.getPath()}") + } + } + } else { + if (verbose) { + print(s"$file doesn't contain R source code, skipping...", printStream) } } - } else { - if (verbose) { - print(s"$file doesn't contain R source code, skipping...", printStream) - } + } { + jar.close() } } else { print(s"WARN: $file resolved as dependency, but not found.", printStream, Level.WARNING) @@ -231,8 +236,12 @@ private[deploy] object RPackageUtils extends Logging { val zipOutputStream = new ZipOutputStream(new FileOutputStream(zipFile, false)) try { filesToBundle.foreach { file => - // get the relative paths for proper naming in the zip file - val relPath = file.getAbsolutePath.replaceFirst(dir.getAbsolutePath, "") + // Get the relative paths for proper naming in the ZIP file. Note that + // we convert dir to URI to force / and then remove trailing / that show up for + // directories because the separator should always be / for according to ZIP + // specification and therefore `relPath` here should be, for example, + // "/packageTest/def.R" or "/test.R". + val relPath = file.toURI.toString.replaceFirst(dir.toURI.toString.stripSuffix("/"), "") val fis = new FileInputStream(file) val zipEntry = new ZipEntry(relPath) zipOutputStream.putNextEntry(zipEntry) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 8ef69b142c..3011ed0f95 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -446,9 +446,13 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } val logPath = fileStatus.getPath() - val appCompleted = isApplicationCompleted(fileStatus) + // Use loading time as lastUpdated since some filesystems don't update modifiedTime + // each time file is updated. However use modifiedTime for completed jobs so lastUpdated + // won't change whenever HistoryServer restarts and reloads the file. + val lastUpdated = if (appCompleted) fileStatus.getModificationTime else clock.getTimeMillis() + val appListener = replay(fileStatus, appCompleted, new ReplayListenerBus(), eventsFilter) // Without an app ID, new logs will render incorrectly in the listing page, so do not list or @@ -461,7 +465,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) appListener.appAttemptId, appListener.startTime.getOrElse(-1L), appListener.endTime.getOrElse(-1L), - fileStatus.getModificationTime(), + lastUpdated, appListener.sparkUser.getOrElse(NOT_STARTED), appCompleted, fileStatus.getLen() @@ -546,7 +550,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val appsToRetain = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() def shouldClean(attempt: FsApplicationAttemptInfo): Boolean = { - now - attempt.lastUpdated > maxAge && attempt.completed + now - attempt.lastUpdated > maxAge } // Scan all logs from the log directory. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 7e21fa681a..2b00a4a6b3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -269,7 +269,7 @@ object HistoryServer extends Logging { Utils.initDaemon(log) new HistoryServerArguments(conf, argStrings) initSecurity() - val securityManager = new SecurityManager(conf) + val securityManager = createSecurityManager(conf) val providerName = conf.getOption("spark.history.provider") .getOrElse(classOf[FsHistoryProvider].getName()) @@ -289,6 +289,21 @@ object HistoryServer extends Logging { while(true) { Thread.sleep(Int.MaxValue) } } + /** + * Create a security manager. + * This turns off security in the SecurityManager, so that the the History Server can start + * in a Spark cluster where security is enabled. + * @param config configuration for the SecurityManager constructor + * @return the security manager for use in constructing the History Server. + */ + private[history] def createSecurityManager(config: SparkConf): SecurityManager = { + if (config.getBoolean(SecurityManager.SPARK_AUTH_CONF, false)) { + logDebug(s"Clearing ${SecurityManager.SPARK_AUTH_CONF}") + config.set(SecurityManager.SPARK_AUTH_CONF, "false") + } + new SecurityManager(config) + } + def initSecurity() { // If we are accessing HDFS and it has security enabled (Kerberos), we have to login // from a keytab file so that we can access HDFS beyond the kerberos ticket expiration. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 2eddb5ff54..080ba12c2f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils /** - * Command-line parser for the master. + * Command-line parser for the [[HistoryServer]]. */ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 3fb860582c..ebbbbd3b71 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -176,8 +176,15 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private def workerRow(worker: WorkerInfo): Seq[Node] = { - {worker.id} + { + if (worker.isAlive()) { + + {worker.id} + + } else { + worker.id + } + } {worker.host}:{worker.port} {worker.state} @@ -247,10 +254,13 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {driver.id} {killLink} {driver.submitDate} {driver.worker.map(w => - - {w.id.toString} - ).getOrElse("None")} + if (w.isAlive()) { + + {w.id.toString} + + } else { + w.id.toString + }).getOrElse("None")} {driver.state} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 8b1c6bf2e5..0940f3c558 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -187,8 +187,7 @@ private[deploy] class Worker( webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() - val scheme = if (webUi.sslOptions.enabled) "https" else "http" - workerWebUiUrl = s"$scheme://$publicAddress:${webUi.boundPort}" + workerWebUiUrl = s"http://$publicAddress:${webUi.boundPort}" registerWithMaster() metricsSystem.registerSource(workerSource) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9501dd9cd8..3346f6dd1f 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -84,6 +84,16 @@ private[spark] class Executor( // Start worker thread pool private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker") private val executorSource = new ExecutorSource(threadPool, executorId) + // Pool used for threads that supervise task killing / cancellation + private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper") + // For tasks which are in the process of being killed, this map holds the most recently created + // TaskReaper. All accesses to this map should be synchronized on the map itself (this isn't + // a ConcurrentHashMap because we use the synchronization for purposes other than simply guarding + // the integrity of the map's internal state). The purpose of this map is to prevent the creation + // of a separate TaskReaper for every killTask() of a given task. Instead, this map allows us to + // track whether an existing TaskReaper fulfills the role of a TaskReaper that we would otherwise + // create. The map key is a task id. + private val taskReaperForTask: HashMap[Long, TaskReaper] = HashMap[Long, TaskReaper]() if (!isLocal) { env.metricsSystem.registerSource(executorSource) @@ -93,6 +103,9 @@ private[spark] class Executor( // Whether to load classes in user jars before those in Spark jars private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false) + // Whether to monitor killed / interrupted tasks + private val taskReaperEnabled = conf.getBoolean("spark.task.reaper.enabled", false) + // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager private val urlClassLoader = createClassLoader() @@ -148,9 +161,27 @@ private[spark] class Executor( } def killTask(taskId: Long, interruptThread: Boolean): Unit = { - val tr = runningTasks.get(taskId) - if (tr != null) { - tr.kill(interruptThread) + val taskRunner = runningTasks.get(taskId) + if (taskRunner != null) { + if (taskReaperEnabled) { + val maybeNewTaskReaper: Option[TaskReaper] = taskReaperForTask.synchronized { + val shouldCreateReaper = taskReaperForTask.get(taskId) match { + case None => true + case Some(existingReaper) => interruptThread && !existingReaper.interruptThread + } + if (shouldCreateReaper) { + val taskReaper = new TaskReaper(taskRunner, interruptThread = interruptThread) + taskReaperForTask(taskId) = taskReaper + Some(taskReaper) + } else { + None + } + } + // Execute the TaskReaper from outside of the synchronized block. + maybeNewTaskReaper.foreach(taskReaperPool.execute) + } else { + taskRunner.kill(interruptThread = interruptThread) + } } } @@ -161,12 +192,7 @@ private[spark] class Executor( * @param interruptThread whether to interrupt the task thread */ def killAllTasks(interruptThread: Boolean) : Unit = { - // kill all the running tasks - for (taskRunner <- runningTasks.values().asScala) { - if (taskRunner != null) { - taskRunner.kill(interruptThread) - } - } + runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread)) } def stop(): Unit = { @@ -192,13 +218,21 @@ private[spark] class Executor( serializedTask: ByteBuffer) extends Runnable { + val threadName = s"Executor task launch worker for task $taskId" + /** Whether this task has been killed. */ @volatile private var killed = false + @volatile private var threadId: Long = -1 + + def getThreadId: Long = threadId + /** Whether this task has been finished. */ @GuardedBy("TaskRunner.this") private var finished = false + def isFinished: Boolean = synchronized { finished } + /** How much the JVM process has spent in GC when the task starts to run. */ @volatile var startGCTime: Long = _ @@ -229,9 +263,15 @@ private[spark] class Executor( // ClosedByInterruptException during execBackend.statusUpdate which causes // Executor to crash Thread.interrupted() + // Notify any waiting TaskReapers. Generally there will only be one reaper per task but there + // is a rare corner-case where one task can have two reapers in case cancel(interrupt=False) + // is followed by cancel(interrupt=True). Thus we use notifyAll() to avoid a lost wakeup: + notifyAll() } override def run(): Unit = { + threadId = Thread.currentThread.getId + Thread.currentThread.setName(threadName) val threadMXBean = ManagementFactory.getThreadMXBean val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId) val deserializeStartTime = System.currentTimeMillis() @@ -431,6 +471,117 @@ private[spark] class Executor( } } + /** + * Supervises the killing / cancellation of a task by sending the interrupted flag, optionally + * sending a Thread.interrupt(), and monitoring the task until it finishes. + * + * Spark's current task cancellation / task killing mechanism is "best effort" because some tasks + * may not be interruptable or may not respond to their "killed" flags being set. If a significant + * fraction of a cluster's task slots are occupied by tasks that have been marked as killed but + * remain running then this can lead to a situation where new jobs and tasks are starved of + * resources that are being used by these zombie tasks. + * + * The TaskReaper was introduced in SPARK-18761 as a mechanism to monitor and clean up zombie + * tasks. For backwards-compatibility / backportability this component is disabled by default + * and must be explicitly enabled by setting `spark.task.reaper.enabled=true`. + * + * A TaskReaper is created for a particular task when that task is killed / cancelled. Typically + * a task will have only one TaskReaper, but it's possible for a task to have up to two reapers + * in case kill is called twice with different values for the `interrupt` parameter. + * + * Once created, a TaskReaper will run until its supervised task has finished running. If the + * TaskReaper has not been configured to kill the JVM after a timeout (i.e. if + * `spark.task.reaper.killTimeout < 0`) then this implies that the TaskReaper may run indefinitely + * if the supervised task never exits. + */ + private class TaskReaper( + taskRunner: TaskRunner, + val interruptThread: Boolean) + extends Runnable { + + private[this] val taskId: Long = taskRunner.taskId + + private[this] val killPollingIntervalMs: Long = + conf.getTimeAsMs("spark.task.reaper.pollingInterval", "10s") + + private[this] val killTimeoutMs: Long = conf.getTimeAsMs("spark.task.reaper.killTimeout", "-1") + + private[this] val takeThreadDump: Boolean = + conf.getBoolean("spark.task.reaper.threadDump", true) + + override def run(): Unit = { + val startTimeMs = System.currentTimeMillis() + def elapsedTimeMs = System.currentTimeMillis() - startTimeMs + def timeoutExceeded(): Boolean = killTimeoutMs > 0 && elapsedTimeMs > killTimeoutMs + try { + // Only attempt to kill the task once. If interruptThread = false then a second kill + // attempt would be a no-op and if interruptThread = true then it may not be safe or + // effective to interrupt multiple times: + taskRunner.kill(interruptThread = interruptThread) + // Monitor the killed task until it exits. The synchronization logic here is complicated + // because we don't want to synchronize on the taskRunner while possibly taking a thread + // dump, but we also need to be careful to avoid races between checking whether the task + // has finished and wait()ing for it to finish. + var finished: Boolean = false + while (!finished && !timeoutExceeded()) { + taskRunner.synchronized { + // We need to synchronize on the TaskRunner while checking whether the task has + // finished in order to avoid a race where the task is marked as finished right after + // we check and before we call wait(). + if (taskRunner.isFinished) { + finished = true + } else { + taskRunner.wait(killPollingIntervalMs) + } + } + if (taskRunner.isFinished) { + finished = true + } else { + logWarning(s"Killed task $taskId is still running after $elapsedTimeMs ms") + if (takeThreadDump) { + try { + Utils.getThreadDumpForThread(taskRunner.getThreadId).foreach { thread => + if (thread.threadName == taskRunner.threadName) { + logWarning(s"Thread dump from task $taskId:\n${thread.stackTrace}") + } + } + } catch { + case NonFatal(e) => + logWarning("Exception thrown while obtaining thread dump: ", e) + } + } + } + } + + if (!taskRunner.isFinished && timeoutExceeded()) { + if (isLocal) { + logError(s"Killed task $taskId could not be stopped within $killTimeoutMs ms; " + + "not killing JVM because we are running in local mode.") + } else { + // In non-local-mode, the exception thrown here will bubble up to the uncaught exception + // handler and cause the executor JVM to exit. + throw new SparkException( + s"Killing executor JVM because killed task $taskId could not be stopped within " + + s"$killTimeoutMs ms.") + } + } + } finally { + // Clean up entries in the taskReaperForTask map. + taskReaperForTask.synchronized { + taskReaperForTask.get(taskId).foreach { taskReaperInMap => + if (taskReaperInMap eq this) { + taskReaperForTask.remove(taskId) + } else { + // This must have been a TaskReaper where interruptThread == false where a subsequent + // killTask() call for the same task had interruptThread == true and overwrote the + // map entry. + } + } + } + } + } + } + /** * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes * created by the interpreter to the search path diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index a69a2b5645..aba429bcdc 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -114,11 +114,21 @@ package object config { .intConf .createWithDefault(2) + private[spark] val MAX_FAILURES_PER_EXEC = + ConfigBuilder("spark.blacklist.application.maxFailedTasksPerExecutor") + .intConf + .createWithDefault(2) + private[spark] val MAX_FAILURES_PER_EXEC_STAGE = ConfigBuilder("spark.blacklist.stage.maxFailedTasksPerExecutor") .intConf .createWithDefault(2) + private[spark] val MAX_FAILED_EXEC_PER_NODE = + ConfigBuilder("spark.blacklist.application.maxFailedExecutorsPerNode") + .intConf + .createWithDefault(2) + private[spark] val MAX_FAILED_EXEC_PER_NODE_STAGE = ConfigBuilder("spark.blacklist.stage.maxFailedExecutorsPerNode") .intConf @@ -198,12 +208,13 @@ package object config { .createWithDefault(0) private[spark] val DRIVER_BLOCK_MANAGER_PORT = ConfigBuilder("spark.driver.blockManager.port") - .doc("Port to use for the block managed on the driver.") + .doc("Port to use for the block manager on the driver.") .fallbackConf(BLOCK_MANAGER_PORT) private[spark] val IGNORE_CORRUPT_FILES = ConfigBuilder("spark.files.ignoreCorruptFiles") .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + - "encountering corrupt files and contents that have been read will still be returned.") + "encountering corrupted or non-existing files and contents that have been read will still " + + "be returned.") .booleanConf .createWithDefault(false) diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala index aaeb3d0038..6de1fc0685 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala @@ -146,7 +146,7 @@ object SparkHadoopMapReduceWriter extends Logging { case c: Configurable => c.setConf(hadoopConf) case _ => () } - val writer = taskFormat.getRecordWriter(taskContext) + var writer = taskFormat.getRecordWriter(taskContext) .asInstanceOf[RecordWriter[K, V]] require(writer != null, "Unable to obtain RecordWriter") var recordsWritten = 0L @@ -154,6 +154,7 @@ object SparkHadoopMapReduceWriter extends Logging { // Write all rows in RDD partition. try { val ret = Utils.tryWithSafeFinallyAndFailureCallbacks { + // Write rows out, release resource and commit the task. while (iterator.hasNext) { val pair = iterator.next() writer.write(pair._1, pair._2) @@ -163,12 +164,23 @@ object SparkHadoopMapReduceWriter extends Logging { outputMetricsAndBytesWrittenCallback, recordsWritten) recordsWritten += 1 } - + if (writer != null) { + writer.close(taskContext) + writer = null + } committer.commitTask(taskContext) }(catchBlock = { - committer.abortTask(taskContext) - logError(s"Task ${taskContext.getTaskAttemptID} aborted.") - }, finallyBlock = writer.close(taskContext)) + // If there is an error, release resource and then abort the task. + try { + if (writer != null) { + writer.close(taskContext) + writer = null + } + } finally { + committer.abortTask(taskContext) + logError(s"Task ${taskContext.getTaskAttemptID} aborted.") + } + }) outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index 86874e2067..df520f804b 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -17,6 +17,8 @@ package org.apache.spark.network.netty +import scala.collection.JavaConverters._ + import org.apache.spark.SparkConf import org.apache.spark.network.util.{ConfigProvider, TransportConf} @@ -58,6 +60,10 @@ object SparkTransportConf { new TransportConf(module, new ConfigProvider { override def get(name: String): String = conf.get(name) + + override def getAll(): java.lang.Iterable[java.util.Map.Entry[String, String]] = { + conf.getAll.toMap.asJava.entrySet() + } }) } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index ae4320d458..a83e139c13 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -131,9 +131,9 @@ class HadoopRDD[K, V]( minPartitions) } - protected val jobConfCacheKey = "rdd_%d_job_conf".format(id) + protected val jobConfCacheKey: String = "rdd_%d_job_conf".format(id) - protected val inputFormatCacheKey = "rdd_%d_input_format".format(id) + protected val inputFormatCacheKey: String = "rdd_%d_input_format".format(id) // used to build JobTracker ID private val createTime = new Date() @@ -210,22 +210,24 @@ class HadoopRDD[K, V]( override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { val iter = new NextIterator[(K, V)] { - val split = theSplit.asInstanceOf[HadoopPartition] + private val split = theSplit.asInstanceOf[HadoopPartition] logInfo("Input split: " + split.inputSplit) - val jobConf = getJobConf() + private val jobConf = getJobConf() - val inputMetrics = context.taskMetrics().inputMetrics - val existingBytesRead = inputMetrics.bytesRead + private val inputMetrics = context.taskMetrics().inputMetrics + private val existingBytesRead = inputMetrics.bytesRead - // Sets the thread local variable for the file's name + // Sets InputFileBlockHolder for the file block's information split.inputSplit.value match { - case fs: FileSplit => InputFileNameHolder.setInputFileName(fs.getPath.toString) - case _ => InputFileNameHolder.unsetInputFileName() + case fs: FileSplit => + InputFileBlockHolder.set(fs.getPath.toString, fs.getStart, fs.getLength) + case _ => + InputFileBlockHolder.unset() } // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { + private val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { case _: FileSplit | _: CombineFileSplit => SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None @@ -235,29 +237,39 @@ class HadoopRDD[K, V]( // If we do a coalesce, however, we are likely to compute multiple partitions in the same // task and in the same thread, in which case we need to avoid override values written by // previous partitions (SPARK-13071). - def updateBytesRead(): Unit = { + private def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } - var reader: RecordReader[K, V] = null - val inputFormat = getInputFormat(jobConf) + private var reader: RecordReader[K, V] = null + private val inputFormat = getInputFormat(jobConf) HadoopRDD.addLocalConfiguration( new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime), context.stageId, theSplit.index, context.attemptNumber, jobConf) - reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + reader = + try { + inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + } catch { + case e: IOException if ignoreCorruptFiles => + logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) + finished = true + null + } // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener{ context => closeIfNeeded() } - val key: K = reader.createKey() - val value: V = reader.createValue() + private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey() + private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue() override def getNext(): (K, V) = { try { finished = !reader.next(key, value) } catch { - case e: IOException if ignoreCorruptFiles => finished = true + case e: IOException if ignoreCorruptFiles => + logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) + finished = true } if (!finished) { inputMetrics.incRecordsRead(1) @@ -270,7 +282,7 @@ class HadoopRDD[K, V]( override def close() { if (reader != null) { - InputFileNameHolder.unsetInputFileName() + InputFileBlockHolder.unset() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala new file mode 100644 index 0000000000..9ba476d2ba --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.unsafe.types.UTF8String + +/** + * This holds file names of the current Spark task. This is used in HadoopRDD, + * FileScanRDD, NewHadoopRDD and InputFileName function in Spark SQL. + */ +private[spark] object InputFileBlockHolder { + /** + * A wrapper around some input file information. + * + * @param filePath path of the file read, or empty string if not available. + * @param startOffset starting offset, in bytes, or -1 if not available. + * @param length size of the block, in bytes, or -1 if not available. + */ + private class FileBlock(val filePath: UTF8String, val startOffset: Long, val length: Long) { + def this() { + this(UTF8String.fromString(""), -1, -1) + } + } + + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputBlock: ThreadLocal[FileBlock] = new ThreadLocal[FileBlock] { + override protected def initialValue(): FileBlock = new FileBlock + } + + /** + * Returns the holding file name or empty string if it is unknown. + */ + def getInputFilePath: UTF8String = inputBlock.get().filePath + + /** + * Returns the starting offset of the block currently being read, or -1 if it is unknown. + */ + def getStartOffset: Long = inputBlock.get().startOffset + + /** + * Returns the length of the block being read, or -1 if it is unknown. + */ + def getLength: Long = inputBlock.get().length + + /** + * Sets the thread-local input block. + */ + def set(filePath: String, startOffset: Long, length: Long): Unit = { + require(filePath != null, "filePath cannot be null") + require(startOffset >= 0, s"startOffset ($startOffset) cannot be negative") + require(length >= 0, s"length ($length) cannot be negative") + inputBlock.set(new FileBlock(UTF8String.fromString(filePath), startOffset, length)) + } + + /** + * Clears the input file block to default value. + */ + def unset(): Unit = inputBlock.remove() +} diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala deleted file mode 100644 index 960c91a154..0000000000 --- a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import org.apache.spark.unsafe.types.UTF8String - -/** - * This holds file names of the current Spark task. This is used in HadoopRDD, - * FileScanRDD, NewHadoopRDD and InputFileName function in Spark SQL. - * - * The returned value should never be null but empty string if it is unknown. - */ -private[spark] object InputFileNameHolder { - /** - * The thread variable for the name of the current file being read. This is used by - * the InputFileName function in Spark SQL. - */ - private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { - override protected def initialValue(): UTF8String = UTF8String.fromString("") - } - - /** - * Returns the holding file name or empty string if it is unknown. - */ - def getInputFileName(): UTF8String = inputFileName.get() - - private[spark] def setInputFileName(file: String) = { - require(file != null, "The input file name cannot be null") - inputFileName.set(UTF8String.fromString(file)) - } - - private[spark] def unsetInputFileName(): Unit = inputFileName.remove() - -} diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index c783e13752..733e85f305 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -132,61 +132,79 @@ class NewHadoopRDD[K, V]( override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { val iter = new Iterator[(K, V)] { - val split = theSplit.asInstanceOf[NewHadoopPartition] + private val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) - val conf = getConf + private val conf = getConf - val inputMetrics = context.taskMetrics().inputMetrics - val existingBytesRead = inputMetrics.bytesRead + private val inputMetrics = context.taskMetrics().inputMetrics + private val existingBytesRead = inputMetrics.bytesRead - // Sets the thread local variable for the file's name + // Sets InputFileBlockHolder for the file block's information split.serializableHadoopSplit.value match { - case fs: FileSplit => InputFileNameHolder.setInputFileName(fs.getPath.toString) - case _ => InputFileNameHolder.unsetInputFileName() + case fs: FileSplit => + InputFileBlockHolder.set(fs.getPath.toString, fs.getStart, fs.getLength) + case _ => + InputFileBlockHolder.unset() } // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None - } + private val getBytesReadCallback: Option[() => Long] = + split.serializableHadoopSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. // If we do a coalesce, however, we are likely to compute multiple partitions in the same // task and in the same thread, in which case we need to avoid override values written by // previous partitions (SPARK-13071). - def updateBytesRead(): Unit = { + private def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } - val format = inputFormatClass.newInstance + private val format = inputFormatClass.newInstance format match { case configurable: Configurable => configurable.setConf(conf) case _ => } - val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) - val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - private var reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + private val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) + private val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + private var finished = false + private var reader = + try { + val _reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + _reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + _reader + } catch { + case e: IOException if ignoreCorruptFiles => + logWarning( + s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", + e) + finished = true + null + } // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) - var havePair = false - var finished = false - var recordsSinceMetricsUpdate = 0 + private var havePair = false + private var recordsSinceMetricsUpdate = 0 override def hasNext: Boolean = { if (!finished && !havePair) { try { finished = !reader.nextKeyValue } catch { - case e: IOException if ignoreCorruptFiles => finished = true + case e: IOException if ignoreCorruptFiles => + logWarning( + s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", + e) + finished = true } if (finished) { // Close and release the reader here; close() will also be called when the task @@ -215,7 +233,7 @@ class NewHadoopRDD[K, V]( private def close() { if (reader != null) { - InputFileNameHolder.unsetInputFileName() + InputFileBlockHolder.unset() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index d285e917b8..374abccf6a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1746,7 +1746,7 @@ abstract class RDD[T: ClassTag]( /** * Clears the dependencies of this RDD. This method must ensure that all references - * to the original parent RDDs is removed to enable the parent RDDs to be garbage + * to the original parent RDDs are removed to enable the parent RDDs to be garbage * collected. Subclasses of RDD may override this method for implementing their own cleaning * logic. See [[org.apache.spark.rdd.UnionRDD]] for an example. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 29d5d74650..26eaa9aa3d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -25,10 +25,6 @@ import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { override val index: Int = idx - - override def hashCode(): Int = index - - override def equals(other: Any): Boolean = super.equals(other) } /** diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala index f527ec86ab..117f51c5b8 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala @@ -18,7 +18,7 @@ package org.apache.spark.rpc /** - * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe + * A callback that [[RpcEndpoint]] can use to send back a message or failure. It's thread-safe * and can be called in any thread. */ private[spark] trait RpcCallContext { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index bbc4163814..530743c036 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -146,11 +146,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * @param uri URI with location of the file. */ def openChannel(uri: String): ReadableByteChannel - - /** - * Return if the current thread is a RPC thread. - */ - def isInRPCThread: Boolean } /** diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala index 2761d39e37..efd26486ab 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -24,7 +24,7 @@ import scala.concurrent.duration._ import scala.util.control.NonFatal import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. @@ -72,15 +72,9 @@ private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: S * is still not ready */ def awaitResult[T](future: Future[T]): T = { - val wrapAndRethrow: PartialFunction[Throwable, T] = { - case NonFatal(t) => - throw new SparkException("Exception thrown in awaitResult", t) - } try { - // scalastyle:off awaitresult - Await.result(future, duration) - // scalastyle:on awaitresult - } catch addMessageIfTimeout.orElse(wrapAndRethrow) + ThreadUtils.awaitResult(future, duration) + } catch addMessageIfTimeout } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 67baabd2cb..a02cf30a5d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -201,7 +201,6 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { /** Message loop used for dispatching messages. */ private class MessageLoop extends Runnable { override def run(): Unit = { - NettyRpcEnv.rpcThreadFlag.value = true try { while (true) { try { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 0b8cd144a2..e56943da13 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -407,14 +407,9 @@ private[netty] class NettyRpcEnv( } } - - override def isInRPCThread: Boolean = NettyRpcEnv.rpcThreadFlag.value } private[netty] object NettyRpcEnv extends Logging { - - private[netty] val rpcThreadFlag = new DynamicVariable[Boolean](false) - /** * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]]. * Use `currentEnv` to wrap the deserialization codes. E.g., diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index fca4c6d37e..bf7a62ea33 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -17,10 +17,274 @@ package org.apache.spark.scheduler +import java.util.concurrent.atomic.AtomicReference + +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} + import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.internal.config -import org.apache.spark.util.Utils +import org.apache.spark.util.{Clock, SystemClock, Utils} + +/** + * BlacklistTracker is designed to track problematic executors and nodes. It supports blacklisting + * executors and nodes across an entire application (with a periodic expiry). TaskSetManagers add + * additional blacklisting of executors and nodes for individual tasks and stages which works in + * concert with the blacklisting here. + * + * The tracker needs to deal with a variety of workloads, eg.: + * + * * bad user code -- this may lead to many task failures, but that should not count against + * individual executors + * * many small stages -- this may prevent a bad executor for having many failures within one + * stage, but still many failures over the entire application + * * "flaky" executors -- they don't fail every task, but are still faulty enough to merit + * blacklisting + * + * See the design doc on SPARK-8425 for a more in-depth discussion. + * + * THREADING: As with most helpers of TaskSchedulerImpl, this is not thread-safe. Though it is + * called by multiple threads, callers must already have a lock on the TaskSchedulerImpl. The + * one exception is [[nodeBlacklist()]], which can be called without holding a lock. + */ +private[scheduler] class BlacklistTracker ( + conf: SparkConf, + clock: Clock = new SystemClock()) extends Logging { + + BlacklistTracker.validateBlacklistConfs(conf) + private val MAX_FAILURES_PER_EXEC = conf.get(config.MAX_FAILURES_PER_EXEC) + private val MAX_FAILED_EXEC_PER_NODE = conf.get(config.MAX_FAILED_EXEC_PER_NODE) + val BLACKLIST_TIMEOUT_MILLIS = BlacklistTracker.getBlacklistTimeout(conf) + + /** + * A map from executorId to information on task failures. Tracks the time of each task failure, + * so that we can avoid blacklisting executors due to failures that are very far apart. We do not + * actively remove from this as soon as tasks hit their timeouts, to avoid the time it would take + * to do so. But it will not grow too large, because as soon as an executor gets too many + * failures, we blacklist the executor and remove its entry here. + */ + private val executorIdToFailureList = new HashMap[String, ExecutorFailureList]() + val executorIdToBlacklistStatus = new HashMap[String, BlacklistedExecutor]() + val nodeIdToBlacklistExpiryTime = new HashMap[String, Long]() + /** + * An immutable copy of the set of nodes that are currently blacklisted. Kept in an + * AtomicReference to make [[nodeBlacklist()]] thread-safe. + */ + private val _nodeBlacklist = new AtomicReference[Set[String]](Set()) + /** + * Time when the next blacklist will expire. Used as a + * shortcut to avoid iterating over all entries in the blacklist when none will have expired. + */ + var nextExpiryTime: Long = Long.MaxValue + /** + * Mapping from nodes to all of the executors that have been blacklisted on that node. We do *not* + * remove from this when executors are removed from spark, so we can track when we get multiple + * successive blacklisted executors on one node. Nonetheless, it will not grow too large because + * there cannot be many blacklisted executors on one node, before we stop requesting more + * executors on that node, and we clean up the list of blacklisted executors once an executor has + * been blacklisted for BLACKLIST_TIMEOUT_MILLIS. + */ + val nodeToBlacklistedExecs = new HashMap[String, HashSet[String]]() + + /** + * Un-blacklists executors and nodes that have been blacklisted for at least + * BLACKLIST_TIMEOUT_MILLIS + */ + def applyBlacklistTimeout(): Unit = { + val now = clock.getTimeMillis() + // quickly check if we've got anything to expire from blacklist -- if not, avoid doing any work + if (now > nextExpiryTime) { + // Apply the timeout to blacklisted nodes and executors + val execsToUnblacklist = executorIdToBlacklistStatus.filter(_._2.expiryTime < now).keys + if (execsToUnblacklist.nonEmpty) { + // Un-blacklist any executors that have been blacklisted longer than the blacklist timeout. + logInfo(s"Removing executors $execsToUnblacklist from blacklist because the blacklist " + + s"for those executors has timed out") + execsToUnblacklist.foreach { exec => + val status = executorIdToBlacklistStatus.remove(exec).get + val failedExecsOnNode = nodeToBlacklistedExecs(status.node) + failedExecsOnNode.remove(exec) + if (failedExecsOnNode.isEmpty) { + nodeToBlacklistedExecs.remove(status.node) + } + } + } + val nodesToUnblacklist = nodeIdToBlacklistExpiryTime.filter(_._2 < now).keys + if (nodesToUnblacklist.nonEmpty) { + // Un-blacklist any nodes that have been blacklisted longer than the blacklist timeout. + logInfo(s"Removing nodes $nodesToUnblacklist from blacklist because the blacklist " + + s"has timed out") + nodeIdToBlacklistExpiryTime --= nodesToUnblacklist + _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) + } + updateNextExpiryTime() + } + } + + private def updateNextExpiryTime(): Unit = { + val execMinExpiry = if (executorIdToBlacklistStatus.nonEmpty) { + executorIdToBlacklistStatus.map{_._2.expiryTime}.min + } else { + Long.MaxValue + } + val nodeMinExpiry = if (nodeIdToBlacklistExpiryTime.nonEmpty) { + nodeIdToBlacklistExpiryTime.values.min + } else { + Long.MaxValue + } + nextExpiryTime = math.min(execMinExpiry, nodeMinExpiry) + } + + + def updateBlacklistForSuccessfulTaskSet( + stageId: Int, + stageAttemptId: Int, + failuresByExec: HashMap[String, ExecutorFailuresInTaskSet]): Unit = { + // if any tasks failed, we count them towards the overall failure count for the executor at + // this point. + val now = clock.getTimeMillis() + failuresByExec.foreach { case (exec, failuresInTaskSet) => + val appFailuresOnExecutor = + executorIdToFailureList.getOrElseUpdate(exec, new ExecutorFailureList) + appFailuresOnExecutor.addFailures(stageId, stageAttemptId, failuresInTaskSet) + appFailuresOnExecutor.dropFailuresWithTimeoutBefore(now) + val newTotal = appFailuresOnExecutor.numUniqueTaskFailures + + val expiryTimeForNewBlacklists = now + BLACKLIST_TIMEOUT_MILLIS + // If this pushes the total number of failures over the threshold, blacklist the executor. + // If its already blacklisted, we avoid "re-blacklisting" (which can happen if there were + // other tasks already running in another taskset when it got blacklisted), because it makes + // some of the logic around expiry times a little more confusing. But it also wouldn't be a + // problem to re-blacklist, with a later expiry time. + if (newTotal >= MAX_FAILURES_PER_EXEC && !executorIdToBlacklistStatus.contains(exec)) { + logInfo(s"Blacklisting executor id: $exec because it has $newTotal" + + s" task failures in successful task sets") + val node = failuresInTaskSet.node + executorIdToBlacklistStatus.put(exec, BlacklistedExecutor(node, expiryTimeForNewBlacklists)) + updateNextExpiryTime() + + // In addition to blacklisting the executor, we also update the data for failures on the + // node, and potentially put the entire node into a blacklist as well. + val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(node, HashSet[String]()) + blacklistedExecsOnNode += exec + // If the node is already in the blacklist, we avoid adding it again with a later expiry + // time. + if (blacklistedExecsOnNode.size >= MAX_FAILED_EXEC_PER_NODE && + !nodeIdToBlacklistExpiryTime.contains(node)) { + logInfo(s"Blacklisting node $node because it has ${blacklistedExecsOnNode.size} " + + s"executors blacklisted: ${blacklistedExecsOnNode}") + nodeIdToBlacklistExpiryTime.put(node, expiryTimeForNewBlacklists) + _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) + } + } + } + } + + def isExecutorBlacklisted(executorId: String): Boolean = { + executorIdToBlacklistStatus.contains(executorId) + } + + /** + * Get the full set of nodes that are blacklisted. Unlike other methods in this class, this *IS* + * thread-safe -- no lock required on a taskScheduler. + */ + def nodeBlacklist(): Set[String] = { + _nodeBlacklist.get() + } + + def isNodeBlacklisted(node: String): Boolean = { + nodeIdToBlacklistExpiryTime.contains(node) + } + + def handleRemovedExecutor(executorId: String): Unit = { + // We intentionally do not clean up executors that are already blacklisted in + // nodeToBlacklistedExecs, so that if another executor on the same node gets blacklisted, we can + // blacklist the entire node. We also can't clean up executorIdToBlacklistStatus, so we can + // eventually remove the executor after the timeout. Despite not clearing those structures + // here, we don't expect they will grow too big since you won't get too many executors on one + // node, and the timeout will clear it up periodically in any case. + executorIdToFailureList -= executorId + } + + + /** + * Tracks all failures for one executor (that have not passed the timeout). + * + * In general we actually expect this to be extremely small, since it won't contain more than the + * maximum number of task failures before an executor is failed (default 2). + */ + private[scheduler] final class ExecutorFailureList extends Logging { + + private case class TaskId(stage: Int, stageAttempt: Int, taskIndex: Int) + + /** + * All failures on this executor in successful task sets. + */ + private var failuresAndExpiryTimes = ArrayBuffer[(TaskId, Long)]() + /** + * As an optimization, we track the min expiry time over all entries in failuresAndExpiryTimes + * so its quick to tell if there are any failures with expiry before the current time. + */ + private var minExpiryTime = Long.MaxValue + + def addFailures( + stage: Int, + stageAttempt: Int, + failuresInTaskSet: ExecutorFailuresInTaskSet): Unit = { + failuresInTaskSet.taskToFailureCountAndFailureTime.foreach { + case (taskIdx, (_, failureTime)) => + val expiryTime = failureTime + BLACKLIST_TIMEOUT_MILLIS + failuresAndExpiryTimes += ((TaskId(stage, stageAttempt, taskIdx), expiryTime)) + if (expiryTime < minExpiryTime) { + minExpiryTime = expiryTime + } + } + } + + /** + * The number of unique tasks that failed on this executor. Only counts failures within the + * timeout, and in successful tasksets. + */ + def numUniqueTaskFailures: Int = failuresAndExpiryTimes.size + + def isEmpty: Boolean = failuresAndExpiryTimes.isEmpty + + /** + * Apply the timeout to individual tasks. This is to prevent one-off failures that are very + * spread out in time (and likely have nothing to do with problems on the executor) from + * triggering blacklisting. However, note that we do *not* remove executors and nodes from + * the blacklist as we expire individual task failures -- each have their own timeout. Eg., + * suppose: + * * timeout = 10, maxFailuresPerExec = 2 + * * Task 1 fails on exec 1 at time 0 + * * Task 2 fails on exec 1 at time 5 + * --> exec 1 is blacklisted from time 5 - 15. + * This is to simplify the implementation, as well as keep the behavior easier to understand + * for the end user. + */ + def dropFailuresWithTimeoutBefore(dropBefore: Long): Unit = { + if (minExpiryTime < dropBefore) { + var newMinExpiry = Long.MaxValue + val newFailures = new ArrayBuffer[(TaskId, Long)] + failuresAndExpiryTimes.foreach { case (task, expiryTime) => + if (expiryTime >= dropBefore) { + newFailures += ((task, expiryTime)) + if (expiryTime < newMinExpiry) { + newMinExpiry = expiryTime + } + } + } + failuresAndExpiryTimes = newFailures + minExpiryTime = newMinExpiry + } + } + + override def toString(): String = { + s"failures = $failuresAndExpiryTimes" + } + } + +} private[scheduler] object BlacklistTracker extends Logging { @@ -80,7 +344,9 @@ private[scheduler] object BlacklistTracker extends Logging { config.MAX_TASK_ATTEMPTS_PER_EXECUTOR, config.MAX_TASK_ATTEMPTS_PER_NODE, config.MAX_FAILURES_PER_EXEC_STAGE, - config.MAX_FAILED_EXEC_PER_NODE_STAGE + config.MAX_FAILED_EXEC_PER_NODE_STAGE, + config.MAX_FAILURES_PER_EXEC, + config.MAX_FAILED_EXEC_PER_NODE ).foreach { config => val v = conf.get(config) if (v <= 0) { @@ -112,3 +378,5 @@ private[scheduler] object BlacklistTracker extends Logging { } } } + +private final case class BlacklistedExecutor(node: String, expiryTime: Long) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 7fde34d897..6177bafc11 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1009,13 +1009,14 @@ class DAGScheduler( } val tasks: Seq[Task[_]] = try { + val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array() stage match { case stage: ShuffleMapStage => partitionsToCompute.map { id => val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, stage.latestInfo.taskMetrics, properties, Option(jobId), + taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) } @@ -1025,7 +1026,7 @@ class DAGScheduler( val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, id, properties, stage.latestInfo.taskMetrics, + taskBinary, part, locs, id, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) } } @@ -1256,27 +1257,46 @@ class DAGScheduler( s"longer running") } - if (disallowStageRetryForTest) { - abortStage(failedStage, "Fetch failure will not retry stage due to testing config", - None) - } else if (failedStage.failedOnFetchAndShouldAbort(task.stageAttemptId)) { - abortStage(failedStage, s"$failedStage (${failedStage.name}) " + - s"has failed the maximum allowable number of " + - s"times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}. " + - s"Most recent failure reason: ${failureMessage}", None) - } else { - if (failedStages.isEmpty) { - // Don't schedule an event to resubmit failed stages if failed isn't empty, because - // in that case the event will already have been scheduled. - // TODO: Cancel running tasks in the stage - logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure") - messageScheduler.schedule(new Runnable { - override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) - }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + val shouldAbortStage = + failedStage.failedOnFetchAndShouldAbort(task.stageAttemptId) || + disallowStageRetryForTest + + if (shouldAbortStage) { + val abortMessage = if (disallowStageRetryForTest) { + "Fetch failure will not retry stage due to testing config" + } else { + s"""$failedStage (${failedStage.name}) + |has failed the maximum allowable number of + |times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}. + |Most recent failure reason: $failureMessage""".stripMargin.replaceAll("\n", " ") } + abortStage(failedStage, abortMessage, None) + } else { // update failedStages and make sure a ResubmitFailedStages event is enqueued + // TODO: Cancel running tasks in the failed stage -- cf. SPARK-17064 + val noResubmitEnqueued = !failedStages.contains(failedStage) failedStages += failedStage failedStages += mapStage + if (noResubmitEnqueued) { + // We expect one executor failure to trigger many FetchFailures in rapid succession, + // but all of those task failures can typically be handled by a single resubmission of + // the failed stage. We avoid flooding the scheduler's event queue with resubmit + // messages by checking whether a resubmit is already in the event queue for the + // failed stage. If there is already a resubmit enqueued for a different failed + // stage, that event would also be sufficient to handle the current failed stage, but + // producing a resubmit for each failed stage makes debugging and logging a little + // simpler while not producing an overwhelming number of scheduler events. + logInfo( + s"Resubmitting $mapStage (${mapStage.name}) and " + + s"$failedStage (${failedStage.name}) due to fetch failure" + ) + messageScheduler.schedule( + new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, + DAGScheduler.RESUBMIT_TIMEOUT, + TimeUnit.MILLISECONDS + ) + } } // Mark the map whose fetch failed as broken in the map stage if (mapId != -1) { @@ -1661,7 +1681,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler } catch { case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) } - dagScheduler.sc.stop() + dagScheduler.sc.stopInNewThread() } override def onStop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala index 20ab27d127..70553d8be2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala @@ -25,26 +25,30 @@ import scala.collection.mutable.HashMap private[scheduler] class ExecutorFailuresInTaskSet(val node: String) { /** * Mapping from index of the tasks in the taskset, to the number of times it has failed on this - * executor. + * executor and the most recent failure time. */ - val taskToFailureCount = HashMap[Int, Int]() + val taskToFailureCountAndFailureTime = HashMap[Int, (Int, Long)]() - def updateWithFailure(taskIndex: Int): Unit = { - val prevFailureCount = taskToFailureCount.getOrElse(taskIndex, 0) - taskToFailureCount(taskIndex) = prevFailureCount + 1 + def updateWithFailure(taskIndex: Int, failureTime: Long): Unit = { + val (prevFailureCount, prevFailureTime) = + taskToFailureCountAndFailureTime.getOrElse(taskIndex, (0, -1L)) + // these times always come from the driver, so we don't need to worry about skew, but might + // as well still be defensive in case there is non-monotonicity in the clock + val newFailureTime = math.max(prevFailureTime, failureTime) + taskToFailureCountAndFailureTime(taskIndex) = (prevFailureCount + 1, newFailureTime) } - def numUniqueTasksWithFailures: Int = taskToFailureCount.size + def numUniqueTasksWithFailures: Int = taskToFailureCountAndFailureTime.size /** * Return the number of times this executor has failed on the given task index. */ def getNumTaskFailures(index: Int): Int = { - taskToFailureCount.getOrElse(index, 0) + taskToFailureCountAndFailureTime.getOrElse(index, (0, 0))._1 } override def toString(): String = { s"numUniqueTasksWithFailures = $numUniqueTasksWithFailures; " + - s"tasksToFailureCount = $taskToFailureCount" + s"tasksToFailureCount = $taskToFailureCountAndFailureTime" } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index d19353f2a9..6abdf0fd53 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -42,7 +42,8 @@ import org.apache.spark.rdd.RDD * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). * @param localProperties copy of thread-local properties set by the user on the driver side. - * @param metrics a `TaskMetrics` that is created at driver side and sent to executor side. + * @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side + * and sent to executor side. * * The parameters below are optional: * @param jobId id of the job this task belongs to @@ -57,12 +58,12 @@ private[spark] class ResultTask[T, U]( locs: Seq[TaskLocation], val outputId: Int, localProperties: Properties, - metrics: TaskMetrics, + serializedTaskMetrics: Array[Byte], jobId: Option[Int] = None, appId: Option[String] = None, appAttemptId: Option[String] = None) - extends Task[U](stageId, stageAttemptId, partition.index, metrics, localProperties, jobId, - appId, appAttemptId) + extends Task[U](stageId, stageAttemptId, partition.index, localProperties, serializedTaskMetrics, + jobId, appId, appAttemptId) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 31011de85b..994b81e062 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -42,8 +42,9 @@ import org.apache.spark.shuffle.ShuffleWriter * the type should be (RDD[_], ShuffleDependency[_, _, _]). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling - * @param metrics a `TaskMetrics` that is created at driver side and sent to executor side. * @param localProperties copy of thread-local properties set by the user on the driver side. + * @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side + * and sent to executor side. * * The parameters below are optional: * @param jobId id of the job this task belongs to @@ -56,18 +57,18 @@ private[spark] class ShuffleMapTask( taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient private var locs: Seq[TaskLocation], - metrics: TaskMetrics, localProperties: Properties, + serializedTaskMetrics: Array[Byte], jobId: Option[Int] = None, appId: Option[String] = None, appAttemptId: Option[String] = None) - extends Task[MapStatus](stageId, stageAttemptId, partition.index, metrics, localProperties, jobId, - appId, appAttemptId) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, localProperties, + serializedTaskMetrics, jobId, appId, appAttemptId) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, 0, null, new Partition { override def index: Int = 0 }, null, null, new Properties) + this(0, 0, null, new Partition { override def index: Int = 0 }, null, new Properties, null) } @transient private val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 1554200aea..5becca6c06 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -48,6 +48,8 @@ import org.apache.spark.util._ * @param partitionId index of the number in the RDD * @param metrics a `TaskMetrics` that is created at driver side and sent to executor side. * @param localProperties copy of thread-local properties set by the user on the driver side. + * @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side + * and sent to executor side. * * The parameters below are optional: * @param jobId id of the job this task belongs to @@ -58,13 +60,17 @@ private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, val partitionId: Int, - // The default value is only used in tests. - val metrics: TaskMetrics = TaskMetrics.registered, @transient var localProperties: Properties = new Properties, + // The default value is only used in tests. + serializedTaskMetrics: Array[Byte] = + SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(), val jobId: Option[Int] = None, val appId: Option[String] = None, val appAttemptId: Option[String] = None) extends Serializable { + @transient lazy val metrics: TaskMetrics = + SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(serializedTaskMetrics)) + /** * Called by [[org.apache.spark.executor.Executor]] to run this task. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index b03cfe4f0d..9a8e313f9e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -51,13 +51,28 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} * acquire a lock on us, so we need to make sure that we don't try to lock the backend while * we are holding a lock on ourselves. */ -private[spark] class TaskSchedulerImpl( +private[spark] class TaskSchedulerImpl private[scheduler]( val sc: SparkContext, val maxTaskFailures: Int, + blacklistTrackerOpt: Option[BlacklistTracker], isLocal: Boolean = false) extends TaskScheduler with Logging { - def this(sc: SparkContext) = this(sc, sc.conf.get(config.MAX_TASK_FAILURES)) + + def this(sc: SparkContext) = { + this( + sc, + sc.conf.get(config.MAX_TASK_FAILURES), + TaskSchedulerImpl.maybeCreateBlacklistTracker(sc.conf)) + } + + def this(sc: SparkContext, maxTaskFailures: Int, isLocal: Boolean) = { + this( + sc, + maxTaskFailures, + TaskSchedulerImpl.maybeCreateBlacklistTracker(sc.conf), + isLocal = isLocal) + } val conf = sc.conf @@ -209,7 +224,7 @@ private[spark] class TaskSchedulerImpl( private[scheduler] def createTaskSetManager( taskSet: TaskSet, maxTaskFailures: Int): TaskSetManager = { - new TaskSetManager(this, taskSet, maxTaskFailures) + new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt) } override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { @@ -256,6 +271,8 @@ private[spark] class TaskSchedulerImpl( availableCpus: Array[Int], tasks: IndexedSeq[ArrayBuffer[TaskDescription]]) : Boolean = { var launchedTask = false + // nodes and executors that are blacklisted for the entire application have already been + // filtered out by this point for (i <- 0 until shuffledOffers.size) { val execId = shuffledOffers(i).executorId val host = shuffledOffers(i).host @@ -308,8 +325,20 @@ private[spark] class TaskSchedulerImpl( } } + // Before making any offers, remove any nodes from the blacklist whose blacklist has expired. Do + // this here to avoid a separate thread and added synchronization overhead, and also because + // updating the blacklist is only relevant when task offers are being made. + blacklistTrackerOpt.foreach(_.applyBlacklistTimeout()) + + val filteredOffers = blacklistTrackerOpt.map { blacklistTracker => + offers.filter { offer => + !blacklistTracker.isNodeBlacklisted(offer.host) && + !blacklistTracker.isExecutorBlacklisted(offer.executorId) + } + }.getOrElse(offers) + // Randomly shuffle offers to avoid always placing tasks on the same set of workers. - val shuffledOffers = Random.shuffle(offers) + val shuffledOffers = Random.shuffle(filteredOffers) // Build a list of tasks to assign to each worker. val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores)) val availableCpus = shuffledOffers.map(o => o.cores).toArray @@ -574,6 +603,7 @@ private[spark] class TaskSchedulerImpl( executorIdToHost -= executorId rootPool.executorLost(executorId, host, reason) } + blacklistTrackerOpt.foreach(_.handleRemovedExecutor(executorId)) } def executorAdded(execId: String, host: String) { @@ -600,6 +630,14 @@ private[spark] class TaskSchedulerImpl( executorIdToRunningTaskIds.get(execId).exists(_.nonEmpty) } + /** + * Get a snapshot of the currently blacklisted nodes for the entire application. This is + * thread-safe -- it can be called without a lock on the TaskScheduler. + */ + def nodeBlacklist(): scala.collection.immutable.Set[String] = { + blacklistTrackerOpt.map(_.nodeBlacklist()).getOrElse(scala.collection.immutable.Set()) + } + // By default, rack is unknown def getRackForHost(value: String): Option[String] = None @@ -678,4 +716,13 @@ private[spark] object TaskSchedulerImpl { retval.toList } + + private def maybeCreateBlacklistTracker(conf: SparkConf): Option[BlacklistTracker] = { + if (BlacklistTracker.isBlacklistEnabled(conf)) { + Some(new BlacklistTracker(conf)) + } else { + None + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala index f4b0f55b76..e815b7e0cf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala @@ -28,6 +28,10 @@ import org.apache.spark.util.Clock * (task, executor) / (task, nodes) pairs, and also completely blacklisting executors and nodes * for the entire taskset. * + * It also must store sufficient information in task failures for application level blacklisting, + * which is handled by [[BlacklistTracker]]. Note that BlacklistTracker does not know anything + * about task failures until a taskset completes successfully. + * * THREADING: This class is a helper to [[TaskSetManager]]; as with the methods in * [[TaskSetManager]] this class is designed only to be called from code with a lock on the * TaskScheduler (e.g. its event handlers). It should not be called from other threads. @@ -41,7 +45,9 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, private val MAX_FAILED_EXEC_PER_NODE_STAGE = conf.get(config.MAX_FAILED_EXEC_PER_NODE_STAGE) /** - * A map from each executor to the task failures on that executor. + * A map from each executor to the task failures on that executor. This is used for blacklisting + * within this taskset, and it is also relayed onto [[BlacklistTracker]] for app-level + * blacklisting if this taskset completes successfully. */ val execToFailures = new HashMap[String, ExecutorFailuresInTaskSet]() @@ -57,9 +63,9 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, /** * Return true if this executor is blacklisted for the given task. This does *not* - * need to return true if the executor is blacklisted for the entire stage. - * That is to keep this method as fast as possible in the inner-loop of the - * scheduler, where those filters will have already been applied. + * need to return true if the executor is blacklisted for the entire stage, or blacklisted + * for the entire application. That is to keep this method as fast as possible in the inner-loop + * of the scheduler, where those filters will have already been applied. */ def isExecutorBlacklistedForTask(executorId: String, index: Int): Boolean = { execToFailures.get(executorId).exists { execFailures => @@ -72,10 +78,10 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, } /** - * Return true if this executor is blacklisted for the given stage. Completely ignores - * anything to do with the node the executor is on. That - * is to keep this method as fast as possible in the inner-loop of the scheduler, where those - * filters will already have been applied. + * Return true if this executor is blacklisted for the given stage. Completely ignores whether + * the executor is blacklisted for the entire application (or anything to do with the node the + * executor is on). That is to keep this method as fast as possible in the inner-loop of the + * scheduler, where those filters will already have been applied. */ def isExecutorBlacklistedForTaskSet(executorId: String): Boolean = { blacklistedExecs.contains(executorId) @@ -90,7 +96,7 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, exec: String, index: Int): Unit = { val execFailures = execToFailures.getOrElseUpdate(exec, new ExecutorFailuresInTaskSet(host)) - execFailures.updateWithFailure(index) + execFailures.updateWithFailure(index, clock.getTimeMillis()) // check if this task has also failed on other executors on the same host -- if its gone // over the limit, blacklist this task from the entire host. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index f2a432cad3..3756c216f5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -51,6 +51,7 @@ private[spark] class TaskSetManager( sched: TaskSchedulerImpl, val taskSet: TaskSet, val maxTaskFailures: Int, + blacklistTracker: Option[BlacklistTracker] = None, clock: Clock = new SystemClock()) extends Schedulable with Logging { private val conf = sched.sc.conf @@ -85,10 +86,8 @@ private[spark] class TaskSetManager( var calculatedTasks = 0 private[scheduler] val taskSetBlacklistHelperOpt: Option[TaskSetBlacklist] = { - if (BlacklistTracker.isBlacklistEnabled(conf)) { - Some(new TaskSetBlacklist(conf, stageId, clock)) - } else { - None + blacklistTracker.map { _ => + new TaskSetBlacklist(conf, stageId, clock) } } @@ -487,6 +486,12 @@ private[spark] class TaskSetManager( private def maybeFinishTaskSet() { if (isZombie && runningTasks == 0) { sched.taskSetFinished(this) + if (tasksSuccessful == numTasks) { + blacklistTracker.foreach(_.updateBlacklistForSuccessfulTaskSet( + taskSet.stageId, + taskSet.stageAttemptId, + taskSetBlacklistHelperOpt.get.execToFailures)) + } } } @@ -589,6 +594,7 @@ private[spark] class TaskSetManager( private[scheduler] def abortIfCompletelyBlacklisted( hostToExecutors: HashMap[String, HashSet[String]]): Unit = { taskSetBlacklistHelperOpt.foreach { taskSetBlacklist => + val appBlacklist = blacklistTracker.get // Only look for unschedulable tasks when at least one executor has registered. Otherwise, // task sets will be (unnecessarily) aborted in cases when no executors have registered yet. if (hostToExecutors.nonEmpty) { @@ -615,13 +621,15 @@ private[spark] class TaskSetManager( val blacklistedEverywhere = hostToExecutors.forall { case (host, execsOnHost) => // Check if the task can run on the node val nodeBlacklisted = - taskSetBlacklist.isNodeBlacklistedForTaskSet(host) || - taskSetBlacklist.isNodeBlacklistedForTask(host, indexInTaskSet) + appBlacklist.isNodeBlacklisted(host) || + taskSetBlacklist.isNodeBlacklistedForTaskSet(host) || + taskSetBlacklist.isNodeBlacklistedForTask(host, indexInTaskSet) if (nodeBlacklisted) { true } else { // Check if the task can run on any of the executors execsOnHost.forall { exec => + appBlacklist.isExecutorBlacklisted(exec) || taskSetBlacklist.isExecutorBlacklistedForTaskSet(exec) || taskSetBlacklist.isExecutorBlacklistedForTask(exec, indexInTaskSet) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 0a4f19d760..0280359809 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -99,7 +99,8 @@ private[spark] object CoarseGrainedClusterMessages { case class RequestExecutors( requestedTotal: Int, localityAwareTasks: Int, - hostToLocalTaskCount: Map[String, Int]) + hostToLocalTaskCount: Map[String, Int], + nodeBlacklist: Set[String]) extends CoarseGrainedClusterMessage // Check if an executor was force-killed but for a reason unrelated to the running tasks. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 368cd30a2e..7befdb0c1f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -139,7 +139,7 @@ private[spark] class StandaloneSchedulerBackend( scheduler.error(reason) } finally { // Ensure the application terminates, as we can no longer run jobs. - sc.stop() + sc.stopInNewThread() } } } diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index 8e3436f134..cdd3b8d851 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -21,12 +21,15 @@ import java.util.Properties import javax.crypto.KeyGenerator import javax.crypto.spec.{IvParameterSpec, SecretKeySpec} +import scala.collection.JavaConverters._ + import org.apache.commons.crypto.random._ import org.apache.commons.crypto.stream._ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.network.util.CryptoUtils /** * A util class for manipulating IO encryption and decryption streams. @@ -37,8 +40,6 @@ private[spark] object CryptoStreamUtils extends Logging { val IV_LENGTH_IN_BYTES = 16 // The prefix of IO encryption related configurations in Spark configuration. val SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX = "spark.io.encryption.commons.config." - // The prefix for the configurations passing to Apache Commons Crypto library. - val COMMONS_CRYPTO_CONF_PREFIX = "commons.crypto." /** * Helper method to wrap `OutputStream` with `CryptoOutputStream` for encryption. @@ -70,18 +71,9 @@ private[spark] object CryptoStreamUtils extends Logging { new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) } - /** - * Get Commons-crypto configurations from Spark configurations identified by prefix. - */ def toCryptoConf(conf: SparkConf): Properties = { - val props = new Properties() - conf.getAll.foreach { case (k, v) => - if (k.startsWith(SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX)) { - props.put(COMMONS_CRYPTO_CONF_PREFIX + k.substring( - SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX.length()), v) - } - } - props + CryptoUtils.toCryptoConf(SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX, + conf.getAll.toMap.asJava.entrySet()) } /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index b9d83495d2..8b2e26cdd9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -42,24 +42,21 @@ private[spark] class BlockStoreShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val blockFetcherItr = new ShuffleBlockFetcherIterator( + val wrappedStreams = new ShuffleBlockFetcherIterator( context, blockManager.shuffleClient, blockManager, mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), + serializerManager.wrapStream, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, - SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)) - - // Wrap the streams for compression and encryption based on configuration - val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => - serializerManager.wrapStream(blockId, inputStream) - } + SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), + SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) val serializerInstance = dep.serializer.newInstance() // Create a key/value iterator for each stream - val recordIter = wrappedStreams.flatMap { wrappedStream => + val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 269c12d6da..b720aaee7c 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,19 +17,21 @@ package org.apache.spark.storage -import java.io.InputStream +import java.io.{InputStream, IOException} +import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue import javax.annotation.concurrent.GuardedBy +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} -import scala.util.control.NonFatal import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils +import org.apache.spark.util.io.ChunkedByteBufferOutputStream /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -47,8 +49,10 @@ import org.apache.spark.util.Utils * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. + * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. + * @param detectCorrupt whether to detect any corruption in fetched blocks. */ private[spark] final class ShuffleBlockFetcherIterator( @@ -56,8 +60,10 @@ final class ShuffleBlockFetcherIterator( shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], + streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, - maxReqsInFlight: Int) + maxReqsInFlight: Int, + detectCorrupt: Boolean) extends Iterator[(BlockId, InputStream)] with Logging { import ShuffleBlockFetcherIterator._ @@ -94,7 +100,7 @@ final class ShuffleBlockFetcherIterator( * Current [[FetchResult]] being processed. We track this so we can release the current buffer * in case of a runtime exception when processing the current buffer. */ - @volatile private[this] var currentResult: FetchResult = null + @volatile private[this] var currentResult: SuccessFetchResult = null /** * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that @@ -108,6 +114,12 @@ final class ShuffleBlockFetcherIterator( /** Current number of requests in flight */ private[this] var reqsInFlight = 0 + /** + * The blocks that can't be decompressed successfully, it is used to guarantee that we retry + * at most once for those corrupted blocks. + */ + private[this] val corruptedBlocks = mutable.HashSet[BlockId]() + private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() /** @@ -123,9 +135,8 @@ final class ShuffleBlockFetcherIterator( // The currentResult is set to null to prevent releasing the buffer again on cleanup() private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary - currentResult match { - case SuccessFetchResult(_, _, _, buf, _) => buf.release() - case _ => + if (currentResult != null) { + currentResult.buf.release() } currentResult = null } @@ -305,40 +316,84 @@ final class ShuffleBlockFetcherIterator( */ override def next(): (BlockId, InputStream) = { numBlocksProcessed += 1 - val startFetchWait = System.currentTimeMillis() - currentResult = results.take() - val result = currentResult - val stopFetchWait = System.currentTimeMillis() - shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) - - result match { - case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) => - if (address != blockManager.blockManagerId) { - shuffleMetrics.incRemoteBytesRead(buf.size) - shuffleMetrics.incRemoteBlocksFetched(1) - } - bytesInFlight -= size - if (isNetworkReqDone) { - reqsInFlight -= 1 - logDebug("Number of requests in flight " + reqsInFlight) - } - case _ => - } - // Send fetch requests up to maxBytesInFlight - fetchUpToMaxBytes() - result match { - case FailureFetchResult(blockId, address, e) => - throwFetchFailedException(blockId, address, e) + var result: FetchResult = null + var input: InputStream = null + // Take the next fetched result and try to decompress it to detect data corruption, + // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch + // is also corrupt, so the previous stage could be retried. + // For local shuffle block, throw FailureFetchResult for the first IOException. + while (result == null) { + val startFetchWait = System.currentTimeMillis() + result = results.take() + val stopFetchWait = System.currentTimeMillis() + shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) - case SuccessFetchResult(blockId, address, _, buf, _) => - try { - (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this)) - } catch { - case NonFatal(t) => - throwFetchFailedException(blockId, address, t) - } + result match { + case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) => + if (address != blockManager.blockManagerId) { + shuffleMetrics.incRemoteBytesRead(buf.size) + shuffleMetrics.incRemoteBlocksFetched(1) + } + bytesInFlight -= size + if (isNetworkReqDone) { + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + } + + val in = try { + buf.createInputStream() + } catch { + // The exception could only be throwed by local shuffle block + case e: IOException => + assert(buf.isInstanceOf[FileSegmentManagedBuffer]) + logError("Failed to create input stream from local block", e) + buf.release() + throwFetchFailedException(blockId, address, e) + } + + input = streamWrapper(blockId, in) + // Only copy the stream if it's wrapped by compression or encryption, also the size of + // block is small (the decompressed block is smaller than maxBytesInFlight) + if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) { + val originalInput = input + val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) + try { + // Decompress the whole block at once to detect any corruption, which could increase + // the memory usage tne potential increase the chance of OOM. + // TODO: manage the memory used here, and spill it into disk in case of OOM. + Utils.copyStream(input, out) + out.close() + input = out.toChunkedByteBuffer.toInputStream(dispose = true) + } catch { + case e: IOException => + buf.release() + if (buf.isInstanceOf[FileSegmentManagedBuffer] + || corruptedBlocks.contains(blockId)) { + throwFetchFailedException(blockId, address, e) + } else { + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest(address, Array((blockId, size))) + result = null + } + } finally { + // TODO: release the buf here to free memory earlier + originalInput.close() + in.close() + } + } + + case FailureFetchResult(blockId, address, e) => + throwFetchFailedException(blockId, address, e) + } + + // Send fetch requests up to maxBytesInFlight + fetchUpToMaxBytes() } + + currentResult = result.asInstanceOf[SuccessFetchResult] + (currentResult.blockId, new BufferReleasingInputStream(input, this)) } private def fetchUpToMaxBytes(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index fff21218b1..fb54dd66a3 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -331,7 +331,15 @@ private[spark] class MemoryStore( var unrollMemoryUsedByThisBlock = 0L // Underlying buffer for unrolling the block val redirectableStream = new RedirectableOutputStream - val bbos = new ChunkedByteBufferOutputStream(initialMemoryThreshold.toInt, allocator) + val chunkSize = if (initialMemoryThreshold > Int.MaxValue) { + logWarning(s"Initial memory threshold of ${Utils.bytesToString(initialMemoryThreshold)} " + + s"is too large to be set as chunk size. Chunk size has been capped to " + + s"${Utils.bytesToString(Int.MaxValue)}") + Int.MaxValue + } else { + initialMemoryThreshold.toInt + } + val bbos = new ChunkedByteBufferOutputStream(chunkSize, allocator) redirectableStream.setOutputStream(bbos) val serializationStream: SerializationStream = { val autoPick = !blockId.isInstanceOf[StreamBlockId] @@ -694,7 +702,7 @@ private[storage] class PartiallyUnrolledIterator[T]( } override def next(): T = { - if (unrolled == null) { + if (unrolled == null || !unrolled.hasNext) { rest.next() } else { unrolled.next() diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index b828532aba..7d31ac54a7 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -60,6 +60,8 @@ private[spark] class SparkUI private ( var appId: String = _ + private var streamingJobProgressListener: Option[SparkListener] = None + /** Initialize all components of the server. */ def initialize() { val jobsTab = new JobsTab(this) @@ -124,6 +126,12 @@ private[spark] class SparkUI private ( def getApplicationInfo(appId: String): Option[ApplicationInfo] = { getApplicationInfoList.find(_.id == appId) } + + def getStreamingJobProgressListener: Option[SparkListener] = streamingJobProgressListener + + def setStreamingJobProgressListener(sparkListener: SparkListener): Unit = { + streamingJobProgressListener = Option(sparkListener) + } } private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 8c80155867..b8604c52e6 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -147,10 +147,7 @@ private[spark] abstract class WebUI( } /** Return the url of web interface. Only valid after bind(). */ - def webUrl: String = { - val protocol = if (sslOptions.enabled) "https" else "http" - s"$protocol://$publicHostName:$boundPort" - } + def webUrl: String = s"http://$publicHostName:$boundPort" /** Return the actual port to which this server is bound. Only valid after bind(). */ def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 60a6e82c6f..1aa4456ed0 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import java.util.concurrent._ -import scala.concurrent.{Await, Awaitable, ExecutionContext, ExecutionContextExecutor} +import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor} import scala.concurrent.duration.Duration import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal @@ -180,39 +180,30 @@ private[spark] object ThreadUtils { // scalastyle:off awaitresult /** - * Preferred alternative to `Await.result()`. This method wraps and re-throws any exceptions - * thrown by the underlying `Await` call, ensuring that this thread's stack trace appears in - * logs. - */ - @throws(classOf[SparkException]) - def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = { - try { - Await.result(awaitable, atMost) - // scalastyle:on awaitresult - } catch { - case NonFatal(t) => - throw new SparkException("Exception thrown in awaitResult: ", t) - } - } - - /** - * Calls `Awaitable.result` directly to avoid using `ForkJoinPool`'s `BlockingContext`, wraps - * and re-throws any exceptions with nice stack track. + * Preferred alternative to `Await.result()`. + * + * This method wraps and re-throws any exceptions thrown by the underlying `Await` call, ensuring + * that this thread's stack trace appears in logs. * - * Codes running in the user's thread may be in a thread of Scala ForkJoinPool. As concurrent - * executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this method - * basically prevents ForkJoinPool from running other tasks in the current waiting thread. + * In addition, it calls `Awaitable.result` directly to avoid using `ForkJoinPool`'s + * `BlockingContext`. Codes running in the user's thread may be in a thread of Scala ForkJoinPool. + * As concurrent executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this + * method basically prevents ForkJoinPool from running other tasks in the current waiting thread. + * In general, we should use this method because many places in Spark use [[ThreadLocal]] and it's + * hard to debug when [[ThreadLocal]]s leak to other tasks. */ @throws(classOf[SparkException]) - def awaitResultInForkJoinSafely[T](awaitable: Awaitable[T], atMost: Duration): T = { + def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = { try { // `awaitPermission` is not actually used anywhere so it's safe to pass in null here. // See SPARK-13747. val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] - awaitable.result(Duration.Inf)(awaitPermission) + awaitable.result(atMost)(awaitPermission) } catch { - case NonFatal(t) => + // TimeoutException is thrown in the current thread, so not need to warp the exception. + case NonFatal(t) if !t.isInstanceOf[TimeoutException] => throw new SparkException("Exception thrown in awaitResult: ", t) } } + // scalastyle:on awaitresult } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 91f5606127..078cc3d5b4 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import java.io._ -import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo} +import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} import java.net._ import java.nio.ByteBuffer import java.nio.channels.Channels @@ -1249,7 +1249,7 @@ private[spark] object Utils extends Logging { val currentThreadName = Thread.currentThread().getName if (sc != null) { logError(s"uncaught error in thread $currentThreadName, stopping SparkContext", t) - sc.stop() + sc.stopInNewThread() } if (!NonFatal(t)) { logError(s"throw uncaught fatal error in thread $currentThreadName", t) @@ -2131,28 +2131,46 @@ private[spark] object Utils extends Logging { // We need to filter out null values here because dumpAllThreads() may return null array // elements for threads that are dead / don't exist. val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) - threadInfos.sortBy(_.getThreadId).map { case threadInfo => - val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap - val stackTrace = threadInfo.getStackTrace.map { frame => - monitors.get(frame) match { - case Some(monitor) => - monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}" - case None => - frame.toString - } - }.mkString("\n") - - // use a set to dedup re-entrant locks that are held at multiple places - val heldLocks = (threadInfo.getLockedSynchronizers.map(_.lockString) - ++ threadInfo.getLockedMonitors.map(_.lockString) - ).toSet + threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace) + } - ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, threadInfo.getThreadState, - stackTrace, if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId), - Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""), heldLocks.toSeq) + def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = { + if (threadId <= 0) { + None + } else { + // The Int.MaxValue here requests the entire untruncated stack trace of the thread: + val threadInfo = + Option(ManagementFactory.getThreadMXBean.getThreadInfo(threadId, Int.MaxValue)) + threadInfo.map(threadInfoToThreadStackTrace) } } + private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): ThreadStackTrace = { + val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap + val stackTrace = threadInfo.getStackTrace.map { frame => + monitors.get(frame) match { + case Some(monitor) => + monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}" + case None => + frame.toString + } + }.mkString("\n") + + // use a set to dedup re-entrant locks that are held at multiple places + val heldLocks = + (threadInfo.getLockedSynchronizers ++ threadInfo.getLockedMonitors).map(_.lockString).toSet + + ThreadStackTrace( + threadId = threadInfo.getThreadId, + threadName = threadInfo.getThreadName, + threadState = threadInfo.getThreadState, + stackTrace = stackTrace, + blockedByThreadId = + if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId), + blockedByLock = Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""), + holdingLocks = heldLocks.toSeq) + } + /** * Convert all spark properties set in the given SparkConf to a sequence of java options. */ diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index da08661d13..7572cac393 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -151,7 +151,7 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { * @param dispose if true, `ChunkedByteBuffer.dispose()` will be called at the end of the stream * in order to close any memory-mapped files which back the buffer. */ -private class ChunkedByteBufferInputStream( +private[spark] class ChunkedByteBufferInputStream( var chunkedByteBuffer: ChunkedByteBuffer, dispose: Boolean) extends InputStream { diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index 297524c943..a7e0075deb 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -56,11 +56,14 @@ private[spark] object SamplingUtils { val rand = new XORShiftRandom(seed) while (input.hasNext) { val item = input.next() + l += 1 + // There are k elements in the reservoir, and the l-th element has been + // consumed. It should be chosen with probability k/l. The expression + // below is a random long chosen uniformly from [0,l) val replacementIndex = (rand.nextDouble() * l).toLong if (replacementIndex < k) { reservoir(replacementIndex.toInt) = item } - l += 1 } (reservoir, l) } diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 682d98867b..0c77123740 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -27,8 +27,10 @@ import org.slf4j.LoggerFactory; import org.slf4j.bridge.SLF4JBridgeHandler; import static org.junit.Assert.*; +import static org.junit.Assume.*; import org.apache.spark.internal.config.package$; +import org.apache.spark.util.Utils; /** * These tests require the Spark assembly to be built before they can be run. @@ -155,6 +157,10 @@ public void testRedirectToLog() throws Exception { @Test public void testChildProcLauncher() throws Exception { + // This test is failed on Windows due to the failure of initiating executors + // by the path length limitation. See SPARK-18718. + assumeTrue(!Utils.isWindows()); + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); Map env = new HashMap<>(); env.put("SPARK_PRINT_LAUNCH_COMMAND", "1"); diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 89f0b1cb5b..6538507d40 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -22,6 +22,7 @@ import java.util.zip.GZIPOutputStream import scala.io.Source +import org.apache.hadoop.fs.Path import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec import org.apache.hadoop.mapred.{FileAlreadyExistsException, FileSplit, JobConf, TextInputFormat, TextOutputFormat} @@ -255,7 +256,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { val (infile: String, indata: PortableDataStream) = inRdd.collect.head // Make sure the name and array match - assert(infile.contains(outFileName)) // a prefix may get added + assert(infile.contains(outFile.toURI.getPath)) // a prefix may get added assert(indata.toArray === testOutput) } @@ -532,7 +533,9 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { .mapPartitionsWithInputSplit { (split, part) => Iterator(split.asInstanceOf[FileSplit].getPath.toUri.getPath) }.collect() - assert(inputPaths.toSet === Set(s"$outDir/part-00000", s"$outDir/part-00001")) + val outPathOne = new Path(outDir, "part-00000").toUri.getPath + val outPathTwo = new Path(outDir, "part-00001").toUri.getPath + assert(inputPaths.toSet === Set(outPathOne, outPathTwo)) } test("Get input files via new Hadoop API") { @@ -546,7 +549,9 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { .mapPartitionsWithInputSplit { (split, part) => Iterator(split.asInstanceOf[NewFileSplit].getPath.toUri.getPath) }.collect() - assert(inputPaths.toSet === Set(s"$outDir/part-00000", s"$outDir/part-00001")) + val outPathOne = new Path(outDir, "part-00000").toUri.getPath + val outPathTwo = new Path(outDir, "part-00001").toUri.getPath + assert(inputPaths.toSet === Set(outPathOne, outPathTwo)) } test("spark.files.ignoreCorruptFiles should work both HadoopRDD and NewHadoopRDD") { diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 915d7a1b8b..7b6a2313f9 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -272,7 +272,7 @@ private class FakeSchedulerBackend( protected override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { clusterManagerEndpoint.ask[Boolean]( - RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount, Set.empty[String])) } protected override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = { @@ -291,7 +291,7 @@ private class FakeClusterManager(override val rpcEnv: RpcEnv) extends RpcEndpoin def getExecutorIdsToKill: Set[String] = executorIdsToKill.toSet override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RequestExecutors(requestedTotal, _, _) => + case RequestExecutors(requestedTotal, _, _, _) => targetNumExecutors = requestedTotal context.reply(true) case KillExecutors(executorIds) => diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index a3490fc79e..99150a1430 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -209,6 +209,83 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft assert(jobB.get() === 100) } + test("task reaper kills JVM if killed tasks keep running for too long") { + val conf = new SparkConf() + .set("spark.task.reaper.enabled", "true") + .set("spark.task.reaper.killTimeout", "5s") + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) + + // Add a listener to release the semaphore once any tasks are launched. + val sem = new Semaphore(0) + sc.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + + // jobA is the one to be cancelled. + val jobA = Future { + sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true) + sc.parallelize(1 to 10000, 2).map { i => + while (true) { } + }.count() + } + + // Block until both tasks of job A have started and cancel job A. + sem.acquire(2) + // Small delay to ensure tasks actually start executing the task body + Thread.sleep(1000) + + sc.clearJobGroup() + val jobB = sc.parallelize(1 to 100, 2).countAsync() + sc.cancelJobGroup("jobA") + val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 15.seconds) }.getCause + assert(e.getMessage contains "cancel") + + // Once A is cancelled, job B should finish fairly quickly. + assert(ThreadUtils.awaitResult(jobB, 60.seconds) === 100) + } + + test("task reaper will not kill JVM if spark.task.killTimeout == -1") { + val conf = new SparkConf() + .set("spark.task.reaper.enabled", "true") + .set("spark.task.reaper.killTimeout", "-1") + .set("spark.task.reaper.PollingInterval", "1s") + .set("spark.deploy.maxExecutorRetries", "1") + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) + + // Add a listener to release the semaphore once any tasks are launched. + val sem = new Semaphore(0) + sc.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + + // jobA is the one to be cancelled. + val jobA = Future { + sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true) + sc.parallelize(1 to 2, 2).map { i => + val startTime = System.currentTimeMillis() + while (System.currentTimeMillis() < startTime + 10000) { } + }.count() + } + + // Block until both tasks of job A have started and cancel job A. + sem.acquire(2) + // Small delay to ensure tasks actually start executing the task body + Thread.sleep(1000) + + sc.clearJobGroup() + val jobB = sc.parallelize(1 to 100, 2).countAsync() + sc.cancelJobGroup("jobA") + val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 15.seconds) }.getCause + assert(e.getMessage contains "cancel") + + // Once A is cancelled, job B should finish fairly quickly. + assert(ThreadUtils.awaitResult(jobB, 60.seconds) === 100) + } + test("two jobs sharing the same stage") { // sem1: make sure cancel is issued after some tasks are launched // twoJobsSharingStageSemaphore: diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index a854f5bb9b..e626ed3621 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListene import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.ShuffleWriter import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId} -import org.apache.spark.util.MutablePair +import org.apache.spark.util.{MutablePair, Utils} abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index c451c596b0..8fba82de54 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.Matchers._ +import org.apache.spark.scheduler.SparkListener import org.apache.spark.util.Utils class SparkContextSuite extends SparkFunSuite with LocalSparkContext { @@ -451,4 +452,19 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } } + + test("register and deregister Spark listener from SparkContext") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + try { + val sparkListener1 = new SparkListener { } + val sparkListener2 = new SparkListener { } + sc.addSparkListener(sparkListener1) + sc.addSparkListener(sparkListener2) + assert(sc.listenerBus.listeners.contains(sparkListener1)) + assert(sc.listenerBus.listeners.contains(sparkListener2)) + sc.removeSparkListener(sparkListener1) + assert(!sc.listenerBus.listeners.contains(sparkListener1)) + assert(sc.listenerBus.listeners.contains(sparkListener2)) + } + } } diff --git a/core/src/test/scala/org/apache/spark/api/r/JVMObjectTrackerSuite.scala b/core/src/test/scala/org/apache/spark/api/r/JVMObjectTrackerSuite.scala new file mode 100644 index 0000000000..6a979aefe6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/api/r/JVMObjectTrackerSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import org.apache.spark.SparkFunSuite + +class JVMObjectTrackerSuite extends SparkFunSuite { + test("JVMObjectId does not take null IDs") { + intercept[IllegalArgumentException] { + JVMObjectId(null) + } + } + + test("JVMObjectTracker") { + val tracker = new JVMObjectTracker + assert(tracker.size === 0) + withClue("an empty tracker can be cleared") { + tracker.clear() + } + val none = JVMObjectId("none") + assert(tracker.get(none) === None) + intercept[NoSuchElementException] { + tracker(JVMObjectId("none")) + } + + val obj1 = new Object + val id1 = tracker.addAndGetId(obj1) + assert(id1 != null) + assert(tracker.size === 1) + assert(tracker.get(id1).get.eq(obj1)) + assert(tracker(id1).eq(obj1)) + + val obj2 = new Object + val id2 = tracker.addAndGetId(obj2) + assert(id1 !== id2) + assert(tracker.size === 2) + assert(tracker(id2).eq(obj2)) + + val Some(obj1Removed) = tracker.remove(id1) + assert(obj1Removed.eq(obj1)) + assert(tracker.get(id1) === None) + assert(tracker.size === 1) + assert(tracker(id2).eq(obj2)) + + val obj3 = new Object + val id3 = tracker.addAndGetId(obj3) + assert(tracker.size === 2) + assert(id3 != id1) + assert(id3 != id2) + assert(tracker(id3).eq(obj3)) + + tracker.clear() + assert(tracker.size === 0) + assert(tracker.get(id1) === None) + assert(tracker.get(id2) === None) + assert(tracker.get(id3) === None) + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java b/core/src/test/scala/org/apache/spark/api/r/RBackendSuite.scala similarity index 67% rename from common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java rename to core/src/test/scala/org/apache/spark/api/r/RBackendSuite.scala index f15ec8d294..085cc267ca 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java +++ b/core/src/test/scala/org/apache/spark/api/r/RBackendSuite.scala @@ -15,18 +15,17 @@ * limitations under the License. */ -package org.apache.spark.network.util; +package org.apache.spark.api.r -import java.util.NoSuchElementException; +import org.apache.spark.SparkFunSuite -/** Uses System properties to obtain config values. */ -public class SystemPropertyConfigProvider extends ConfigProvider { - @Override - public String get(String name) { - String value = System.getProperty(name); - if (value == null) { - throw new NoSuchElementException(name); - } - return value; +class RBackendSuite extends SparkFunSuite { + test("close() clears jvmObjectTracker") { + val backend = new RBackend + val tracker = backend.jvmObjectTracker + val id = tracker.addAndGetId(new Object) + backend.close() + assert(tracker.get(id) === None) + assert(tracker.size === 0) } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 973676398a..6646068d50 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -137,6 +137,18 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } + test("Cache broadcast to disk") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.memory.useLegacyMode", "true") + .set("spark.storage.memoryFraction", "0.0") + sc = new SparkContext(conf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + assert(broadcast.value.sum === 10) + } + /** * Verify the persistence of state associated with a TorrentBroadcast in a local-cluster. * diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala index c9b3d657c2..f50cb38311 100644 --- a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala @@ -142,7 +142,7 @@ private[deploy] object IvyTestUtils { |} """.stripMargin val sourceFile = - new JavaSourceFromString(new File(dir, className).getAbsolutePath, contents) + new JavaSourceFromString(new File(dir, className).toURI.getPath, contents) createCompiledClass(className, dir, sourceFile, Seq.empty) } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 6268880229..9417930d02 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -461,7 +461,7 @@ class SparkSubmitSuite val tempDir = Utils.createTempDir() val srcDir = new File(tempDir, "sparkrtest") srcDir.mkdirs() - val excSource = new JavaSourceFromString(new File(srcDir, "DummyClass").getAbsolutePath, + val excSource = new JavaSourceFromString(new File(srcDir, "DummyClass").toURI.getPath, """package sparkrtest; | |public class DummyClass implements java.io.Serializable { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 2c41c432d1..027f412c75 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -66,7 +66,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } test("Parse application logs") { - val provider = new FsHistoryProvider(createTestConf()) + val clock = new ManualClock(12345678) + val provider = new FsHistoryProvider(createTestConf(), clock) // Write a new-style application log. val newAppComplete = newLogFile("new1", None, inProgress = false) @@ -109,12 +110,15 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc List(ApplicationAttemptInfo(None, start, end, lastMod, user, completed))) } + // For completed files, lastUpdated would be lastModified time. list(0) should be (makeAppInfo("new-app-complete", newAppComplete.getName(), 1L, 5L, newAppComplete.lastModified(), "test", true)) list(1) should be (makeAppInfo("new-complete-lzf", newAppCompressedComplete.getName(), 1L, 4L, newAppCompressedComplete.lastModified(), "test", true)) + + // For Inprogress files, lastUpdated would be current loading time. list(2) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L, - newAppIncomplete.lastModified(), "test", false)) + clock.getTimeMillis(), "test", false)) // Make sure the UI can be rendered. list.foreach { case info => @@ -299,6 +303,48 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(!log2.exists()) } + test("log cleaner for inProgress files") { + val firstFileModifiedTime = TimeUnit.SECONDS.toMillis(10) + val secondFileModifiedTime = TimeUnit.SECONDS.toMillis(20) + val maxAge = TimeUnit.SECONDS.toMillis(40) + val clock = new ManualClock(0) + val provider = new FsHistoryProvider( + createTestConf().set("spark.history.fs.cleaner.maxAge", s"${maxAge}ms"), clock) + + val log1 = newLogFile("inProgressApp1", None, inProgress = true) + writeFile(log1, true, None, + SparkListenerApplicationStart( + "inProgressApp1", Some("inProgressApp1"), 3L, "test", Some("attempt1")) + ) + + clock.setTime(firstFileModifiedTime) + provider.checkForLogs() + + val log2 = newLogFile("inProgressApp2", None, inProgress = true) + writeFile(log2, true, None, + SparkListenerApplicationStart( + "inProgressApp2", Some("inProgressApp2"), 23L, "test2", Some("attempt2")) + ) + + clock.setTime(secondFileModifiedTime) + provider.checkForLogs() + + // This should not trigger any cleanup + updateAndCheck(provider)(list => list.size should be(2)) + + // Should trigger cleanup for first file but not second one + clock.setTime(firstFileModifiedTime + maxAge + 1) + updateAndCheck(provider)(list => list.size should be(1)) + assert(!log1.exists()) + assert(log2.exists()) + + // Should cleanup the second file as well. + clock.setTime(secondFileModifiedTime + maxAge + 1) + updateAndCheck(provider)(list => list.size should be(0)) + assert(!log1.exists()) + assert(!log2.exists()) + } + test("Event log copy") { val provider = new FsHistoryProvider(createTestConf()) val logs = (1 to 2).map { i => diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 715811a46f..d3b79dd3e3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -75,7 +75,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers .set("spark.testing", "true") provider = new FsHistoryProvider(conf) provider.checkForLogs() - val securityManager = new SecurityManager(conf) + val securityManager = HistoryServer.createSecurityManager(conf) server = new HistoryServer(conf, provider, securityManager, 18080) server.initialize() @@ -288,7 +288,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers provider = new FsHistoryProvider(conf) provider.checkForLogs() - val securityManager = new SecurityManager(conf) + val securityManager = HistoryServer.createSecurityManager(conf) server = new HistoryServer(conf, provider, securityManager, 18080) server.initialize() @@ -349,6 +349,17 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } + /** + * Verify that the security manager needed for the history server can be instantiated + * when `spark.authenticate` is `true`, rather than raise an `IllegalArgumentException`. + */ + test("security manager starts with spark.authenticate set") { + val conf = new SparkConf() + .set("spark.testing", "true") + .set(SecurityManager.SPARK_AUTH_CONF, "true") + HistoryServer.createSecurityManager(conf) + } + test("incomplete apps get refreshed") { implicit val webDriver: WebDriver = new HtmlUnitDriver @@ -368,7 +379,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers .set("spark.history.cache.window", "250ms") .remove("spark.testing") val provider = new FsHistoryProvider(myConf) - val securityManager = new SecurityManager(myConf) + val securityManager = HistoryServer.createSecurityManager(myConf) sc = new SparkContext("local", "test", myConf) val logDirUri = logDir.toURI diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 683eeeeb6d..742500d87d 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -51,9 +51,11 @@ class ExecutorSuite extends SparkFunSuite { when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem) when(mockEnv.memoryManager).thenReturn(mockMemoryManager) when(mockEnv.closureSerializer).thenReturn(serializer) + val fakeTaskMetrics = serializer.newInstance().serialize(TaskMetrics.registered).array() + val serializedTask = Task.serializeWithDependencies( - new FakeTask(0, 0), + new FakeTask(0, 0, Nil, fakeTaskMetrics), HashMap[String, Long](), HashMap[String, Long](), serializer.newInstance()) diff --git a/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala index cac15a1dc4..c88cc13654 100644 --- a/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ import org.apache.spark._ +import org.apache.spark.util.Utils class LauncherBackendSuite extends SparkFunSuite with Matchers { @@ -35,6 +36,8 @@ class LauncherBackendSuite extends SparkFunSuite with Matchers { tests.foreach { case (name, master) => test(s"$name: launcher handle") { + // The tests here are failed due to the cmd length limitation up to 8K on Windows. + assume(!Utils.isWindows) testWithMaster(master) } } diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index f8054f5fd7..a73b300ec2 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -61,7 +61,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext pw.close() // Path to tmpFile - tmpFilePath = "file://" + tmpFile.getAbsolutePath + tmpFilePath = tmpFile.toURI.toString } after { @@ -181,7 +181,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext sc.textFile(tmpFilePath, 4) .map(key => (key, 1)) .reduceByKey(_ + _) - .saveAsTextFile("file://" + tmpFile.getAbsolutePath) + .saveAsTextFile(tmpFile.toURI.toString) sc.listenerBus.waitUntilEmpty(500) assert(inputRead == numRecords) @@ -197,7 +197,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext val numPartitions = 2 val cartVector = 0 to 9 val cartFile = new File(tmpDir, getClass.getSimpleName + "_cart.txt") - val cartFilePath = "file://" + cartFile.getAbsolutePath + val cartFilePath = cartFile.toURI.toString // write files to disk so we can read them later. sc.parallelize(cartVector).saveAsTextFile(cartFilePath) diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 58664e77d2..b29a53cffe 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -199,10 +199,9 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim val f = sc.parallelize(1 to 100, 4) .mapPartitions(itr => { Thread.sleep(20); itr }) .countAsync() - val e = intercept[SparkException] { + intercept[TimeoutException] { ThreadUtils.awaitResult(f, Duration(20, "milliseconds")) } - assert(e.getCause.isInstanceOf[TimeoutException]) } private def testAsyncAction[R](action: RDD[Int] => FutureAction[R]): Unit = { diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 7293aa9a25..287ae6ff6e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -32,109 +32,104 @@ import org.apache.spark._ import org.apache.spark.util.Utils class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { + val envCommand = if (Utils.isWindows) { + "cmd.exe /C set" + } else { + "printenv" + } test("basic pipe") { - if (testCommandAvailable("cat")) { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + assume(testCommandAvailable("cat")) + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("cat")) + val piped = nums.pipe(Seq("cat")) - val c = piped.collect() - assert(c.size === 4) - assert(c(0) === "1") - assert(c(1) === "2") - assert(c(2) === "3") - assert(c(3) === "4") - } else { - assert(true) - } + val c = piped.collect() + assert(c.size === 4) + assert(c(0) === "1") + assert(c(1) === "2") + assert(c(2) === "3") + assert(c(3) === "4") } test("basic pipe with tokenization") { - if (testCommandAvailable("wc")) { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - - // verify that both RDD.pipe(command: String) and RDD.pipe(command: String, env) work good - for (piped <- Seq(nums.pipe("wc -l"), nums.pipe("wc -l", Map[String, String]()))) { - val c = piped.collect() - assert(c.size === 2) - assert(c(0).trim === "2") - assert(c(1).trim === "2") - } - } else { - assert(true) + assume(testCommandAvailable("wc")) + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + + // verify that both RDD.pipe(command: String) and RDD.pipe(command: String, env) work good + for (piped <- Seq(nums.pipe("wc -l"), nums.pipe("wc -l", Map[String, String]()))) { + val c = piped.collect() + assert(c.size === 2) + assert(c(0).trim === "2") + assert(c(1).trim === "2") } } test("failure in iterating over pipe input") { - if (testCommandAvailable("cat")) { - val nums = - sc.makeRDD(Array(1, 2, 3, 4), 2) - .mapPartitionsWithIndex((index, iterator) => { - new Iterator[Int] { - def hasNext = true - def next() = { - throw new SparkException("Exception to simulate bad scenario") - } - } - }) - - val piped = nums.pipe(Seq("cat")) - - intercept[SparkException] { - piped.collect() - } + assume(testCommandAvailable("cat")) + val nums = + sc.makeRDD(Array(1, 2, 3, 4), 2) + .mapPartitionsWithIndex((index, iterator) => { + new Iterator[Int] { + def hasNext = true + def next() = { + throw new SparkException("Exception to simulate bad scenario") + } + } + }) + + val piped = nums.pipe(Seq("cat")) + + intercept[SparkException] { + piped.collect() } } test("advanced pipe") { - if (testCommandAvailable("cat")) { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val bl = sc.broadcast(List("0")) - - val piped = nums.pipe(Seq("cat"), + assume(testCommandAvailable("cat")) + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val bl = sc.broadcast(List("0")) + + val piped = nums.pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => { + bl.value.foreach(f); f("\u0001") + }, + (i: Int, f: String => Unit) => f(i + "_")) + + val c = piped.collect() + + assert(c.size === 8) + assert(c(0) === "0") + assert(c(1) === "\u0001") + assert(c(2) === "1_") + assert(c(3) === "2_") + assert(c(4) === "0") + assert(c(5) === "\u0001") + assert(c(6) === "3_") + assert(c(7) === "4_") + + val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) + val d = nums1.groupBy(str => str.split("\t")(0)). + pipe(Seq("cat"), Map[String, String](), (f: String => Unit) => { bl.value.foreach(f); f("\u0001") }, - (i: Int, f: String => Unit) => f(i + "_")) - - val c = piped.collect() - - assert(c.size === 8) - assert(c(0) === "0") - assert(c(1) === "\u0001") - assert(c(2) === "1_") - assert(c(3) === "2_") - assert(c(4) === "0") - assert(c(5) === "\u0001") - assert(c(6) === "3_") - assert(c(7) === "4_") - - val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) - val d = nums1.groupBy(str => str.split("\t")(0)). - pipe(Seq("cat"), - Map[String, String](), - (f: String => Unit) => { - bl.value.foreach(f); f("\u0001") - }, - (i: Tuple2[String, Iterable[String]], f: String => Unit) => { - for (e <- i._2) { - f(e + "_") - } - }).collect() - assert(d.size === 8) - assert(d(0) === "0") - assert(d(1) === "\u0001") - assert(d(2) === "b\t2_") - assert(d(3) === "b\t4_") - assert(d(4) === "0") - assert(d(5) === "\u0001") - assert(d(6) === "a\t1_") - assert(d(7) === "a\t3_") - } else { - assert(true) - } + (i: Tuple2[String, Iterable[String]], f: String => Unit) => { + for (e <- i._2) { + f(e + "_") + } + }).collect() + assert(d.size === 8) + assert(d(0) === "0") + assert(d(1) === "\u0001") + assert(d(2) === "b\t2_") + assert(d(3) === "b\t4_") + assert(d(4) === "0") + assert(d(5) === "\u0001") + assert(d(6) === "a\t1_") + assert(d(7) === "a\t3_") } test("pipe with empty partition") { @@ -142,67 +137,67 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { val piped = data.pipe("wc -c") assert(piped.count == 8) val charCounts = piped.map(_.trim.toInt).collect().toSet - assert(Set(0, 4, 5) == charCounts) + val expected = if (Utils.isWindows) { + // Note that newline character on Windows is \r\n which are two. + Set(0, 5, 6) + } else { + Set(0, 4, 5) + } + assert(expected == charCounts) } test("pipe with env variable") { - if (testCommandAvailable("printenv")) { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA")) - val c = piped.collect() - assert(c.size === 2) - assert(c(0) === "LALALA") - assert(c(1) === "LALALA") - } else { - assert(true) - } + assume(testCommandAvailable(envCommand)) + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val piped = nums.pipe(s"$envCommand MY_TEST_ENV", Map("MY_TEST_ENV" -> "LALALA")) + val c = piped.collect() + assert(c.length === 2) + // On Windows, `cmd.exe /C set` is used which prints out it as `varname=value` format + // whereas `printenv` usually prints out `value`. So, `varname=` is stripped here for both. + assert(c(0).stripPrefix("MY_TEST_ENV=") === "LALALA") + assert(c(1).stripPrefix("MY_TEST_ENV=") === "LALALA") } test("pipe with process which cannot be launched due to bad command") { - if (!testCommandAvailable("some_nonexistent_command")) { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val command = Seq("some_nonexistent_command") - val piped = nums.pipe(command) - val exception = intercept[SparkException] { - piped.collect() - } - assert(exception.getMessage.contains(command.mkString(" "))) + assume(!testCommandAvailable("some_nonexistent_command")) + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val command = Seq("some_nonexistent_command") + val piped = nums.pipe(command) + val exception = intercept[SparkException] { + piped.collect() } + assert(exception.getMessage.contains(command.mkString(" "))) } test("pipe with process which is launched but fails with non-zero exit status") { - if (testCommandAvailable("cat")) { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val command = Seq("cat", "nonexistent_file") - val piped = nums.pipe(command) - val exception = intercept[SparkException] { - piped.collect() - } - assert(exception.getMessage.contains(command.mkString(" "))) + assume(testCommandAvailable("cat")) + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val command = Seq("cat", "nonexistent_file") + val piped = nums.pipe(command) + val exception = intercept[SparkException] { + piped.collect() } + assert(exception.getMessage.contains(command.mkString(" "))) } test("basic pipe with separate working directory") { - if (testCommandAvailable("cat")) { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("cat"), separateWorkingDir = true) - val c = piped.collect() - assert(c.size === 4) - assert(c(0) === "1") - assert(c(1) === "2") - assert(c(2) === "3") - assert(c(3) === "4") - val pipedPwd = nums.pipe(Seq("pwd"), separateWorkingDir = true) - val collectPwd = pipedPwd.collect() - assert(collectPwd(0).contains("tasks/")) - val pipedLs = nums.pipe(Seq("ls"), separateWorkingDir = true, bufferSize = 16384).collect() - // make sure symlinks were created - assert(pipedLs.length > 0) - // clean up top level tasks directory - Utils.deleteRecursively(new File("tasks")) - } else { - assert(true) - } + assume(testCommandAvailable("cat")) + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val piped = nums.pipe(Seq("cat"), separateWorkingDir = true) + val c = piped.collect() + assert(c.size === 4) + assert(c(0) === "1") + assert(c(1) === "2") + assert(c(2) === "3") + assert(c(3) === "4") + val pipedPwd = nums.pipe(Seq("pwd"), separateWorkingDir = true) + val collectPwd = pipedPwd.collect() + assert(collectPwd(0).contains("tasks/")) + val pipedLs = nums.pipe(Seq("ls"), separateWorkingDir = true, bufferSize = 16384).collect() + // make sure symlinks were created + assert(pipedLs.length > 0) + // clean up top level tasks directory + Utils.deleteRecursively(new File("tasks")) } test("test pipe exports map_input_file") { @@ -219,36 +214,35 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } def testExportInputFile(varName: String) { - if (testCommandAvailable("printenv")) { - val nums = new HadoopRDD(sc, new JobConf(), classOf[TextInputFormat], classOf[LongWritable], - classOf[Text], 2) { - override def getPartitions: Array[Partition] = Array(generateFakeHadoopPartition()) + assume(testCommandAvailable(envCommand)) + val nums = new HadoopRDD(sc, new JobConf(), classOf[TextInputFormat], classOf[LongWritable], + classOf[Text], 2) { + override def getPartitions: Array[Partition] = Array(generateFakeHadoopPartition()) - override val getDependencies = List[Dependency[_]]() + override val getDependencies = List[Dependency[_]]() - override def compute(theSplit: Partition, context: TaskContext) = { - new InterruptibleIterator[(LongWritable, Text)](context, Iterator((new LongWritable(1), - new Text("b")))) - } + override def compute(theSplit: Partition, context: TaskContext) = { + new InterruptibleIterator[(LongWritable, Text)](context, Iterator((new LongWritable(1), + new Text("b")))) } - val hadoopPart1 = generateFakeHadoopPartition() - val pipedRdd = - new PipedRDD( - nums, - PipedRDD.tokenize("printenv " + varName), - Map(), - null, - null, - false, - 4092, - Codec.defaultCharsetCodec.name) - val tContext = TaskContext.empty() - val rddIter = pipedRdd.compute(hadoopPart1, tContext) - val arr = rddIter.toArray - assert(arr(0) == "/some/path") - } else { - // printenv isn't available so just pass the test } + val hadoopPart1 = generateFakeHadoopPartition() + val pipedRdd = + new PipedRDD( + nums, + PipedRDD.tokenize(s"$envCommand $varName"), + Map(), + null, + null, + false, + 4092, + Codec.defaultCharsetCodec.name) + val tContext = TaskContext.empty() + val rddIter = pipedRdd.compute(hadoopPart1, tContext) + val arr = rddIter.toArray + // On Windows, `cmd.exe /C set` is used which prints out it as `varname=value` format + // whereas `printenv` usually prints out `value`. So, `varname=` is stripped here for both. + assert(arr(0).stripPrefix(s"$varName=") === "/some/path") } def generateFakeHadoopPartition(): HadoopPartition = { diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index aa0705987d..acdf21df9a 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -870,19 +870,6 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { verify(endpoint, never()).onDisconnected(any()) verify(endpoint, never()).onNetworkError(any(), any()) } - - test("isInRPCThread") { - val rpcEndpointRef = env.setupEndpoint("isInRPCThread", new RpcEndpoint { - override val rpcEnv = env - - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case m => context.reply(rpcEnv.isInRPCThread) - } - }) - assert(rpcEndpointRef.askWithRetry[Boolean]("hello") === true) - assert(env.isInRPCThread === false) - env.stop(rpcEndpointRef) - } } class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index b2e7ec5df0..6b314d2ae3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -17,10 +17,356 @@ package org.apache.spark.scheduler -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mock.MockitoSugar + +import org.apache.spark._ import org.apache.spark.internal.config +import org.apache.spark.util.ManualClock + +class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with MockitoSugar + with LocalSparkContext { + + private val clock = new ManualClock(0) + + private var blacklist: BlacklistTracker = _ + private var scheduler: TaskSchedulerImpl = _ + private var conf: SparkConf = _ + + override def beforeEach(): Unit = { + conf = new SparkConf().setAppName("test").setMaster("local") + .set(config.BLACKLIST_ENABLED.key, "true") + scheduler = mockTaskSchedWithConf(conf) + + clock.setTime(0) + blacklist = new BlacklistTracker(conf, clock) + } + + override def afterEach(): Unit = { + if (blacklist != null) { + blacklist = null + } + if (scheduler != null) { + scheduler.stop() + scheduler = null + } + super.afterEach() + } + + // All executors and hosts used in tests should be in this set, so that [[assertEquivalentToSet]] + // works. Its OK if its got extraneous entries + val allExecutorAndHostIds = { + (('A' to 'Z')++ (1 to 100).map(_.toString)) + .flatMap{ suffix => + Seq(s"host$suffix", s"host-$suffix") + } + }.toSet + + /** + * Its easier to write our tests as if we could directly look at the sets of nodes & executors in + * the blacklist. However the api doesn't expose a set, so this is a simple way to test + * something similar, since we know the universe of values that might appear in these sets. + */ + def assertEquivalentToSet(f: String => Boolean, expected: Set[String]): Unit = { + allExecutorAndHostIds.foreach { id => + val actual = f(id) + val exp = expected.contains(id) + assert(actual === exp, raw"""for string "$id" """) + } + } -class BlacklistTrackerSuite extends SparkFunSuite { + def mockTaskSchedWithConf(conf: SparkConf): TaskSchedulerImpl = { + sc = new SparkContext(conf) + val scheduler = mock[TaskSchedulerImpl] + when(scheduler.sc).thenReturn(sc) + when(scheduler.mapOutputTracker).thenReturn(SparkEnv.get.mapOutputTracker) + scheduler + } + + def createTaskSetBlacklist(stageId: Int = 0): TaskSetBlacklist = { + new TaskSetBlacklist(conf, stageId, clock) + } + + test("executors can be blacklisted with only a few failures per stage") { + // For many different stages, executor 1 fails a task, then executor 2 succeeds the task, + // and then the task set is done. Not enough failures to blacklist the executor *within* + // any particular taskset, but we still blacklist the executor overall eventually. + // Also, we intentionally have a mix of task successes and failures -- there are even some + // successes after the executor is blacklisted. The idea here is those tasks get scheduled + // before the executor is blacklisted. We might get successes after blacklisting (because the + // executor might be flaky but not totally broken). But successes should not unblacklist the + // executor. + val failuresUntilBlacklisted = conf.get(config.MAX_FAILURES_PER_EXEC) + var failuresSoFar = 0 + (0 until failuresUntilBlacklisted * 10).foreach { stageId => + val taskSetBlacklist = createTaskSetBlacklist(stageId) + if (stageId % 2 == 0) { + // fail one task in every other taskset + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + failuresSoFar += 1 + } + blacklist.updateBlacklistForSuccessfulTaskSet(stageId, 0, taskSetBlacklist.execToFailures) + assert(failuresSoFar == stageId / 2 + 1) + if (failuresSoFar < failuresUntilBlacklisted) { + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + } else { + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1")) + } + } + } + + // If an executor has many task failures, but the task set ends up failing, it shouldn't be + // counted against the executor. + test("executors aren't blacklisted as a result of tasks in failed task sets") { + val failuresUntilBlacklisted = conf.get(config.MAX_FAILURES_PER_EXEC) + // for many different stages, executor 1 fails a task, and then the taskSet fails. + (0 until failuresUntilBlacklisted * 10).foreach { stage => + val taskSetBlacklist = createTaskSetBlacklist(stage) + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + } + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + } + + Seq(true, false).foreach { succeedTaskSet => + val label = if (succeedTaskSet) "success" else "failure" + test(s"stage blacklist updates correctly on stage $label") { + // Within one taskset, an executor fails a few times, so it's blacklisted for the taskset. + // But if the taskset fails, we shouldn't blacklist the executor after the stage. + val taskSetBlacklist = createTaskSetBlacklist(0) + // We trigger enough failures for both the taskset blacklist, and the application blacklist. + val numFailures = math.max(conf.get(config.MAX_FAILURES_PER_EXEC), + conf.get(config.MAX_FAILURES_PER_EXEC_STAGE)) + (0 until numFailures).foreach { index => + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = index) + } + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + if (succeedTaskSet) { + // The task set succeeded elsewhere, so we should count those failures against our executor, + // and it should be blacklisted for the entire application. + blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist.execToFailures) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1")) + } else { + // The task set failed, so we don't count these failures against the executor for other + // stages. + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + } + } + } + + test("blacklisted executors and nodes get recovered with time") { + val taskSetBlacklist0 = createTaskSetBlacklist(stageId = 0) + // Fail 4 tasks in one task set on executor 1, so that executor gets blacklisted for the whole + // application. + (0 until 4).foreach { partition => + taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + } + blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) + assert(blacklist.nodeBlacklist() === Set()) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1")) + + val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) + // Fail 4 tasks in one task set on executor 2, so that executor gets blacklisted for the whole + // application. Since that's the second executor that is blacklisted on the same node, we also + // blacklist that node. + (0 until 4).foreach { partition => + taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + } + blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) + assert(blacklist.nodeBlacklist() === Set("hostA")) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set("hostA")) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2")) + + // Advance the clock and then make sure hostA and executors 1 and 2 have been removed from the + // blacklist. + clock.advance(blacklist.BLACKLIST_TIMEOUT_MILLIS + 1) + blacklist.applyBlacklistTimeout() + assert(blacklist.nodeBlacklist() === Set()) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + + // Fail one more task, but executor isn't put back into blacklist since the count of failures + // on that executor should have been reset to 0. + val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 2) + taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + blacklist.updateBlacklistForSuccessfulTaskSet(2, 0, taskSetBlacklist2.execToFailures) + assert(blacklist.nodeBlacklist() === Set()) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + } + + test("blacklist can handle lost executors") { + // The blacklist should still work if an executor is killed completely. We should still + // be able to blacklist the entire node. + val taskSetBlacklist0 = createTaskSetBlacklist(stageId = 0) + // Lets say that executor 1 dies completely. We get some task failures, but + // the taskset then finishes successfully (elsewhere). + (0 until 4).foreach { partition => + taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + } + blacklist.handleRemovedExecutor("1") + blacklist.updateBlacklistForSuccessfulTaskSet( + stageId = 0, + stageAttemptId = 0, + taskSetBlacklist0.execToFailures) + assert(blacklist.isExecutorBlacklisted("1")) + clock.advance(blacklist.BLACKLIST_TIMEOUT_MILLIS / 2) + + // Now another executor gets spun up on that host, but it also dies. + val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) + (0 until 4).foreach { partition => + taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + } + blacklist.handleRemovedExecutor("2") + blacklist.updateBlacklistForSuccessfulTaskSet( + stageId = 1, + stageAttemptId = 0, + taskSetBlacklist1.execToFailures) + // We've now had two bad executors on the hostA, so we should blacklist the entire node. + assert(blacklist.isExecutorBlacklisted("1")) + assert(blacklist.isExecutorBlacklisted("2")) + assert(blacklist.isNodeBlacklisted("hostA")) + + // Advance the clock so that executor 1 should no longer be explicitly blacklisted, but + // everything else should still be blacklisted. + clock.advance(blacklist.BLACKLIST_TIMEOUT_MILLIS / 2 + 1) + blacklist.applyBlacklistTimeout() + assert(!blacklist.isExecutorBlacklisted("1")) + assert(blacklist.isExecutorBlacklisted("2")) + assert(blacklist.isNodeBlacklisted("hostA")) + // make sure we don't leak memory + assert(!blacklist.executorIdToBlacklistStatus.contains("1")) + assert(!blacklist.nodeToBlacklistedExecs("hostA").contains("1")) + // Advance the timeout again so now hostA should be removed from the blacklist. + clock.advance(blacklist.BLACKLIST_TIMEOUT_MILLIS / 2) + blacklist.applyBlacklistTimeout() + assert(!blacklist.nodeIdToBlacklistExpiryTime.contains("hostA")) + } + + test("task failures expire with time") { + // Verifies that 2 failures within the timeout period cause an executor to be blacklisted, but + // if task failures are spaced out by more than the timeout period, the first failure is timed + // out, and the executor isn't blacklisted. + var stageId = 0 + def failOneTaskInTaskSet(exec: String): Unit = { + val taskSetBlacklist = createTaskSetBlacklist(stageId = stageId) + taskSetBlacklist.updateBlacklistForFailedTask("host-" + exec, exec, 0) + blacklist.updateBlacklistForSuccessfulTaskSet(stageId, 0, taskSetBlacklist.execToFailures) + stageId += 1 + } + failOneTaskInTaskSet(exec = "1") + // We have one sporadic failure on exec 2, but that's it. Later checks ensure that we never + // blacklist executor 2 despite this one failure. + failOneTaskInTaskSet(exec = "2") + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + assert(blacklist.nextExpiryTime === Long.MaxValue) + + // We advance the clock past the expiry time. + clock.advance(blacklist.BLACKLIST_TIMEOUT_MILLIS + 1) + val t0 = clock.getTimeMillis() + blacklist.applyBlacklistTimeout() + assert(blacklist.nextExpiryTime === Long.MaxValue) + failOneTaskInTaskSet(exec = "1") + + // Because the 2nd failure on executor 1 happened past the expiry time, nothing should have been + // blacklisted. + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + + // Now we add one more failure, within the timeout, and it should be counted. + clock.setTime(t0 + blacklist.BLACKLIST_TIMEOUT_MILLIS - 1) + val t1 = clock.getTimeMillis() + failOneTaskInTaskSet(exec = "1") + blacklist.applyBlacklistTimeout() + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1")) + assert(blacklist.nextExpiryTime === t1 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + + // Add failures on executor 3, make sure it gets put on the blacklist. + clock.setTime(t1 + blacklist.BLACKLIST_TIMEOUT_MILLIS - 1) + val t2 = clock.getTimeMillis() + failOneTaskInTaskSet(exec = "3") + failOneTaskInTaskSet(exec = "3") + blacklist.applyBlacklistTimeout() + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "3")) + assert(blacklist.nextExpiryTime === t1 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + + // Now we go past the timeout for executor 1, so it should be dropped from the blacklist. + clock.setTime(t1 + blacklist.BLACKLIST_TIMEOUT_MILLIS + 1) + blacklist.applyBlacklistTimeout() + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("3")) + assert(blacklist.nextExpiryTime === t2 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + + // Make sure that we update correctly when we go from having blacklisted executors to + // just having tasks with timeouts. + clock.setTime(t2 + blacklist.BLACKLIST_TIMEOUT_MILLIS - 1) + failOneTaskInTaskSet(exec = "4") + blacklist.applyBlacklistTimeout() + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("3")) + assert(blacklist.nextExpiryTime === t2 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + + clock.setTime(t2 + blacklist.BLACKLIST_TIMEOUT_MILLIS + 1) + blacklist.applyBlacklistTimeout() + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + // we've got one task failure still, but we don't bother setting nextExpiryTime to it, to + // avoid wasting time checking for expiry of individual task failures. + assert(blacklist.nextExpiryTime === Long.MaxValue) + } + + test("task failure timeout works as expected for long-running tasksets") { + // This ensures that we don't trigger spurious blacklisting for long tasksets, when the taskset + // finishes long after the task failures. We create two tasksets, each with one failure. + // Individually they shouldn't cause any blacklisting since there is only one failure. + // Furthermore, we space the failures out so far that even when both tasksets have completed, + // we still don't trigger any blacklisting. + val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) + val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 2) + // Taskset1 has one failure immediately + taskSetBlacklist1.updateBlacklistForFailedTask("host-1", "1", 0) + // Then we have a *long* delay, much longer than the timeout, before any other failures or + // taskset completion + clock.advance(blacklist.BLACKLIST_TIMEOUT_MILLIS * 5) + // After the long delay, we have one failure on taskset 2, on the same executor + taskSetBlacklist2.updateBlacklistForFailedTask("host-1", "1", 0) + // Finally, we complete both tasksets. Its important here to complete taskset2 *first*. We + // want to make sure that when taskset 1 finishes, even though we've now got two task failures, + // we realize that the task failure we just added was well before the timeout. + clock.advance(1) + blacklist.updateBlacklistForSuccessfulTaskSet(stageId = 2, 0, taskSetBlacklist2.execToFailures) + clock.advance(1) + blacklist.updateBlacklistForSuccessfulTaskSet(stageId = 1, 0, taskSetBlacklist1.execToFailures) + + // Make sure nothing was blacklisted + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + } + + test("only blacklist nodes for the application when enough executors have failed on that " + + "specific host") { + // we blacklist executors on two different hosts -- make sure that doesn't lead to any + // node blacklisting + val taskSetBlacklist0 = createTaskSetBlacklist(stageId = 0) + taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1")) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) + + val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) + taskSetBlacklist1.updateBlacklistForFailedTask("hostB", exec = "2", index = 0) + taskSetBlacklist1.updateBlacklistForFailedTask("hostB", exec = "2", index = 1) + blacklist.updateBlacklistForSuccessfulTaskSet(1, 0, taskSetBlacklist1.execToFailures) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2")) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) + + // Finally, blacklist another executor on the same node as the original blacklisted executor, + // and make sure this time we *do* blacklist the node. + val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 0) + taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "3", index = 0) + taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "3", index = 1) + blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2", "3")) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set("hostA")) + } test("blacklist still respects legacy configs") { val conf = new SparkConf().setMaster("local") @@ -68,6 +414,8 @@ class BlacklistTrackerSuite extends SparkFunSuite { config.MAX_TASK_ATTEMPTS_PER_NODE, config.MAX_FAILURES_PER_EXEC_STAGE, config.MAX_FAILED_EXEC_PER_NODE_STAGE, + config.MAX_FAILURES_PER_EXEC, + config.MAX_FAILED_EXEC_PER_NODE, config.BLACKLIST_TIMEOUT_CONF ).foreach { config => conf.set(config.key, "0") diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 230e2c34d0..4c3d0b1021 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -119,19 +119,20 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit } test("Event log name") { + val baseDirUri = Utils.resolveURI("/base-dir") // without compression - assert(s"file:/base-dir/app1" === EventLoggingListener.getLogPath( - Utils.resolveURI("/base-dir"), "app1", None)) + assert(s"${baseDirUri.toString}/app1" === EventLoggingListener.getLogPath( + baseDirUri, "app1", None)) // with compression - assert(s"file:/base-dir/app1.lzf" === - EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), "app1", None, Some("lzf"))) + assert(s"${baseDirUri.toString}/app1.lzf" === + EventLoggingListener.getLogPath(baseDirUri, "app1", None, Some("lzf"))) // illegal characters in app ID - assert(s"file:/base-dir/a-fine-mind_dollar_bills__1" === - EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), + assert(s"${baseDirUri.toString}/a-fine-mind_dollar_bills__1" === + EventLoggingListener.getLogPath(baseDirUri, "a fine:mind$dollar{bills}.1", None)) // illegal characters in app ID with compression - assert(s"file:/base-dir/a-fine-mind_dollar_bills__1.lz4" === - EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), + assert(s"${baseDirUri.toString}/a-fine-mind_dollar_bills__1.lz4" === + EventLoggingListener.getLogPath(baseDirUri, "a fine:mind$dollar{bills}.1", None, Some("lz4"))) } @@ -289,7 +290,7 @@ object EventLoggingListenerSuite { val conf = new SparkConf conf.set("spark.eventLog.enabled", "true") conf.set("spark.eventLog.testing", "true") - conf.set("spark.eventLog.dir", logDir.toString) + conf.set("spark.eventLog.dir", logDir.toUri.toString) compressionCodec.foreach { codec => conf.set("spark.eventLog.compress", "true") conf.set("spark.io.compression.codec", codec) diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index a757041299..fe6de2bd98 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -17,12 +17,20 @@ package org.apache.spark.scheduler +import java.util.Properties + +import org.apache.spark.SparkEnv import org.apache.spark.TaskContext +import org.apache.spark.executor.TaskMetrics class FakeTask( stageId: Int, partitionId: Int, - prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, partitionId) { + prefLocs: Seq[TaskLocation] = Nil, + serializedTaskMetrics: Array[Byte] = + SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array()) + extends Task[Int](stageId, 0, partitionId, new Properties, serializedTaskMetrics) { + override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 83288db92b..8c4e389e86 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -158,10 +158,9 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { 0 until rdd.partitions.size, resultHandler, () => Unit) // It's an error if the job completes successfully even though no committer was authorized, // so throw an exception if the job was allowed to complete. - val e = intercept[SparkException] { + intercept[TimeoutException] { ThreadUtils.awaitResult(futureAction, 5 seconds) } - assert(e.getCause.isInstanceOf[TimeoutException]) assert(tempDir.list().size === 0) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index c28aa06623..2ba63da881 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -28,6 +28,8 @@ import scala.reflect.ClassTag import org.scalactic.TripleEquals import org.scalatest.Assertions.AssertionsHelper +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.TaskState._ @@ -157,8 +159,16 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa } // When a job fails, we terminate before waiting for all the task end events to come in, // so there might still be a running task set. So we only check these conditions - // when the job succeeds - assert(taskScheduler.runningTaskSets.isEmpty) + // when the job succeeds. + // When the final task of a taskset completes, we post + // the event to the DAGScheduler event loop before we finish processing in the taskscheduler + // thread. It's possible the DAGScheduler thread processes the event, finishes the job, + // and notifies the job waiter before our original thread in the task scheduler finishes + // handling the event and marks the taskset as complete. So its ok if we need to wait a + // *little* bit longer for the original taskscheduler thread to finish up to deal w/ the race. + eventually(timeout(1 second), interval(10 millis)) { + assert(taskScheduler.runningTaskSets.isEmpty) + } assert(!backend.hasTasks) } else { assert(failure != null) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 9eda79ace1..7004128308 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -62,7 +62,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val func = (c: TaskContext, i: Iterator[String]) => i.next() val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) val task = new ResultTask[String, String]( - 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, new TaskMetrics) + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, + closureSerializer.serialize(TaskMetrics.registered).array()) intercept[RuntimeException] { task.run(0, 0, null) } @@ -83,7 +84,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val func = (c: TaskContext, i: Iterator[String]) => i.next() val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) val task = new ResultTask[String, String]( - 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, new TaskMetrics) + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, + closureSerializer.serialize(TaskMetrics.registered).array()) intercept[RuntimeException] { task.run(0, 0, null) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index ee95e4ff7d..c9e682f53c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -171,7 +171,7 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local val tempDir = Utils.createTempDir() val srcDir = new File(tempDir, "repro/") srcDir.mkdirs() - val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath, + val excSource = new JavaSourceFromString(new File(srcDir, "MyException").toURI.getPath, """package repro; | |public class MyException extends Exception { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index a0b6268331..304dc9d47e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -21,14 +21,15 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.mockito.Matchers.{anyInt, anyString, eq => meq} -import org.mockito.Mockito.{atLeast, atMost, never, spy, verify, when} +import org.mockito.Matchers.{anyInt, anyObject, anyString, eq => meq} +import org.mockito.Mockito.{atLeast, atMost, never, spy, times, verify, when} import org.scalatest.BeforeAndAfterEach import org.scalatest.mock.MockitoSugar import org.apache.spark._ import org.apache.spark.internal.config import org.apache.spark.internal.Logging +import org.apache.spark.storage.BlockManagerId class FakeSchedulerBackend extends SchedulerBackend { def start() {} @@ -44,6 +45,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B var failedTaskSetReason: String = null var failedTaskSet = false + var blacklist: BlacklistTracker = null var taskScheduler: TaskSchedulerImpl = null var dagScheduler: DAGScheduler = null @@ -82,11 +84,12 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B } def setupSchedulerWithMockTaskSetBlacklist(): TaskSchedulerImpl = { + blacklist = mock[BlacklistTracker] val conf = new SparkConf().setMaster("local").setAppName("TaskSchedulerImplSuite") conf.set(config.BLACKLIST_ENABLED, true) sc = new SparkContext(conf) taskScheduler = - new TaskSchedulerImpl(sc, sc.conf.getInt("spark.task.maxFailures", 4)) { + new TaskSchedulerImpl(sc, sc.conf.getInt("spark.task.maxFailures", 4), Some(blacklist)) { override def createTaskSetManager(taskSet: TaskSet, maxFailures: Int): TaskSetManager = { val tsm = super.createTaskSetManager(taskSet, maxFailures) // we need to create a spied tsm just so we can set the TaskSetBlacklist @@ -408,6 +411,95 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B } assert(tsm.isZombie) } + + // the tasksSets complete, so the tracker should be notified of the successful ones + verify(blacklist, times(1)).updateBlacklistForSuccessfulTaskSet( + stageId = 0, + stageAttemptId = 0, + failuresByExec = stageToMockTaskSetBlacklist(0).execToFailures) + verify(blacklist, times(1)).updateBlacklistForSuccessfulTaskSet( + stageId = 1, + stageAttemptId = 0, + failuresByExec = stageToMockTaskSetBlacklist(1).execToFailures) + // but we shouldn't update for the failed taskset + verify(blacklist, never).updateBlacklistForSuccessfulTaskSet( + stageId = meq(2), + stageAttemptId = anyInt(), + failuresByExec = anyObject()) + } + + test("scheduled tasks obey node and executor blacklists") { + taskScheduler = setupSchedulerWithMockTaskSetBlacklist() + (0 to 2).foreach { stageId => + val taskSet = FakeTask.createTaskSet(numTasks = 2, stageId = stageId, stageAttemptId = 0) + taskScheduler.submitTasks(taskSet) + } + + val offers = IndexedSeq( + new WorkerOffer("executor0", "host0", 1), + new WorkerOffer("executor1", "host1", 1), + new WorkerOffer("executor2", "host1", 1), + new WorkerOffer("executor3", "host2", 10), + new WorkerOffer("executor4", "host3", 1) + ) + + // setup our mock blacklist: + // host1, executor0 & executor3 are completely blacklisted + // This covers everything *except* one core on executor4 / host3, so that everything is still + // schedulable. + when(blacklist.isNodeBlacklisted("host1")).thenReturn(true) + when(blacklist.isExecutorBlacklisted("executor0")).thenReturn(true) + when(blacklist.isExecutorBlacklisted("executor3")).thenReturn(true) + + val stageToTsm = (0 to 2).map { stageId => + val tsm = taskScheduler.taskSetManagerForAttempt(stageId, 0).get + stageId -> tsm + }.toMap + + val firstTaskAttempts = taskScheduler.resourceOffers(offers).flatten + firstTaskAttempts.foreach { task => logInfo(s"scheduled $task on ${task.executorId}") } + assert(firstTaskAttempts.size === 1) + assert(firstTaskAttempts.head.executorId === "executor4") + ('0' until '2').foreach { hostNum => + verify(blacklist, atLeast(1)).isNodeBlacklisted("host" + hostNum) + } + } + + test("abort stage when all executors are blacklisted") { + taskScheduler = setupSchedulerWithMockTaskSetBlacklist() + val taskSet = FakeTask.createTaskSet(numTasks = 10, stageAttemptId = 0) + taskScheduler.submitTasks(taskSet) + val tsm = stageToMockTaskSetManager(0) + + // first just submit some offers so the scheduler knows about all the executors + taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor0", "host0", 2), + WorkerOffer("executor1", "host0", 2), + WorkerOffer("executor2", "host0", 2), + WorkerOffer("executor3", "host1", 2) + )) + + // now say our blacklist updates to blacklist a bunch of resources, but *not* everything + when(blacklist.isNodeBlacklisted("host1")).thenReturn(true) + when(blacklist.isExecutorBlacklisted("executor0")).thenReturn(true) + + // make an offer on the blacklisted resources. We won't schedule anything, but also won't + // abort yet, since we know of other resources that work + assert(taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor0", "host0", 2), + WorkerOffer("executor3", "host1", 2) + )).flatten.size === 0) + assert(!tsm.isZombie) + + // now update the blacklist so that everything really is blacklisted + when(blacklist.isExecutorBlacklisted("executor1")).thenReturn(true) + when(blacklist.isExecutorBlacklisted("executor2")).thenReturn(true) + assert(taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor0", "host0", 2), + WorkerOffer("executor3", "host1", 2) + )).flatten.size === 0) + assert(tsm.isZombie) + verify(tsm).abort(anyString(), anyObject()) } /** @@ -650,6 +742,17 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(taskScheduler.getExecutorsAliveOnHost("host1") === Some(Set("executor1", "executor3"))) } + test("scheduler checks for executors that can be expired from blacklist") { + taskScheduler = setupScheduler() + + taskScheduler.submitTasks(FakeTask.createTaskSet(1, 0)) + taskScheduler.resourceOffers(IndexedSeq( + new WorkerOffer("executor0", "host0", 1) + )).flatten + + verify(blacklist).applyBlacklistTimeout() + } + test("if an executor is lost then the state for its running tasks is cleaned up (SPARK-18553)") { sc = new SparkContext("local", "TaskSchedulerImplSuite") val taskScheduler = new TaskSchedulerImpl(sc) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala index 8c902af568..6b52c10b2c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala @@ -85,9 +85,9 @@ class TaskSetBlacklistSuite extends SparkFunSuite { Seq("exec1", "exec2").foreach { exec => assert( - execToFailures(exec).taskToFailureCount === Map( - 0 -> 1, - 1 -> 1 + execToFailures(exec).taskToFailureCountAndFailureTime === Map( + 0 -> (1, 0), + 1 -> (1, 0) ) ) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index abc8fff30e..2f5b029a96 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -183,7 +183,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) val accumUpdates = taskSet.tasks.head.metrics.internalAccums // Offer a host with NO_PREF as the constraint, @@ -236,7 +236,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execC", "host2")) val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "execB"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // An executor that is not NODE_LOCAL should be rejected. assert(manager.resourceOffer("execC", "host2", ANY) === None) @@ -257,7 +257,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq() // Last task has no locality prefs ) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // First offer host1, exec1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) == None) @@ -286,7 +286,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq() // Last task has no locality prefs ) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // First offer host1, exec1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).get.index === 0) assert(manager.resourceOffer("exec3", "host2", PROCESS_LOCAL).get.index === 1) @@ -306,7 +306,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(TaskLocation("host2")) ) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) @@ -344,7 +344,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(TaskLocation("host3")) ) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) @@ -376,7 +376,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) @@ -393,7 +393,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted // after the last failure. @@ -426,7 +426,10 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // affinity to exec1 on host1 - which we will fail. val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "exec1"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, 4, clock) + // We don't directly use the application blacklist, but its presence triggers blacklisting + // within the taskset. + val blacklistTrackerOpt = Some(new BlacklistTracker(conf, clock)) + val manager = new TaskSetManager(sched, taskSet, 4, blacklistTrackerOpt, clock) { val offerResult = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) @@ -515,7 +518,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(TaskLocation("host2", "execC")), Seq()) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // Only ANY is valid assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY))) // Add a new executor @@ -546,7 +549,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(TaskLocation("host1", "execB")), Seq(TaskLocation("host2", "execC")), Seq()) - val manager = new TaskSetManager(sched, taskSet, 1, new ManualClock) + val manager = new TaskSetManager(sched, taskSet, 1, clock = new ManualClock) sched.addExecutor("execA", "host1") manager.executorAdded() sched.addExecutor("execC", "host2") @@ -579,7 +582,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(TaskLocation("host1", "execA")), Seq(TaskLocation("host1", "execA"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY))) // Set allowed locality to ANY @@ -670,7 +673,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(), Seq(TaskLocation("host3", "execC"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) assert(manager.resourceOffer("execA", "host1", PROCESS_LOCAL).get.index === 0) assert(manager.resourceOffer("execA", "host1", NODE_LOCAL) == None) @@ -698,7 +701,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(), Seq(TaskLocation("host3"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // node-local tasks are scheduled without delay assert(manager.resourceOffer("execA", "host1", NODE_LOCAL).get.index === 0) @@ -720,7 +723,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(ExecutorCacheTaskLocation("host1", "execA")), Seq(ExecutorCacheTaskLocation("host2", "execB"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // process-local tasks are scheduled first assert(manager.resourceOffer("execA", "host1", NODE_LOCAL).get.index === 2) @@ -740,7 +743,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(ExecutorCacheTaskLocation("host1", "execA")), Seq(ExecutorCacheTaskLocation("host2", "execB"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // process-local tasks are scheduled first assert(manager.resourceOffer("execA", "host1", PROCESS_LOCAL).get.index === 1) @@ -760,7 +763,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(TaskLocation("host1", "execA")), Seq(TaskLocation("host2", "execB.1"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // Only ANY is valid assert(manager.myLocalityLevels.sameElements(Array(ANY))) // Add a new executor @@ -794,7 +797,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(TaskLocation("host2")), Seq(TaskLocation("hdfs_cache_host3"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) sched.removeExecutor("execA") manager.executorAdded() @@ -822,7 +825,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Set the speculation multiplier to be 0 so speculative tasks are launched immediately sc.conf.set("spark.speculation.multiplier", "0.0") val clock = new ManualClock() - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => task.metrics.internalAccums } @@ -876,7 +879,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sc.conf.set("spark.speculation.multiplier", "0.0") sc.conf.set("spark.speculation.quantile", "0.6") val clock = new ManualClock() - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => task.metrics.internalAccums } @@ -980,17 +983,17 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sc = new SparkContext("local", "test") sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, new ManualClock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock) assert(manager.name === "TaskSet_0.0") // Make sure a task set with the same stage ID but different attempt ID has a unique name val taskSet2 = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 1) - val manager2 = new TaskSetManager(sched, taskSet2, MAX_TASK_FAILURES, new ManualClock) + val manager2 = new TaskSetManager(sched, taskSet2, MAX_TASK_FAILURES, clock = new ManualClock) assert(manager2.name === "TaskSet_0.1") // Make sure a task set with the same attempt ID but different stage ID also has a unique name val taskSet3 = FakeTask.createTaskSet(numTasks = 1, stageId = 1, stageAttemptId = 1) - val manager3 = new TaskSetManager(sched, taskSet3, MAX_TASK_FAILURES, new ManualClock) + val manager3 = new TaskSetManager(sched, taskSet3, MAX_TASK_FAILURES, clock = new ManualClock) assert(manager3.name === "TaskSet_1.1") } diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index a61ec74c7d..0f3a4a0361 100644 --- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -24,6 +24,7 @@ import com.google.common.io.ByteStreams import org.apache.spark._ import org.apache.spark.internal.config._ +import org.apache.spark.network.util.CryptoUtils import org.apache.spark.security.CryptoStreamUtils._ import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.storage.TempShuffleBlockId @@ -33,11 +34,11 @@ class CryptoStreamUtilsSuite extends SparkFunSuite { test("crypto configuration conversion") { val sparkKey1 = s"${SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX}a.b.c" val sparkVal1 = "val1" - val cryptoKey1 = s"${COMMONS_CRYPTO_CONF_PREFIX}a.b.c" + val cryptoKey1 = s"${CryptoUtils.COMMONS_CRYPTO_CONFIG_PREFIX}a.b.c" val sparkKey2 = SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX.stripSuffix(".") + "A.b.c" val sparkVal2 = "val2" - val cryptoKey2 = s"${COMMONS_CRYPTO_CONF_PREFIX}A.b.c" + val cryptoKey2 = s"${CryptoUtils.COMMONS_CRYPTO_CONFIG_PREFIX}A.b.c" val conf = new SparkConf() conf.set(sparkKey1, sparkVal1) conf.set(sparkKey2, sparkVal2) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index e3ec99685f..e56e440380 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.InputStream +import java.io.{File, InputStream, IOException} import java.util.concurrent.Semaphore import scala.concurrent.ExecutionContext.Implicits.global @@ -31,8 +31,9 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ -import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException @@ -63,7 +64,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Create a mock managed buffer for testing def createMockManagedBuffer(): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) - when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream])) + val in = mock(classOf[InputStream]) + when(in.read(any())).thenReturn(1) + when(in.read(any(), any(), any())).thenReturn(1) + when(mockManagedBuffer.createInputStream()).thenReturn(in) mockManagedBuffer } @@ -99,8 +103,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, + (_, in) => in, 48 * 1024 * 1024, - Int.MaxValue) + Int.MaxValue, + true) // 3 local blocks fetched in initialization verify(blockManager, times(3)).getBlockData(any()) @@ -172,8 +178,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, + (_, in) => in, 48 * 1024 * 1024, - Int.MaxValue) + Int.MaxValue, + true) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() iterator.next()._2.close() // close() first block's input stream @@ -201,9 +209,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer() ) // Semaphore to coordinate event sequence in two different threads. @@ -235,8 +243,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, + (_, in) => in, 48 * 1024 * 1024, - Int.MaxValue) + Int.MaxValue, + true) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -247,4 +257,148 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } intercept[FetchFailedException] { iterator.next() } } + + test("retry corrupt blocks") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer() + ) + + // Semaphore to coordinate event sequence in two different threads. + val sem = new Semaphore(0) + + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[ManagedBuffer]) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) + + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + Future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer) + sem.release() + } + } + }) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => new LimitedInputStream(in, 100), + 48 * 1024 * 1024, + Int.MaxValue, + true) + + // Continue only after the mock calls onBlockFetchFailure + sem.acquire() + + // The first block should be returned without an exception + val (id1, _) = iterator.next() + assert(id1 === ShuffleBlockId(0, 0, 0)) + + when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + Future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + sem.release() + } + } + }) + + // The next block is corrupt local block (the second one is corrupt and retried) + intercept[FetchFailedException] { iterator.next() } + + sem.acquire() + intercept[FetchFailedException] { iterator.next() } + } + + test("retry corrupt blocks (disabled)") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer() + ) + + // Semaphore to coordinate event sequence in two different threads. + val sem = new Semaphore(0) + + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[ManagedBuffer]) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + Future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, corruptBuffer) + sem.release() + } + } + }) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => new LimitedInputStream(in, 100), + 48 * 1024 * 1024, + Int.MaxValue, + false) + + // Continue only after the mock calls onBlockFetchFailure + sem.acquire() + + // The first block should be returned without an exception + val (id1, _) = iterator.next() + assert(id1 === ShuffleBlockId(0, 0, 0)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockId(0, 1, 0)) + val (id3, _) = iterator.next() + assert(id3 === ShuffleBlockId(0, 2, 0)) + } + } diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala index 667a4db6f7..55c5dd5e24 100644 --- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala @@ -44,6 +44,19 @@ class SamplingUtilsSuite extends SparkFunSuite { assert(sample3.length === 10) } + test("SPARK-18678 reservoirSampleAndCount with tiny input") { + val input = Seq(0, 1) + val counts = new Array[Int](input.size) + for (i <- 0 until 500) { + val (samples, inputSize) = SamplingUtils.reservoirSampleAndCount(input.iterator, 1) + assert(inputSize === 2) + assert(samples.length === 1) + counts(samples.head) += 1 + } + // If correct, should be true with prob ~ 0.99999707 + assert(math.abs(counts(0) - counts(1)) <= 100) + } + test("computeFraction") { // test that the computed fraction guarantees enough data points // in the sample with a failure rate <= 0.0001 diff --git a/dev/.rat-excludes b/dev/.rat-excludes index a3efddeaa5..6be1c72bc6 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -102,3 +102,4 @@ org.apache.spark.scheduler.ExternalClusterManager .Rbuildignore org.apache.spark.deploy.yarn.security.ServiceCredentialProvider spark-warehouse +structured-streaming/* diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index aa42750f26..b08577c47c 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -150,7 +150,7 @@ if [[ "$1" == "package" ]]; then NAME=$1 FLAGS=$2 ZINC_PORT=$3 - BUILD_PIP_PACKAGE=$4 + BUILD_PACKAGE=$4 cp -r spark spark-$SPARK_VERSION-bin-$NAME cd spark-$SPARK_VERSION-bin-$NAME @@ -172,11 +172,30 @@ if [[ "$1" == "package" ]]; then MVN_HOME=`$MVN -version 2>&1 | grep 'Maven home' | awk '{print $NF}'` - if [ -z "$BUILD_PIP_PACKAGE" ]; then - echo "Creating distribution without PIP package" + if [ -z "$BUILD_PACKAGE" ]; then + echo "Creating distribution without PIP/R package" ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz $FLAGS \ -DzincPort=$ZINC_PORT 2>&1 > ../binary-release-$NAME.log cd .. + elif [[ "$BUILD_PACKAGE" == "withr" ]]; then + echo "Creating distribution with R package" + ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz --r $FLAGS \ + -DzincPort=$ZINC_PORT 2>&1 > ../binary-release-$NAME.log + cd .. + + echo "Copying and signing R source package" + R_DIST_NAME=SparkR_$SPARK_VERSION.tar.gz + cp spark-$SPARK_VERSION-bin-$NAME/R/$R_DIST_NAME . + + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ + --output $R_DIST_NAME.asc \ + --detach-sig $R_DIST_NAME + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + MD5 $R_DIST_NAME > \ + $R_DIST_NAME.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 $R_DIST_NAME > \ + $R_DIST_NAME.sha else echo "Creating distribution with PIP package" ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz --pip $FLAGS \ @@ -219,7 +238,7 @@ if [[ "$1" == "package" ]]; then FLAGS="-Psparkr -Phive -Phive-thriftserver -Pyarn -Pmesos" make_binary_release "hadoop2.3" "-Phadoop-2.3 $FLAGS" "3033" & make_binary_release "hadoop2.4" "-Phadoop-2.4 $FLAGS" "3034" & - make_binary_release "hadoop2.6" "-Phadoop-2.6 $FLAGS" "3035" & + make_binary_release "hadoop2.6" "-Phadoop-2.6 $FLAGS" "3035" "withr" & make_binary_release "hadoop2.7" "-Phadoop-2.7 $FLAGS" "3036" "withpip" & make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn -Pmesos" "3037" & make_binary_release "without-hadoop" "-Psparkr -Phadoop-provided -Pyarn -Pmesos" "3038" & @@ -232,6 +251,8 @@ if [[ "$1" == "package" ]]; then # Put to new directory: LFTP mkdir -p $dest_dir LFTP mput -O $dest_dir 'spark-*' + LFTP mput -O $dest_dir 'pyspark-*' + LFTP mput -O $dest_dir 'SparkR_*' # Delete /latest directory and rename new upload to /latest LFTP "rm -r -f $REMOTE_PARENT_DIR/latest || exit 0" LFTP mv $dest_dir "$REMOTE_PARENT_DIR/latest" @@ -239,6 +260,7 @@ if [[ "$1" == "package" ]]; then LFTP mkdir -p $dest_dir LFTP mput -O $dest_dir 'spark-*' LFTP mput -O $dest_dir 'pyspark-*' + LFTP mput -O $dest_dir 'SparkR_*' exit 0 fi diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 89bfcef4d9..9cbab3d895 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -122,13 +122,13 @@ metrics-graphite-3.1.2.jar metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.3.0.jar -netty-3.8.0.Final.jar +netty-3.9.9.Final.jar netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar -paranamer-2.3.jar +paranamer-2.6.jar parquet-column-1.8.1.jar parquet-common-1.8.1.jar parquet-encoding-1.8.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 8df3858825..63ce6c66fd 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -129,13 +129,13 @@ metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar -netty-3.8.0.Final.jar +netty-3.9.9.Final.jar netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar -paranamer-2.3.jar +paranamer-2.6.jar parquet-column-1.8.1.jar parquet-common-1.8.1.jar parquet-encoding-1.8.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 71e7fb6dd2..122d5c27d0 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -129,13 +129,13 @@ metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar -netty-3.8.0.Final.jar +netty-3.9.9.Final.jar netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar -paranamer-2.3.jar +paranamer-2.6.jar parquet-column-1.8.1.jar parquet-common-1.8.1.jar parquet-encoding-1.8.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index ba31391495..776aabd111 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -137,13 +137,13 @@ metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar -netty-3.8.0.Final.jar +netty-3.9.9.Final.jar netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar -paranamer-2.3.jar +paranamer-2.6.jar parquet-column-1.8.1.jar parquet-common-1.8.1.jar parquet-encoding-1.8.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index b129e5a99e..524e824073 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -138,13 +138,13 @@ metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar -netty-3.8.0.Final.jar +netty-3.9.9.Final.jar netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar osgi-resource-locator-1.0.1.jar -paranamer-2.3.jar +paranamer-2.6.jar parquet-column-1.8.1.jar parquet-common-1.8.1.jar parquet-encoding-1.8.1.jar diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 49b46fbc3f..6c5ae0d629 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -34,6 +34,7 @@ DISTDIR="$SPARK_HOME/dist" MAKE_TGZ=false MAKE_PIP=false +MAKE_R=false NAME=none MVN="$SPARK_HOME/build/mvn" @@ -41,7 +42,7 @@ function exit_with_usage { echo "make-distribution.sh - tool for making binary distributions of Spark" echo "" echo "usage:" - cl_options="[--name] [--tgz] [--pip] [--mvn ]" + cl_options="[--name] [--tgz] [--pip] [--r] [--mvn ]" echo "make-distribution.sh $cl_options " echo "See Spark's \"Building Spark\" doc for correct Maven options." echo "" @@ -71,6 +72,9 @@ while (( "$#" )); do --pip) MAKE_PIP=true ;; + --r) + MAKE_R=true + ;; --mvn) MVN="$2" shift @@ -98,6 +102,13 @@ if [ -z "$JAVA_HOME" ]; then echo "No JAVA_HOME set, proceeding with '$JAVA_HOME' learned from rpm" fi fi + + if [ -z "$JAVA_HOME" ]; then + if [ `command -v java` ]; then + # If java is in /usr/bin/java, we want /usr + JAVA_HOME="$(dirname $(dirname $(which java)))" + fi + fi fi if [ -z "$JAVA_HOME" ]; then @@ -208,11 +219,30 @@ cp -r "$SPARK_HOME/data" "$DISTDIR" # Make pip package if [ "$MAKE_PIP" == "true" ]; then echo "Building python distribution package" - cd $SPARK_HOME/python + pushd "$SPARK_HOME/python" > /dev/null python setup.py sdist - cd .. + popd > /dev/null +else + echo "Skipping building python distribution package" +fi + +# Make R package - this is used for both CRAN release and packing R layout into distribution +if [ "$MAKE_R" == "true" ]; then + echo "Building R source package" + R_PACKAGE_VERSION=`grep Version $SPARK_HOME/R/pkg/DESCRIPTION | awk '{print $NF}'` + pushd "$SPARK_HOME/R" > /dev/null + # Build source package and run full checks + # Install source package to get it to generate vignettes, etc. + # Do not source the check-cran.sh - it should be run from where it is for it to set SPARK_HOME + NO_TESTS=1 CLEAN_INSTALL=1 "$SPARK_HOME/"R/check-cran.sh + # Move R source package to match the Spark release version if the versions are not the same. + # NOTE(shivaram): `mv` throws an error on Linux if source and destination are same file + if [ "$R_PACKAGE_VERSION" != "$VERSION" ]; then + mv $SPARK_HOME/R/SparkR_"$R_PACKAGE_VERSION".tar.gz $SPARK_HOME/R/SparkR_"$VERSION".tar.gz + fi + popd > /dev/null else - echo "Skipping creating pip installable PySpark" + echo "Skipping building R source package" fi # Copy other things @@ -221,6 +251,12 @@ cp "$SPARK_HOME"/conf/*.template "$DISTDIR"/conf cp "$SPARK_HOME/README.md" "$DISTDIR" cp -r "$SPARK_HOME/bin" "$DISTDIR" cp -r "$SPARK_HOME/python" "$DISTDIR" + +# Remove the python distribution from dist/ if we built it +if [ "$MAKE_PIP" == "true" ]; then + rm -f $DISTDIR/python/dist/pyspark-*.tar.gz +fi + cp -r "$SPARK_HOME/sbin" "$DISTDIR" # Copy SparkR if it exists if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 1d1e72facc..bb286af763 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -80,7 +80,7 @@ def pr_message(build_display_name, short_commit_hash, commit_url, str(' ' + post_msg + '.') if post_msg else '.') - return '**[Test build %s %s](%sconsoleFull)** for PR %s at commit [`%s`](%s)%s' % str_args + return '**[Test build %s %s](%stestReport)** for PR %s at commit [`%s`](%s)%s' % str_args def run_pr_checks(pr_tests, ghprb_actual_commit, sha1): diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index b34ab51f3b..10ad1fe3aa 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -245,7 +245,8 @@ def __hash__(self): name="streaming-kafka-0-10", dependencies=[streaming], source_file_regexes=[ - "external/kafka-0-10", + # The ending "/" is necessary otherwise it will include "sql-kafka" codes + "external/kafka-0-10/", "external/kafka-0-10-assembly", ], sbt_test_goals=[ @@ -469,7 +470,7 @@ def __hash__(self): name="yarn", dependencies=[], source_file_regexes=[ - "yarn/", + "resource-managers/yarn/", "common/network-yarn/", ], build_profile_flags=["-Pyarn"], @@ -485,7 +486,7 @@ def __hash__(self): mesos = Module( name="mesos", dependencies=[], - source_file_regexes=["mesos/"], + source_file_regexes=["resource-managers/mesos/"], build_profile_flags=["-Pmesos"], sbt_test_goals=["mesos/test"] ) diff --git a/docs/README.md b/docs/README.md index ffd3b5712b..90e10a104b 100644 --- a/docs/README.md +++ b/docs/README.md @@ -69,4 +69,5 @@ may take some time as it generates all of the scaladoc. The jekyll plugin also PySpark docs using [Sphinx](http://sphinx-doc.org/). NOTE: To skip the step of building and copying over the Scala, Python, R API docs, run `SKIP_API=1 -jekyll`. +jekyll`. In addition, `SKIP_SCALADOC=1`, `SKIP_PYTHONDOC=1`, and `SKIP_RDOC=1` can be used to skip a single +step of the corresponding language. diff --git a/docs/_config.yml b/docs/_config.yml index e4fc093fe7..83bb30598d 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 2.1.0-SNAPSHOT -SPARK_VERSION_SHORT: 2.1.0 +SPARK_VERSION: 2.2.0-SNAPSHOT +SPARK_VERSION_SHORT: 2.2.0 SCALA_BINARY_VERSION: "2.11" SCALA_VERSION: "2.11.7" MESOS_VERSION: 1.0.0 diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index f926d67e6b..95e3ba35e9 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -113,33 +113,41 @@ File.open(css_file, 'a') { |f| f.write("\n" + css.join()) } end - # Build Sphinx docs for Python + if not (ENV['SKIP_PYTHONDOC'] == '1') + # Build Sphinx docs for Python - puts "Moving to python/docs directory and building sphinx." - cd("../python/docs") - system("make html") || raise("Python doc generation failed") + puts "Moving to python/docs directory and building sphinx." + cd("../python/docs") + system("make html") || raise("Python doc generation failed") - puts "Moving back into home dir." - cd("../../") + puts "Moving back into docs dir." + cd("../../docs") + + puts "Making directory api/python" + mkdir_p "api/python" + + puts "cp -r ../python/docs/_build/html/. api/python" + cp_r("../python/docs/_build/html/.", "api/python") + end - puts "Making directory api/python" - mkdir_p "docs/api/python" + if not (ENV['SKIP_RDOC'] == '1') + # Build SparkR API docs - puts "cp -r python/docs/_build/html/. docs/api/python" - cp_r("python/docs/_build/html/.", "docs/api/python") + puts "Moving to R directory and building roxygen docs." + cd("../R") + system("./create-docs.sh") || raise("R doc generation failed") - # Build SparkR API docs - puts "Moving to R directory and building roxygen docs." - cd("R") - system("./create-docs.sh") || raise("R doc generation failed") + puts "Moving back into docs dir." + cd("../docs") - puts "Moving back into home dir." - cd("../") + puts "Making directory api/R" + mkdir_p "api/R" - puts "Making directory api/R" - mkdir_p "docs/api/R" + puts "cp -r ../R/pkg/html/. api/R" + cp_r("../R/pkg/html/.", "api/R") - puts "cp -r R/pkg/html/. docs/api/R" - cp_r("R/pkg/html/.", "docs/api/R") + puts "cp ../R/pkg/DESCRIPTION api" + cp("../R/pkg/DESCRIPTION", "api") + end end diff --git a/docs/configuration.md b/docs/configuration.md index d8800e93da..39bfb3a05b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -699,6 +699,15 @@ Apart from these, the following properties are also available, and may be useful This is the URL where your proxy is running. This URL is for proxy which is running in front of Spark Master. This is useful when running proxy for authentication e.g. OAuth proxy. Make sure this is a complete URL including scheme (http/https) and port to reach your proxy. + + spark.ui.showConsoleProgress + true + + Show the progress bar in the console. The progress bar shows the progress of stages + that run for longer than 500ms. If multiple stages run at the same time, multiple + progress bars will be displayed on the same line. + + spark.worker.ui.retainedExecutors 1000 @@ -1306,6 +1315,14 @@ Apart from these, the following properties are also available, and may be useful other "spark.blacklist" configuration options. + + spark.blacklist.timeout + 1h + + (Experimental) How long a node or executor is blacklisted for the entire application, before it + is unconditionally removed from the blacklist to attempt running new tasks. + + spark.blacklist.task.maxTaskAttemptsPerExecutor 1 @@ -1323,7 +1340,7 @@ Apart from these, the following properties are also available, and may be useful - spark.blacklist.stage.maxFailedTasksPerExecutor + spark.blacklist.stage.maxFailedTasksPerExecutor 2 (Experimental) How many different tasks must fail on one executor, within one stage, before the @@ -1338,6 +1355,28 @@ Apart from these, the following properties are also available, and may be useful the entire node is marked as failed for the stage. + + spark.blacklist.application.maxFailedTasksPerExecutor + 2 + + (Experimental) How many different tasks must fail on one executor, in successful task sets, + before the executor is blacklisted for the entire application. Blacklisted executors will + be automatically added back to the pool of available resources after the timeout specified by + spark.blacklist.timeout. Note that with dynamic allocation, though, the executors + may get marked as idle and be reclaimed by the cluster manager. + + + + spark.blacklist.application.maxFailedExecutorsPerNode + 2 + + (Experimental) How many different executors must be blacklisted for the entire application, + before the node is blacklisted for the entire application. Blacklisted nodes will + be automatically added back to the pool of available resources after the timeout specified by + spark.blacklist.timeout. Note that with dynamic allocation, though, the executors + on the node may get marked as idle and be reclaimed by the cluster manager. + + spark.speculation false @@ -1384,6 +1423,48 @@ Apart from these, the following properties are also available, and may be useful Should be greater than or equal to 1. Number of allowed retries = this value - 1. + + spark.task.reaper.enabled + false + + Enables monitoring of killed / interrupted tasks. When set to true, any task which is killed + will be monitored by the executor until that task actually finishes executing. See the other + spark.task.reaper.* configurations for details on how to control the exact behavior + of this monitoring. When set to false (the default), task killing will use an older code + path which lacks such monitoring. + + + + spark.task.reaper.pollingInterval + 10s + + When spark.task.reaper.enabled = true, this setting controls the frequency at which + executors will poll the status of killed tasks. If a killed task is still running when polled + then a warning will be logged and, by default, a thread-dump of the task will be logged + (this thread dump can be disabled via the spark.task.reaper.threadDump setting, + which is documented below). + + + + spark.task.reaper.threadDump + true + + When spark.task.reaper.enabled = true, this setting controls whether task thread + dumps are logged during periodic polling of killed tasks. Set this to false to disable + collection of thread dumps. + + + + spark.task.reaper.killTimeout + -1 + + When spark.task.reaper.enabled = true, this setting specifies a timeout after + which the executor JVM will kill itself if a killed task has not stopped running. The default + value, -1, disables this mechanism and prevents the executor from self-destructing. The purpose + of this setting is to act as a safety-net to prevent runaway uncancellable tasks from rendering + an executor unusable. + + #### Dynamic Allocation @@ -1549,14 +1630,15 @@ Apart from these, the following properties are also available, and may be useful - spark.authenticate.encryption.aes.enabled + spark.network.aes.enabled false - Enable AES for over-the-wire encryption + Enable AES for over-the-wire encryption. This is supported for RPC and the block transfer service. + This option has precedence over SASL-based encryption if both are enabled. - spark.authenticate.encryption.aes.cipher.keySize + spark.network.aes.keySize 16 The bytes of AES cipher key which is effective when AES cipher is enabled. AES @@ -1564,14 +1646,12 @@ Apart from these, the following properties are also available, and may be useful - spark.authenticate.encryption.aes.cipher.class - null + spark.network.aes.config.* + None - Specify the underlying implementation class of crypto cipher. Set null here to use default. - In order to use OpenSslCipher users should install openssl. Currently, there are two cipher - classes available in Commons Crypto library: - org.apache.commons.crypto.cipher.OpenSslCipher - org.apache.commons.crypto.cipher.JceCipher + Configuration values for the commons-crypto library, such as which cipher implementations to + use. The config name should be the name of commons-crypto configuration without the + "commons.crypto" prefix. @@ -1649,7 +1729,7 @@ Apart from these, the following properties are also available, and may be useful -#### Encryption +#### TLS / SSL diff --git a/docs/index.md b/docs/index.md index c5d34cb5c4..57b9fa848f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -131,7 +131,6 @@ options for deployment: **External Resources:** * [Spark Homepage](http://spark.apache.org) -* [Spark Wiki](https://cwiki.apache.org/confluence/display/SPARK) * [Spark Community](http://spark.apache.org/community.html) resources, including local meetups * [StackOverflow tag `apache-spark`](http://stackoverflow.com/questions/tagged/apache-spark) * [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here diff --git a/docs/ml-advanced.md b/docs/ml-advanced.md index 12a03d3c91..2747f2df7c 100644 --- a/docs/ml-advanced.md +++ b/docs/ml-advanced.md @@ -59,17 +59,25 @@ Given $n$ weighted observations $(w_i, a_i, b_i)$: The number of features for each observation is $m$. We use the following weighted least squares formulation: `\[ -minimize_{x}\frac{1}{2} \sum_{i=1}^n \frac{w_i(a_i^T x -b_i)^2}{\sum_{k=1}^n w_k} + \frac{1}{2}\frac{\lambda}{\delta}\sum_{j=1}^m(\sigma_{j} x_{j})^2 +\min_{\mathbf{x}}\frac{1}{2} \sum_{i=1}^n \frac{w_i(\mathbf{a}_i^T \mathbf{x} -b_i)^2}{\sum_{k=1}^n w_k} + \frac{\lambda}{\delta}\left[\frac{1}{2}(1 - \alpha)\sum_{j=1}^m(\sigma_j x_j)^2 + \alpha\sum_{j=1}^m |\sigma_j x_j|\right] \]` -where $\lambda$ is the regularization parameter, $\delta$ is the population standard deviation of the label +where $\lambda$ is the regularization parameter, $\alpha$ is the elastic-net mixing parameter, $\delta$ is the population standard deviation of the label and $\sigma_j$ is the population standard deviation of the j-th feature column. -This objective function has an analytic solution and it requires only one pass over the data to collect necessary statistics to solve. -Unlike the original dataset which can only be stored in a distributed system, -these statistics can be loaded into memory on a single machine if the number of features is relatively small, and then we can solve the objective function through Cholesky factorization on the driver. +This objective function requires only one pass over the data to collect the statistics necessary to solve it. For an +$n \times m$ data matrix, these statistics require only $O(m^2)$ storage and so can be stored on a single machine when $m$ (the number of features) is +relatively small. We can then solve the normal equations on a single machine using local methods like direct Cholesky factorization or iterative optimization programs. -WeightedLeastSquares only supports L2 regularization and provides options to enable or disable regularization and standardization. -In order to make the normal equation approach efficient, WeightedLeastSquares requires that the number of features be no more than 4096. For larger problems, use L-BFGS instead. +Spark MLlib currently supports two types of solvers for the normal equations: Cholesky factorization and Quasi-Newton methods (L-BFGS/OWL-QN). Cholesky factorization +depends on a positive definite covariance matrix (i.e. columns of the data matrix must be linearly independent) and will fail if this condition is violated. Quasi-Newton methods +are still capable of providing a reasonable solution even when the covariance matrix is not positive definite, so the normal equation solver can also fall back to +Quasi-Newton methods in this case. This fallback is currently always enabled for the `LinearRegression` and `GeneralizedLinearRegression` estimators. + +`WeightedLeastSquares` supports L1, L2, and elastic-net regularization and provides options to enable or disable regularization and standardization. In the case where no +L1 regularization is applied (i.e. $\alpha = 0$), there exists an analytical solution and either Cholesky or Quasi-Newton solver may be used. When $\alpha > 0$ no analytical +solution exists and we instead use the Quasi-Newton solver to find the coefficients iteratively. + +In order to make the normal equation approach efficient, `WeightedLeastSquares` requires that the number of features be no more than 4096. For larger problems, use L-BFGS instead. ## Iteratively reweighted least squares (IRLS) @@ -83,6 +91,6 @@ It solves certain optimization problems iteratively through the following proced * solve a weighted least squares (WLS) problem by WeightedLeastSquares. * repeat above steps until convergence. -Since it involves solving a weighted least squares (WLS) problem by WeightedLeastSquares in each iteration, +Since it involves solving a weighted least squares (WLS) problem by `WeightedLeastSquares` in each iteration, it also requires the number of features to be no more than 4096. Currently IRLS is used as the default solver of [GeneralizedLinearRegression](api/scala/index.html#org.apache.spark.ml.regression.GeneralizedLinearRegression). diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index 43cc79b9c0..782ee58188 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -75,6 +75,13 @@ More details on parameters can be found in the [Python API documentation](api/py {% include_example python/ml/logistic_regression_with_elastic_net.py %} +
+ +More details on parameters can be found in the [R API documentation](api/R/spark.logit.html). + +{% include_example binomial r/ml/logit.R %} +
+ The `spark.ml` implementation of logistic regression also supports @@ -114,9 +121,15 @@ Continuing the earlier example: {% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java %} -
-Logistic regression model summary is not yet supported in Python. +[`LogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionSummary) +provides a summary for a +[`LogisticRegressionModel`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionModel). +Currently, only binary classification is supported. Support for multiclass model summaries will be added in the future. + +Continuing the earlier example: + +{% include_example python/ml/logistic_regression_summary_example.py %}
@@ -165,6 +178,13 @@ model with elastic net regularization. {% include_example python/ml/multiclass_logistic_regression_with_elastic_net.py %} +
+ +More details on parameters can be found in the [R API documentation](api/R/spark.logit.html). + +{% include_example multinomial r/ml/logit.R %} +
+ @@ -236,6 +256,14 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classificat {% include_example python/ml/random_forest_classifier_example.py %} + +
+ +Refer to the [R API docs](api/R/spark.randomForest.html) for more details. + +{% include_example classification r/ml/randomForest.R %} +
+ ## Gradient-boosted tree classifier @@ -269,6 +297,14 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classificat {% include_example python/ml/gradient_boosted_tree_classifier_example.py %} + +
+ +Refer to the [R API docs](api/R/spark.gbt.html) for more details. + +{% include_example classification r/ml/gbt.R %} +
+ ## Multilayer perceptron classifier @@ -318,6 +354,13 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classificat {% include_example python/ml/multilayer_perceptron_classification.py %} +
+ +Refer to the [R API docs](api/R/spark.mlp.html) for more details. + +{% include_example r/ml/mlp.R %} +
+ @@ -389,6 +432,14 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classificat {% include_example python/ml/naive_bayes_example.py %} + +
+ +Refer to the [R API docs](api/R/spark.naiveBayes.html) for more details. + +{% include_example r/ml/naiveBayes.R %} +
+ @@ -566,6 +617,13 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression. {% include_example python/ml/generalized_linear_regression_example.py %} +
+ +Refer to the [R API docs](api/R/spark.glm.html) for more details. + +{% include_example r/ml/glm.R %} +
+ @@ -635,6 +693,14 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression. {% include_example python/ml/random_forest_regressor_example.py %} + +
+ +Refer to the [R API docs](api/R/spark.randomForest.html) for more details. + +{% include_example regression r/ml/randomForest.R %} +
+ ## Gradient-boosted tree regression @@ -668,6 +734,14 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression. {% include_example python/ml/gradient_boosted_tree_regressor_example.py %} + +
+ +Refer to the [R API docs](api/R/spark.gbt.html) for more details. + +{% include_example regression r/ml/gbt.R %} +
+ @@ -755,6 +829,13 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression. {% include_example python/ml/aft_survival_regression.py %} +
+ +Refer to the [R API docs](api/R/spark.survreg.html) for more details. + +{% include_example r/ml/survreg.R %} +
+ @@ -825,6 +906,14 @@ Refer to the [`IsotonicRegression` Python docs](api/python/pyspark.ml.html#pyspa {% include_example python/ml/isotonic_regression_example.py %} + +
+ +Refer to the [`IsotonicRegression` R API docs](api/R/spark.isoreg.html) for more details on the API. + +{% include_example r/ml/isoreg.R %} +
+ # Linear methods diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md index eedacb12bc..d8b6553c5b 100644 --- a/docs/ml-clustering.md +++ b/docs/ml-clustering.md @@ -86,6 +86,14 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering. {% include_example python/ml/kmeans_example.py %} + +
+ +Refer to the [R API docs](api/R/spark.kmeans.html) for more details. + +{% include_example r/ml/kmeans.R %} +
+ ## Latent Dirichlet allocation (LDA) @@ -118,6 +126,14 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering. {% include_example python/ml/lda_example.py %} + +
+ +Refer to the [R API docs](api/R/spark.lda.html) for more details. + +{% include_example r/ml/lda.R %} +
+ ## Bisecting k-means @@ -233,4 +249,12 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering. {% include_example python/ml/gaussian_mixture_example.py %} + +
+ +Refer to the [R API docs](api/R/spark.gaussianMixture.html) for more details. + +{% include_example r/ml/gaussianMixture.R %} +
+ diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md index 4d19b4069a..cfe835172a 100644 --- a/docs/ml-collaborative-filtering.md +++ b/docs/ml-collaborative-filtering.md @@ -149,4 +149,12 @@ als = ALS(maxIter=5, regParam=0.01, implicitPrefs=True, {% endhighlight %} + +
+ +Refer to the [R API docs](api/R/spark.als.html) for more details. + +{% include_example r/ml/als.R %} +
+ diff --git a/docs/ml-features.md b/docs/ml-features.md index 53c822c335..ca1ccc4050 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -9,6 +9,7 @@ This section covers algorithms for working with features, roughly divided into t * Extraction: Extracting features from "raw" data * Transformation: Scaling, converting, or modifying features * Selection: Selecting a subset from a larger set of features +* Locality Sensitive Hashing (LSH): This class of algorithms combines aspects of feature transformation with other algorithms. **Table of Contents** @@ -1480,3 +1481,113 @@ for more details on the API. {% include_example python/ml/chisq_selector_example.py %} + +# Locality Sensitive Hashing +[Locality Sensitive Hashing (LSH)](https://en.wikipedia.org/wiki/Locality-sensitive_hashing) is an important class of hashing techniques, which is commonly used in clustering, approximate nearest neighbor search and outlier detection with large datasets. + +The general idea of LSH is to use a family of functions ("LSH families") to hash data points into buckets, so that the data points which are close to each other are in the same buckets with high probability, while data points that are far away from each other are very likely in different buckets. An LSH family is formally defined as follows. + +In a metric space `(M, d)`, where `M` is a set and `d` is a distance function on `M`, an LSH family is a family of functions `h` that satisfy the following properties: +`\[ +\forall p, q \in M,\\ +d(p,q) \leq r1 \Rightarrow Pr(h(p)=h(q)) \geq p1\\ +d(p,q) \geq r2 \Rightarrow Pr(h(p)=h(q)) \leq p2 +\]` +This LSH family is called `(r1, r2, p1, p2)`-sensitive. + +In Spark, different LSH families are implemented in separate classes (e.g., `MinHash`), and APIs for feature transformation, approximate similarity join and approximate nearest neighbor are provided in each class. + +In LSH, we define a false positive as a pair of distant input features (with `$d(p,q) \geq r2$`) which are hashed into the same bucket, and we define a false negative as a pair of nearby features (with `$d(p,q) \leq r1$`) which are hashed into different buckets. + +## LSH Operations + +We describe the major types of operations which LSH can be used for. A fitted LSH model has methods for each of these operations. + +### Feature Transformation +Feature transformation is the basic functionality to add hashed values as a new column. This can be useful for dimensionality reduction. Users can specify input and output column names by setting `inputCol` and `outputCol`. + +LSH also supports multiple LSH hash tables. Users can specify the number of hash tables by setting `numHashTables`. This is also used for [OR-amplification](https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Amplification) in approximate similarity join and approximate nearest neighbor. Increasing the number of hash tables will increase the accuracy but will also increase communication cost and running time. + +The type of `outputCol` is `Seq[Vector]` where the dimension of the array equals `numHashTables`, and the dimensions of the vectors are currently set to 1. In future releases, we will implement AND-amplification so that users can specify the dimensions of these vectors. + +### Approximate Similarity Join +Approximate similarity join takes two datasets and approximately returns pairs of rows in the datasets whose distance is smaller than a user-defined threshold. Approximate similarity join supports both joining two different datasets and self-joining. Self-joining will produce some duplicate pairs. + +Approximate similarity join accepts both transformed and untransformed datasets as input. If an untransformed dataset is used, it will be transformed automatically. In this case, the hash signature will be created as `outputCol`. + +In the joined dataset, the origin datasets can be queried in `datasetA` and `datasetB`. A distance column will be added to the output dataset to show the true distance between each pair of rows returned. + +### Approximate Nearest Neighbor Search +Approximate nearest neighbor search takes a dataset (of feature vectors) and a key (a single feature vector), and it approximately returns a specified number of rows in the dataset that are closest to the vector. + +Approximate nearest neighbor search accepts both transformed and untransformed datasets as input. If an untransformed dataset is used, it will be transformed automatically. In this case, the hash signature will be created as `outputCol`. + +A distance column will be added to the output dataset to show the true distance between each output row and the searched key. + +**Note:** Approximate nearest neighbor search will return fewer than `k` rows when there are not enough candidates in the hash bucket. + +## LSH Algorithms + +### Bucketed Random Projection for Euclidean Distance + +[Bucketed Random Projection](https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Stable_distributions) is an LSH family for Euclidean distance. The Euclidean distance is defined as follows: +`\[ +d(\mathbf{x}, \mathbf{y}) = \sqrt{\sum_i (x_i - y_i)^2} +\]` +Its LSH family projects feature vectors `$\mathbf{x}$` onto a random unit vector `$\mathbf{v}$` and portions the projected results into hash buckets: +`\[ +h(\mathbf{x}) = \Big\lfloor \frac{\mathbf{x} \cdot \mathbf{v}}{r} \Big\rfloor +\]` +where `r` is a user-defined bucket length. The bucket length can be used to control the average size of hash buckets (and thus the number of buckets). A larger bucket length (i.e., fewer buckets) increases the probability of features being hashed to the same bucket (increasing the numbers of true and false positives). + +Bucketed Random Projection accepts arbitrary vectors as input features, and supports both sparse and dense vectors. + +
+
+ +Refer to the [BucketedRandomProjectionLSH Scala docs](api/scala/index.html#org.apache.spark.ml.feature.BucketedRandomProjectionLSH) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/BucketedRandomProjectionLSHExample.scala %} +
+ +
+ +Refer to the [BucketedRandomProjectionLSH Java docs](api/java/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java %} +
+
+ +### MinHash for Jaccard Distance +[MinHash](https://en.wikipedia.org/wiki/MinHash) is an LSH family for Jaccard distance where input features are sets of natural numbers. Jaccard distance of two sets is defined by the cardinality of their intersection and union: +`\[ +d(\mathbf{A}, \mathbf{B}) = 1 - \frac{|\mathbf{A} \cap \mathbf{B}|}{|\mathbf{A} \cup \mathbf{B}|} +\]` +MinHash applies a random hash function `g` to each element in the set and take the minimum of all hashed values: +`\[ +h(\mathbf{A}) = \min_{a \in \mathbf{A}}(g(a)) +\]` + +The input sets for MinHash are represented as binary vectors, where the vector indices represent the elements themselves and the non-zero values in the vector represent the presence of that element in the set. While both dense and sparse vectors are supported, typically sparse vectors are recommended for efficiency. For example, `Vectors.sparse(10, Array[(2, 1.0), (3, 1.0), (5, 1.0)])` means there are 10 elements in the space. This set contains elem 2, elem 3 and elem 5. All non-zero values are treated as binary "1" values. + +**Note:** Empty sets cannot be transformed by MinHash, which means any input vector must have at least 1 non-zero entry. + +
+
+ +Refer to the [MinHashLSH Scala docs](api/scala/index.html#org.apache.spark.ml.feature.MinHashLSH) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/MinHashLSHExample.scala %} +
+ +
+ +Refer to the [MinHashLSH Java docs](api/java/org/apache/spark/ml/feature/MinHashLSH.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java %} +
+
diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 4607ad3ba6..971761961b 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -35,6 +35,18 @@ The primary Machine Learning API for Spark is now the [DataFrame](sql-programmin * The DataFrame-based API for MLlib provides a uniform API across ML algorithms and across multiple languages. * DataFrames facilitate practical ML Pipelines, particularly feature transformations. See the [Pipelines guide](ml-pipeline.html) for details. +*What is "Spark ML"?* + +* "Spark ML" is not an official name but occasionally used to refer to the MLlib DataFrame-based API. + This is majorly due to the `org.apache.spark.ml` Scala package name used by the DataFrame-based API, + and the "Spark ML Pipelines" term we used initially to emphasize the pipeline concept. + +*Is MLlib deprecated?* + +* No. MLlib includes both the RDD-based API and the DataFrame-based API. + The RDD-based API is now in maintenance mode. + But neither API is deprecated, nor MLlib as a whole. + # Dependencies MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), which depends on @@ -60,152 +72,34 @@ MLlib is under active development. The APIs marked `Experimental`/`DeveloperApi` may change in future releases, and the migration guide below will explain all changes between releases. -## From 1.6 to 2.0 +## From 2.0 to 2.1 ### Breaking changes - -There were several breaking changes in Spark 2.0, which are outlined below. - -**Linear algebra classes for DataFrame-based APIs** - -Spark's linear algebra dependencies were moved to a new project, `mllib-local` -(see [SPARK-13944](https://issues.apache.org/jira/browse/SPARK-13944)). -As part of this change, the linear algebra classes were copied to a new package, `spark.ml.linalg`. -The DataFrame-based APIs in `spark.ml` now depend on the `spark.ml.linalg` classes, -leading to a few breaking changes, predominantly in various model classes -(see [SPARK-14810](https://issues.apache.org/jira/browse/SPARK-14810) for a full list). - -**Note:** the RDD-based APIs in `spark.mllib` continue to depend on the previous package `spark.mllib.linalg`. - -_Converting vectors and matrices_ - -While most pipeline components support backward compatibility for loading, -some existing `DataFrames` and pipelines in Spark versions prior to 2.0, that contain vector or matrix -columns, may need to be migrated to the new `spark.ml` vector and matrix types. -Utilities for converting `DataFrame` columns from `spark.mllib.linalg` to `spark.ml.linalg` types -(and vice versa) can be found in `spark.mllib.util.MLUtils`. - -There are also utility methods available for converting single instances of -vectors and matrices. Use the `asML` method on a `mllib.linalg.Vector` / `mllib.linalg.Matrix` -for converting to `ml.linalg` types, and -`mllib.linalg.Vectors.fromML` / `mllib.linalg.Matrices.fromML` -for converting to `mllib.linalg` types. - -
-
- -{% highlight scala %} -import org.apache.spark.mllib.util.MLUtils - -// convert DataFrame columns -val convertedVecDF = MLUtils.convertVectorColumnsToML(vecDF) -val convertedMatrixDF = MLUtils.convertMatrixColumnsToML(matrixDF) -// convert a single vector or matrix -val mlVec: org.apache.spark.ml.linalg.Vector = mllibVec.asML -val mlMat: org.apache.spark.ml.linalg.Matrix = mllibMat.asML -{% endhighlight %} - -Refer to the [`MLUtils` Scala docs](api/scala/index.html#org.apache.spark.mllib.util.MLUtils$) for further detail. -
- -
- -{% highlight java %} -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.sql.Dataset; - -// convert DataFrame columns -Dataset convertedVecDF = MLUtils.convertVectorColumnsToML(vecDF); -Dataset convertedMatrixDF = MLUtils.convertMatrixColumnsToML(matrixDF); -// convert a single vector or matrix -org.apache.spark.ml.linalg.Vector mlVec = mllibVec.asML(); -org.apache.spark.ml.linalg.Matrix mlMat = mllibMat.asML(); -{% endhighlight %} - -Refer to the [`MLUtils` Java docs](api/java/org/apache/spark/mllib/util/MLUtils.html) for further detail. -
- -
- -{% highlight python %} -from pyspark.mllib.util import MLUtils - -# convert DataFrame columns -convertedVecDF = MLUtils.convertVectorColumnsToML(vecDF) -convertedMatrixDF = MLUtils.convertMatrixColumnsToML(matrixDF) -# convert a single vector or matrix -mlVec = mllibVec.asML() -mlMat = mllibMat.asML() -{% endhighlight %} - -Refer to the [`MLUtils` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.util.MLUtils) for further detail. -
-
- + **Deprecated methods removed** -Several deprecated methods were removed in the `spark.mllib` and `spark.ml` packages: - -* `setScoreCol` in `ml.evaluation.BinaryClassificationEvaluator` -* `weights` in `LinearRegression` and `LogisticRegression` in `spark.ml` -* `setMaxNumIterations` in `mllib.optimization.LBFGS` (marked as `DeveloperApi`) -* `treeReduce` and `treeAggregate` in `mllib.rdd.RDDFunctions` (these functions are available on `RDD`s directly, and were marked as `DeveloperApi`) -* `defaultStategy` in `mllib.tree.configuration.Strategy` -* `build` in `mllib.tree.Node` -* libsvm loaders for multiclass and load/save labeledData methods in `mllib.util.MLUtils` - -A full list of breaking changes can be found at [SPARK-14810](https://issues.apache.org/jira/browse/SPARK-14810). +* `setLabelCol` in `feature.ChiSqSelectorModel` +* `numTrees` in `classification.RandomForestClassificationModel` (This now refers to the Param called `numTrees`) +* `numTrees` in `regression.RandomForestRegressionModel` (This now refers to the Param called `numTrees`) +* `model` in `regression.LinearRegressionSummary` +* `validateParams` in `PipelineStage` +* `validateParams` in `Evaluator` ### Deprecations and changes of behavior **Deprecations** -Deprecations in the `spark.mllib` and `spark.ml` packages include: - -* [SPARK-14984](https://issues.apache.org/jira/browse/SPARK-14984): - In `spark.ml.regression.LinearRegressionSummary`, the `model` field has been deprecated. -* [SPARK-13784](https://issues.apache.org/jira/browse/SPARK-13784): - In `spark.ml.regression.RandomForestRegressionModel` and `spark.ml.classification.RandomForestClassificationModel`, - the `numTrees` parameter has been deprecated in favor of `getNumTrees` method. -* [SPARK-13761](https://issues.apache.org/jira/browse/SPARK-13761): - In `spark.ml.param.Params`, the `validateParams` method has been deprecated. - We move all functionality in overridden methods to the corresponding `transformSchema`. -* [SPARK-14829](https://issues.apache.org/jira/browse/SPARK-14829): - In `spark.mllib` package, `LinearRegressionWithSGD`, `LassoWithSGD`, `RidgeRegressionWithSGD` and `LogisticRegressionWithSGD` have been deprecated. - We encourage users to use `spark.ml.regression.LinearRegresson` and `spark.ml.classification.LogisticRegresson`. -* [SPARK-14900](https://issues.apache.org/jira/browse/SPARK-14900): - In `spark.mllib.evaluation.MulticlassMetrics`, the parameters `precision`, `recall` and `fMeasure` have been deprecated in favor of `accuracy`. -* [SPARK-15644](https://issues.apache.org/jira/browse/SPARK-15644): - In `spark.ml.util.MLReader` and `spark.ml.util.MLWriter`, the `context` method has been deprecated in favor of `session`. -* In `spark.ml.feature.ChiSqSelectorModel`, the `setLabelCol` method has been deprecated since it was not used by `ChiSqSelectorModel`. +* [SPARK-18592](https://issues.apache.org/jira/browse/SPARK-18592): + Deprecate all Param setter methods except for input/output column Params for `DecisionTreeClassificationModel`, `GBTClassificationModel`, `RandomForestClassificationModel`, `DecisionTreeRegressionModel`, `GBTRegressionModel` and `RandomForestRegressionModel` **Changes of behavior** -Changes of behavior in the `spark.mllib` and `spark.ml` packages include: - -* [SPARK-7780](https://issues.apache.org/jira/browse/SPARK-7780): - `spark.mllib.classification.LogisticRegressionWithLBFGS` directly calls `spark.ml.classification.LogisticRegresson` for binary classification now. - This will introduce the following behavior changes for `spark.mllib.classification.LogisticRegressionWithLBFGS`: - * The intercept will not be regularized when training binary classification model with L1/L2 Updater. - * If users set without regularization, training with or without feature scaling will return the same solution by the same convergence rate. -* [SPARK-13429](https://issues.apache.org/jira/browse/SPARK-13429): - In order to provide better and consistent result with `spark.ml.classification.LogisticRegresson`, - the default value of `spark.mllib.classification.LogisticRegressionWithLBFGS`: `convergenceTol` has been changed from 1E-4 to 1E-6. -* [SPARK-12363](https://issues.apache.org/jira/browse/SPARK-12363): - Fix a bug of `PowerIterationClustering` which will likely change its result. -* [SPARK-13048](https://issues.apache.org/jira/browse/SPARK-13048): - `LDA` using the `EM` optimizer will keep the last checkpoint by default, if checkpointing is being used. -* [SPARK-12153](https://issues.apache.org/jira/browse/SPARK-12153): - `Word2Vec` now respects sentence boundaries. Previously, it did not handle them correctly. -* [SPARK-10574](https://issues.apache.org/jira/browse/SPARK-10574): - `HashingTF` uses `MurmurHash3` as default hash algorithm in both `spark.ml` and `spark.mllib`. -* [SPARK-14768](https://issues.apache.org/jira/browse/SPARK-14768): - The `expectedType` argument for PySpark `Param` was removed. -* [SPARK-14931](https://issues.apache.org/jira/browse/SPARK-14931): - Some default `Param` values, which were mismatched between pipelines in Scala and Python, have been changed. -* [SPARK-13600](https://issues.apache.org/jira/browse/SPARK-13600): - `QuantileDiscretizer` now uses `spark.sql.DataFrameStatFunctions.approxQuantile` to find splits (previously used custom sampling logic). - The output buckets will differ for same input data and params. +* [SPARK-17870](https://issues.apache.org/jira/browse/SPARK-17870): + Fix a bug of `ChiSqSelector` which will likely change its result. Now `ChiSquareSelector` use pValue rather than raw statistic to select a fixed number of top features. +* [SPARK-3261](https://issues.apache.org/jira/browse/SPARK-3261): + `KMeans` returns potentially fewer than k cluster centers in cases where k distinct centroids aren't available or aren't selected. +* [SPARK-17389](https://issues.apache.org/jira/browse/SPARK-17389): + `KMeans` reduces the default number of steps from 5 to 2 for the k-means|| initialization mode. ## Previous Spark versions diff --git a/docs/ml-migration-guides.md b/docs/ml-migration-guides.md index 82bf9d7760..58c3747ea6 100644 --- a/docs/ml-migration-guides.md +++ b/docs/ml-migration-guides.md @@ -7,6 +7,153 @@ description: MLlib migration guides from before Spark SPARK_VERSION_SHORT The migration guide for the current Spark version is kept on the [MLlib Guide main page](ml-guide.html#migration-guide). +## From 1.6 to 2.0 + +### Breaking changes + +There were several breaking changes in Spark 2.0, which are outlined below. + +**Linear algebra classes for DataFrame-based APIs** + +Spark's linear algebra dependencies were moved to a new project, `mllib-local` +(see [SPARK-13944](https://issues.apache.org/jira/browse/SPARK-13944)). +As part of this change, the linear algebra classes were copied to a new package, `spark.ml.linalg`. +The DataFrame-based APIs in `spark.ml` now depend on the `spark.ml.linalg` classes, +leading to a few breaking changes, predominantly in various model classes +(see [SPARK-14810](https://issues.apache.org/jira/browse/SPARK-14810) for a full list). + +**Note:** the RDD-based APIs in `spark.mllib` continue to depend on the previous package `spark.mllib.linalg`. + +_Converting vectors and matrices_ + +While most pipeline components support backward compatibility for loading, +some existing `DataFrames` and pipelines in Spark versions prior to 2.0, that contain vector or matrix +columns, may need to be migrated to the new `spark.ml` vector and matrix types. +Utilities for converting `DataFrame` columns from `spark.mllib.linalg` to `spark.ml.linalg` types +(and vice versa) can be found in `spark.mllib.util.MLUtils`. + +There are also utility methods available for converting single instances of +vectors and matrices. Use the `asML` method on a `mllib.linalg.Vector` / `mllib.linalg.Matrix` +for converting to `ml.linalg` types, and +`mllib.linalg.Vectors.fromML` / `mllib.linalg.Matrices.fromML` +for converting to `mllib.linalg` types. + +
+
+ +{% highlight scala %} +import org.apache.spark.mllib.util.MLUtils + +// convert DataFrame columns +val convertedVecDF = MLUtils.convertVectorColumnsToML(vecDF) +val convertedMatrixDF = MLUtils.convertMatrixColumnsToML(matrixDF) +// convert a single vector or matrix +val mlVec: org.apache.spark.ml.linalg.Vector = mllibVec.asML +val mlMat: org.apache.spark.ml.linalg.Matrix = mllibMat.asML +{% endhighlight %} + +Refer to the [`MLUtils` Scala docs](api/scala/index.html#org.apache.spark.mllib.util.MLUtils$) for further detail. +
+ +
+ +{% highlight java %} +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.sql.Dataset; + +// convert DataFrame columns +Dataset convertedVecDF = MLUtils.convertVectorColumnsToML(vecDF); +Dataset convertedMatrixDF = MLUtils.convertMatrixColumnsToML(matrixDF); +// convert a single vector or matrix +org.apache.spark.ml.linalg.Vector mlVec = mllibVec.asML(); +org.apache.spark.ml.linalg.Matrix mlMat = mllibMat.asML(); +{% endhighlight %} + +Refer to the [`MLUtils` Java docs](api/java/org/apache/spark/mllib/util/MLUtils.html) for further detail. +
+ +
+ +{% highlight python %} +from pyspark.mllib.util import MLUtils + +# convert DataFrame columns +convertedVecDF = MLUtils.convertVectorColumnsToML(vecDF) +convertedMatrixDF = MLUtils.convertMatrixColumnsToML(matrixDF) +# convert a single vector or matrix +mlVec = mllibVec.asML() +mlMat = mllibMat.asML() +{% endhighlight %} + +Refer to the [`MLUtils` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.util.MLUtils) for further detail. +
+
+ +**Deprecated methods removed** + +Several deprecated methods were removed in the `spark.mllib` and `spark.ml` packages: + +* `setScoreCol` in `ml.evaluation.BinaryClassificationEvaluator` +* `weights` in `LinearRegression` and `LogisticRegression` in `spark.ml` +* `setMaxNumIterations` in `mllib.optimization.LBFGS` (marked as `DeveloperApi`) +* `treeReduce` and `treeAggregate` in `mllib.rdd.RDDFunctions` (these functions are available on `RDD`s directly, and were marked as `DeveloperApi`) +* `defaultStategy` in `mllib.tree.configuration.Strategy` +* `build` in `mllib.tree.Node` +* libsvm loaders for multiclass and load/save labeledData methods in `mllib.util.MLUtils` + +A full list of breaking changes can be found at [SPARK-14810](https://issues.apache.org/jira/browse/SPARK-14810). + +### Deprecations and changes of behavior + +**Deprecations** + +Deprecations in the `spark.mllib` and `spark.ml` packages include: + +* [SPARK-14984](https://issues.apache.org/jira/browse/SPARK-14984): + In `spark.ml.regression.LinearRegressionSummary`, the `model` field has been deprecated. +* [SPARK-13784](https://issues.apache.org/jira/browse/SPARK-13784): + In `spark.ml.regression.RandomForestRegressionModel` and `spark.ml.classification.RandomForestClassificationModel`, + the `numTrees` parameter has been deprecated in favor of `getNumTrees` method. +* [SPARK-13761](https://issues.apache.org/jira/browse/SPARK-13761): + In `spark.ml.param.Params`, the `validateParams` method has been deprecated. + We move all functionality in overridden methods to the corresponding `transformSchema`. +* [SPARK-14829](https://issues.apache.org/jira/browse/SPARK-14829): + In `spark.mllib` package, `LinearRegressionWithSGD`, `LassoWithSGD`, `RidgeRegressionWithSGD` and `LogisticRegressionWithSGD` have been deprecated. + We encourage users to use `spark.ml.regression.LinearRegresson` and `spark.ml.classification.LogisticRegresson`. +* [SPARK-14900](https://issues.apache.org/jira/browse/SPARK-14900): + In `spark.mllib.evaluation.MulticlassMetrics`, the parameters `precision`, `recall` and `fMeasure` have been deprecated in favor of `accuracy`. +* [SPARK-15644](https://issues.apache.org/jira/browse/SPARK-15644): + In `spark.ml.util.MLReader` and `spark.ml.util.MLWriter`, the `context` method has been deprecated in favor of `session`. +* In `spark.ml.feature.ChiSqSelectorModel`, the `setLabelCol` method has been deprecated since it was not used by `ChiSqSelectorModel`. + +**Changes of behavior** + +Changes of behavior in the `spark.mllib` and `spark.ml` packages include: + +* [SPARK-7780](https://issues.apache.org/jira/browse/SPARK-7780): + `spark.mllib.classification.LogisticRegressionWithLBFGS` directly calls `spark.ml.classification.LogisticRegresson` for binary classification now. + This will introduce the following behavior changes for `spark.mllib.classification.LogisticRegressionWithLBFGS`: + * The intercept will not be regularized when training binary classification model with L1/L2 Updater. + * If users set without regularization, training with or without feature scaling will return the same solution by the same convergence rate. +* [SPARK-13429](https://issues.apache.org/jira/browse/SPARK-13429): + In order to provide better and consistent result with `spark.ml.classification.LogisticRegresson`, + the default value of `spark.mllib.classification.LogisticRegressionWithLBFGS`: `convergenceTol` has been changed from 1E-4 to 1E-6. +* [SPARK-12363](https://issues.apache.org/jira/browse/SPARK-12363): + Fix a bug of `PowerIterationClustering` which will likely change its result. +* [SPARK-13048](https://issues.apache.org/jira/browse/SPARK-13048): + `LDA` using the `EM` optimizer will keep the last checkpoint by default, if checkpointing is being used. +* [SPARK-12153](https://issues.apache.org/jira/browse/SPARK-12153): + `Word2Vec` now respects sentence boundaries. Previously, it did not handle them correctly. +* [SPARK-10574](https://issues.apache.org/jira/browse/SPARK-10574): + `HashingTF` uses `MurmurHash3` as default hash algorithm in both `spark.ml` and `spark.mllib`. +* [SPARK-14768](https://issues.apache.org/jira/browse/SPARK-14768): + The `expectedType` argument for PySpark `Param` was removed. +* [SPARK-14931](https://issues.apache.org/jira/browse/SPARK-14931): + Some default `Param` values, which were mismatched between pipelines in Scala and Python, have been changed. +* [SPARK-13600](https://issues.apache.org/jira/browse/SPARK-13600): + `QuantileDiscretizer` now uses `spark.sql.DataFrameStatFunctions.approxQuantile` to find splits (previously used custom sampling logic). + The output buckets will differ for same input data and params. + ## From 1.5 to 1.6 There are no breaking API changes in the `spark.mllib` or `spark.ml` packages, but there are diff --git a/docs/monitoring.md b/docs/monitoring.md index 2eef4568d0..7a1de52668 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -44,10 +44,8 @@ The spark jobs themselves must be configured to log events, and to log them to t writable directory. For example, if the server was configured with a log directory of `hdfs://namenode/shared/spark-logs`, then the client-side options would be: -``` -spark.eventLog.enabled true -spark.eventLog.dir hdfs://namenode/shared/spark-logs -``` + spark.eventLog.enabled true + spark.eventLog.dir hdfs://namenode/shared/spark-logs The history server can be configured as follows: diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 4267b8cae8..a4017b5b97 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -347,7 +347,7 @@ Some notes on reading files with Spark: Apart from text files, Spark's Scala API also supports several other data formats: -* `SparkContext.wholeTextFiles` lets you read a directory containing multiple small text files, and returns each of them as (filename, content) pairs. This is in contrast with `textFile`, which would return one record per line in each file. +* `SparkContext.wholeTextFiles` lets you read a directory containing multiple small text files, and returns each of them as (filename, content) pairs. This is in contrast with `textFile`, which would return one record per line in each file. Partitioning is determined by data locality which, in some cases, may result in too few partitions. For those cases, `wholeTextFiles` provides an optional second argument for controlling the minimal number of partitions. * For [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), use SparkContext's `sequenceFile[K, V]` method where `K` and `V` are the types of key and values in the file. These should be subclasses of Hadoop's [Writable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Writable.html) interface, like [IntWritable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/IntWritable.html) and [Text](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Text.html). In addition, Spark allows you to specify native types for a few common Writables; for example, `sequenceFile[Int, String]` will automatically read IntWritables and Texts. @@ -1345,14 +1345,15 @@ therefore be efficiently supported in parallel. They can be used to implement co MapReduce) or sums. Spark natively supports accumulators of numeric types, and programmers can add support for new types. -If accumulators are created with a name, they will be -displayed in Spark's UI. This can be useful for understanding the progress of -running stages (NOTE: this is not yet supported in Python). +As a user, you can create named or unnamed accumulators. As seen in the image below, a named accumulator (in this instance `counter`) will display in the web UI for the stage that modifies that accumulator. Spark displays the value for each accumulator modified by a task in the "Tasks" table.

Accumulators in the Spark UI

+Tracking accumulators in the UI can be useful for understanding the progress of +running stages (NOTE: this is not yet supported in Python). +
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 4d1fafc07b..d4144c86e9 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -518,7 +518,7 @@ instructions: pre-packaged distribution. 1. Locate the `spark--yarn-shuffle.jar`. This should be under `$SPARK_HOME/common/network-yarn/target/scala-` if you are building Spark yourself, and under -`lib` if you are using a distribution. +`yarn` if you are using a distribution. 1. Add this jar to the classpath of all `NodeManager`s in your cluster. 1. In the `yarn-site.xml` on each node, add `spark_shuffle` to `yarn.nodemanager.aux-services`, then set `yarn.nodemanager.aux-services.spark_shuffle.class` to diff --git a/docs/sparkr.md b/docs/sparkr.md index d26949226b..d7ffd9b3f1 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -40,7 +40,9 @@ sparkR.session() You can also start SparkR from RStudio. You can connect your R program to a Spark cluster from RStudio, R shell, Rscript or other R IDEs. To start, make sure SPARK_HOME is set in environment (you can check [Sys.getenv](https://stat.ethz.ch/R-manual/R-devel/library/base/html/Sys.getenv.html)), -load the SparkR package, and call `sparkR.session` as below. In addition to calling `sparkR.session`, +load the SparkR package, and call `sparkR.session` as below. It will check for the Spark installation, and, if not found, it will be downloaded and cached automatically. Alternatively, you can also run `install.spark` manually. + +In addition to calling `sparkR.session`, you could also specify certain Spark driver properties. Normally these [Application properties](configuration.html#application-properties) and [Runtime Environment](configuration.html#runtime-environment) cannot be set programmatically, as the @@ -510,39 +512,50 @@ head(teenagers) # Machine Learning -SparkR supports the following machine learning algorithms currently: `Generalized Linear Model`, `Accelerated Failure Time (AFT) Survival Regression Model`, `Naive Bayes Model` and `KMeans Model`. -Under the hood, SparkR uses MLlib to train the model. -Users can call `summary` to print a summary of the fitted model, [predict](api/R/predict.html) to make predictions on new data, and [write.ml](api/R/write.ml.html)/[read.ml](api/R/read.ml.html) to save/load fitted models. -SparkR supports a subset of the available R formula operators for model fitting, including ‘~’, ‘.’, ‘:’, ‘+’, and ‘-‘. - ## Algorithms -### Generalized Linear Model +SparkR supports the following machine learning algorithms currently: + +#### Classification + +* [`spark.logit`](api/R/spark.logit.html): [`Logistic Regression`](ml-classification-regression.html#logistic-regression) +* [`spark.mlp`](api/R/spark.mlp.html): [`Multilayer Perceptron (MLP)`](ml-classification-regression.html#multilayer-perceptron-classifier) +* [`spark.naiveBayes`](api/R/spark.naiveBayes.html): [`Naive Bayes`](ml-classification-regression.html#naive-bayes) + +#### Regression -[spark.glm()](api/R/spark.glm.html) or [glm()](api/R/glm.html) fits generalized linear model against a Spark DataFrame. -Currently "gaussian", "binomial", "poisson" and "gamma" families are supported. -{% include_example glm r/ml.R %} +* [`spark.survreg`](api/R/spark.survreg.html): [`Accelerated Failure Time (AFT) Survival Model`](ml-classification-regression.html#survival-regression) +* [`spark.glm`](api/R/spark.glm.html) or [`glm`](api/R/glm.html): [`Generalized Linear Model (GLM)`](ml-classification-regression.html#generalized-linear-regression) +* [`spark.isoreg`](api/R/spark.isoreg.html): [`Isotonic Regression`](ml-classification-regression.html#isotonic-regression) -### Accelerated Failure Time (AFT) Survival Regression Model +#### Tree -[spark.survreg()](api/R/spark.survreg.html) fits an accelerated failure time (AFT) survival regression model on a SparkDataFrame. -Note that the formula of [spark.survreg()](api/R/spark.survreg.html) does not support operator '.' currently. -{% include_example survreg r/ml.R %} +* [`spark.gbt`](api/R/spark.gbt.html): `Gradient Boosted Trees for` [`Regression`](ml-classification-regression.html#gradient-boosted-tree-regression) `and` [`Classification`](ml-classification-regression.html#gradient-boosted-tree-classifier) +* [`spark.randomForest`](api/R/spark.randomForest.html): `Random Forest for` [`Regression`](ml-classification-regression.html#random-forest-regression) `and` [`Classification`](ml-classification-regression.html#random-forest-classifier) -### Naive Bayes Model +#### Clustering -[spark.naiveBayes()](api/R/spark.naiveBayes.html) fits a Bernoulli naive Bayes model against a SparkDataFrame. Only categorical data is supported. -{% include_example naiveBayes r/ml.R %} +* [`spark.gaussianMixture`](api/R/spark.gaussianMixture.html): [`Gaussian Mixture Model (GMM)`](ml-clustering.html#gaussian-mixture-model-gmm) +* [`spark.kmeans`](api/R/spark.kmeans.html): [`K-Means`](ml-clustering.html#k-means) +* [`spark.lda`](api/R/spark.lda.html): [`Latent Dirichlet Allocation (LDA)`](ml-clustering.html#latent-dirichlet-allocation-lda) -### KMeans Model +#### Collaborative Filtering + +* [`spark.als`](api/R/spark.als.html): [`Alternating Least Squares (ALS)`](ml-collaborative-filtering.html#collaborative-filtering) + +#### Statistics + +* [`spark.kstest`](api/R/spark.kstest.html): `Kolmogorov-Smirnov Test` + +Under the hood, SparkR uses MLlib to train the model. Please refer to the corresponding section of MLlib user guide for example code. +Users can call `summary` to print a summary of the fitted model, [predict](api/R/predict.html) to make predictions on new data, and [write.ml](api/R/write.ml.html)/[read.ml](api/R/read.ml.html) to save/load fitted models. +SparkR supports a subset of the available R formula operators for model fitting, including ‘~’, ‘.’, ‘:’, ‘+’, and ‘-‘. -[spark.kmeans()](api/R/spark.kmeans.html) fits a k-means clustering model against a Spark DataFrame, similarly to R's kmeans(). -{% include_example kmeans r/ml.R %} ## Model persistence The following example shows how to save/load a MLlib model by SparkR. -{% include_example read_write r/ml.R %} +{% include_example read_write r/ml/ml.R %} # R Function Name Conflicts diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c7ad06c639..6287e2be95 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -526,6 +526,11 @@ By default `saveAsTable` will create a "managed table", meaning that the locatio be controlled by the metastore. Managed tables will also have their data deleted automatically when a table is dropped. +Currently, `saveAsTable` does not expose an API supporting the creation of an "External table" from a `DataFrame`, +however, this functionality can be achieved by providing a `path` option to the `DataFrameWriter` with `path` as the key +and location of the external table as its value (String) when saving the table with `saveAsTable`. When an External table +is dropped only its metadata is removed. + ## Parquet Files [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. @@ -1851,7 +1856,8 @@ You can access them by doing
Property NameDefaultMeaning
The value type in Scala of the data type of this field (For example, Int for a StructField with the data type IntegerType) - StructField(name, dataType, nullable) + StructField(name, dataType, [nullable])
+ Note: The default value of nullable is true.
@@ -2139,7 +2145,8 @@ from pyspark.sql.types import * The value type in Python of the data type of this field (For example, Int for a StructField with the data type IntegerType) - StructField(name, dataType, nullable) + StructField(name, dataType, [nullable])
+ Note: The default value of nullable is True. @@ -2260,7 +2267,7 @@ from pyspark.sql.types import * vector or list list(type="array", elementType=elementType, containsNull=[containsNull])
- Note: The default value of containsNull is True. + Note: The default value of containsNull is TRUE. @@ -2268,7 +2275,7 @@ from pyspark.sql.types import * environment list(type="map", keyType=keyType, valueType=valueType, valueContainsNull=[valueContainsNull])
- Note: The default value of valueContainsNull is True. + Note: The default value of valueContainsNull is TRUE. @@ -2285,7 +2292,8 @@ from pyspark.sql.types import * The value type in R of the data type of this field (For example, integer for a StructField with the data type IntegerType) - list(name=name, type=dataType, nullable=nullable) + list(name=name, type=dataType, nullable=[nullable])
+ Note: The default value of nullable is TRUE. diff --git a/examples/pom.xml b/examples/pom.xml index 90bbd3fbb9..91c2e81ebe 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../pom.xml diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java new file mode 100644 index 0000000000..ca3ee5a285 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.SparkSession; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.feature.BucketedRandomProjectionLSH; +import org.apache.spark.ml.feature.BucketedRandomProjectionLSHModel; +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaBucketedRandomProjectionLSHExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaBucketedRandomProjectionLSHExample") + .getOrCreate(); + + // $example on$ + List dataA = Arrays.asList( + RowFactory.create(0, Vectors.dense(1.0, 1.0)), + RowFactory.create(1, Vectors.dense(1.0, -1.0)), + RowFactory.create(2, Vectors.dense(-1.0, -1.0)), + RowFactory.create(3, Vectors.dense(-1.0, 1.0)) + ); + + List dataB = Arrays.asList( + RowFactory.create(4, Vectors.dense(1.0, 0.0)), + RowFactory.create(5, Vectors.dense(-1.0, 0.0)), + RowFactory.create(6, Vectors.dense(0.0, 1.0)), + RowFactory.create(7, Vectors.dense(0.0, -1.0)) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("keys", new VectorUDT(), false, Metadata.empty()) + }); + Dataset dfA = spark.createDataFrame(dataA, schema); + Dataset dfB = spark.createDataFrame(dataB, schema); + + Vector key = Vectors.dense(1.0, 0.0); + + BucketedRandomProjectionLSH mh = new BucketedRandomProjectionLSH() + .setBucketLength(2.0) + .setNumHashTables(3) + .setInputCol("keys") + .setOutputCol("values"); + + BucketedRandomProjectionLSHModel model = mh.fit(dfA); + + // Feature Transformation + model.transform(dfA).show(); + // Cache the transformed columns + Dataset transformedA = model.transform(dfA).cache(); + Dataset transformedB = model.transform(dfB).cache(); + + // Approximate similarity join + model.approxSimilarityJoin(dfA, dfB, 1.5).show(); + model.approxSimilarityJoin(transformedA, transformedB, 1.5).show(); + // Self Join + model.approxSimilarityJoin(dfA, dfA, 2.5).filter("datasetA.id < datasetB.id").show(); + + // Approximate nearest neighbor search + model.approxNearestNeighbors(dfA, key, 2).show(); + model.approxNearestNeighbors(transformedA, key, 2).show(); + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java new file mode 100644 index 0000000000..9dbbf6d117 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.SparkSession; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.feature.MinHashLSH; +import org.apache.spark.ml.feature.MinHashLSHModel; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off$ + +public class JavaMinHashLSHExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaMinHashLSHExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(0, Vectors.sparse(6, new int[]{0, 1, 2}, new double[]{1.0, 1.0, 1.0})), + RowFactory.create(1, Vectors.sparse(6, new int[]{2, 3, 4}, new double[]{1.0, 1.0, 1.0})), + RowFactory.create(2, Vectors.sparse(6, new int[]{0, 2, 4}, new double[]{1.0, 1.0, 1.0})) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("keys", new VectorUDT(), false, Metadata.empty()) + }); + Dataset dataFrame = spark.createDataFrame(data, schema); + + MinHashLSH mh = new MinHashLSH() + .setNumHashTables(1) + .setInputCol("keys") + .setOutputCol("values"); + + MinHashLSHModel model = mh.fit(dataFrame); + model.transform(dataFrame).show(); + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/python/ml/logistic_regression_summary_example.py b/examples/src/main/python/ml/logistic_regression_summary_example.py new file mode 100644 index 0000000000..bd440a1fbe --- /dev/null +++ b/examples/src/main/python/ml/logistic_regression_summary_example.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from pyspark.ml.classification import LogisticRegression +# $example off$ +from pyspark.sql import SparkSession + +""" +An example demonstrating Logistic Regression Summary. +Run with: + bin/spark-submit examples/src/main/python/ml/logistic_regression_summary_example.py +""" + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("LogisticRegressionSummary") \ + .getOrCreate() + + # Load training data + training = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + + # Fit the model + lrModel = lr.fit(training) + + # $example on$ + # Extract the summary from the returned LogisticRegressionModel instance trained + # in the earlier example + trainingSummary = lrModel.summary + + # Obtain the objective per iteration + objectiveHistory = trainingSummary.objectiveHistory + print("objectiveHistory:") + for objective in objectiveHistory: + print(objective) + + # Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. + trainingSummary.roc.show() + print("areaUnderROC: " + str(trainingSummary.areaUnderROC)) + + # Set the model threshold to maximize F-Measure + fMeasure = trainingSummary.fMeasureByThreshold + maxFMeasure = fMeasure.groupBy().max('F-Measure').select('max(F-Measure)').head() + bestThreshold = fMeasure.where(fMeasure['F-Measure'] == maxFMeasure['max(F-Measure)']) \ + .select('threshold').head()['threshold'] + lr.setThreshold(bestThreshold) + # $example off$ + + spark.stop() diff --git a/examples/src/main/r/ml.R b/examples/src/main/r/ml.R deleted file mode 100644 index a8a1274ac9..0000000000 --- a/examples/src/main/r/ml.R +++ /dev/null @@ -1,148 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# To run this example use -# ./bin/spark-submit examples/src/main/r/ml.R - -# Load SparkR library into your R session -library(SparkR) - -# Initialize SparkSession -sparkR.session(appName = "SparkR-ML-example") - -############################ spark.glm and glm ############################################## -# $example on:glm$ -irisDF <- suppressWarnings(createDataFrame(iris)) -# Fit a generalized linear model of family "gaussian" with spark.glm -gaussianDF <- irisDF -gaussianTestDF <- irisDF -gaussianGLM <- spark.glm(gaussianDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") - -# Model summary -summary(gaussianGLM) - -# Prediction -gaussianPredictions <- predict(gaussianGLM, gaussianTestDF) -showDF(gaussianPredictions) - -# Fit a generalized linear model with glm (R-compliant) -gaussianGLM2 <- glm(Sepal_Length ~ Sepal_Width + Species, gaussianDF, family = "gaussian") -summary(gaussianGLM2) - -# Fit a generalized linear model of family "binomial" with spark.glm -binomialDF <- filter(irisDF, irisDF$Species != "setosa") -binomialTestDF <- binomialDF -binomialGLM <- spark.glm(binomialDF, Species ~ Sepal_Length + Sepal_Width, family = "binomial") - -# Model summary -summary(binomialGLM) - -# Prediction -binomialPredictions <- predict(binomialGLM, binomialTestDF) -showDF(binomialPredictions) -# $example off:glm$ -############################ spark.survreg ############################################## -# $example on:survreg$ -# Use the ovarian dataset available in R survival package -library(survival) - -# Fit an accelerated failure time (AFT) survival regression model with spark.survreg -ovarianDF <- suppressWarnings(createDataFrame(ovarian)) -aftDF <- ovarianDF -aftTestDF <- ovarianDF -aftModel <- spark.survreg(aftDF, Surv(futime, fustat) ~ ecog_ps + rx) - -# Model summary -summary(aftModel) - -# Prediction -aftPredictions <- predict(aftModel, aftTestDF) -showDF(aftPredictions) -# $example off:survreg$ -############################ spark.naiveBayes ############################################## -# $example on:naiveBayes$ -# Fit a Bernoulli naive Bayes model with spark.naiveBayes -titanic <- as.data.frame(Titanic) -titanicDF <- createDataFrame(titanic[titanic$Freq > 0, -5]) -nbDF <- titanicDF -nbTestDF <- titanicDF -nbModel <- spark.naiveBayes(nbDF, Survived ~ Class + Sex + Age) - -# Model summary -summary(nbModel) - -# Prediction -nbPredictions <- predict(nbModel, nbTestDF) -showDF(nbPredictions) -# $example off:naiveBayes$ -############################ spark.kmeans ############################################## -# $example on:kmeans$ -# Fit a k-means model with spark.kmeans -irisDF <- suppressWarnings(createDataFrame(iris)) -kmeansDF <- irisDF -kmeansTestDF <- irisDF -kmeansModel <- spark.kmeans(kmeansDF, ~ Sepal_Length + Sepal_Width + Petal_Length + Petal_Width, - k = 3) - -# Model summary -summary(kmeansModel) - -# Get fitted result from the k-means model -showDF(fitted(kmeansModel)) - -# Prediction -kmeansPredictions <- predict(kmeansModel, kmeansTestDF) -showDF(kmeansPredictions) -# $example off:kmeans$ -############################ model read/write ############################################## -# $example on:read_write$ -irisDF <- suppressWarnings(createDataFrame(iris)) -# Fit a generalized linear model of family "gaussian" with spark.glm -gaussianDF <- irisDF -gaussianTestDF <- irisDF -gaussianGLM <- spark.glm(gaussianDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") - -# Save and then load a fitted MLlib model -modelPath <- tempfile(pattern = "ml", fileext = ".tmp") -write.ml(gaussianGLM, modelPath) -gaussianGLM2 <- read.ml(modelPath) - -# Check model summary -summary(gaussianGLM2) - -# Check model prediction -gaussianPredictions <- predict(gaussianGLM2, gaussianTestDF) -showDF(gaussianPredictions) - -unlink(modelPath) -# $example off:read_write$ -############################ fit models with spark.lapply ##################################### - -# Perform distributed training of multiple models with spark.lapply -families <- c("gaussian", "poisson") -train <- function(family) { - model <- glm(Sepal.Length ~ Sepal.Width + Species, iris, family = family) - summary(model) -} -model.summaries <- spark.lapply(families, train) - -# Print the summary of each model -print(model.summaries) - - -# Stop the SparkSession now -sparkR.session.stop() diff --git a/examples/src/main/r/ml/als.R b/examples/src/main/r/ml/als.R new file mode 100644 index 0000000000..383bbba190 --- /dev/null +++ b/examples/src/main/r/ml/als.R @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/als.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-als-example") + +# $example on$ +# Load training data +data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), + list(1, 2, 4.0), list(2, 1, 1.0), list(2, 2, 5.0)) +df <- createDataFrame(data, c("userId", "movieId", "rating")) +training <- df +test <- df + +# Fit a recommendation model using ALS with spark.als +model <- spark.als(training, maxIter = 5, regParam = 0.01, userCol = "userId", + itemCol = "movieId", ratingCol = "rating") + +# Model summary +summary(model) + +# Prediction +predictions <- predict(model, test) +showDF(predictions) +# $example off$ diff --git a/examples/src/main/r/ml/gaussianMixture.R b/examples/src/main/r/ml/gaussianMixture.R new file mode 100644 index 0000000000..54b69acc83 --- /dev/null +++ b/examples/src/main/r/ml/gaussianMixture.R @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/gaussianMixture.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-gaussianMixture-example") + +# $example on$ +# Load training data +df <- read.df("data/mllib/sample_kmeans_data.txt", source = "libsvm") +training <- df +test <- df + +# Fit a gaussian mixture clustering model with spark.gaussianMixture +model <- spark.gaussianMixture(training, ~ features, k = 2) + +# Model summary +summary(model) + +# Prediction +predictions <- predict(model, test) +showDF(predictions) +# $example off$ diff --git a/examples/src/main/r/ml/gbt.R b/examples/src/main/r/ml/gbt.R new file mode 100644 index 0000000000..be16c2aa66 --- /dev/null +++ b/examples/src/main/r/ml/gbt.R @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/gbt.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-gbt-example") + +# GBT classification model + +# $example on:classification$ +# Load training data +df <- read.df("data/mllib/sample_libsvm_data.txt", source = "libsvm") +training <- df +test <- df + +# Fit a GBT classification model with spark.gbt +model <- spark.gbt(training, label ~ features, "classification", maxIter = 10) + +# Model summary +summary(model) + +# Prediction +predictions <- predict(model, test) +showDF(predictions) +# $example off:classification$ + +# GBT regression model + +# $example on:regression$ +# Load training data +df <- read.df("data/mllib/sample_linear_regression_data.txt", source = "libsvm") +training <- df +test <- df + +# Fit a GBT regression model with spark.gbt +model <- spark.gbt(training, label ~ features, "regression", maxIter = 10) + +# Model summary +summary(model) + +# Prediction +predictions <- predict(model, test) +showDF(predictions) +# $example off:regression$ diff --git a/examples/src/main/r/ml/glm.R b/examples/src/main/r/ml/glm.R new file mode 100644 index 0000000000..599071790a --- /dev/null +++ b/examples/src/main/r/ml/glm.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/glm.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-glm-example") + +# $example on$ +irisDF <- suppressWarnings(createDataFrame(iris)) +# Fit a generalized linear model of family "gaussian" with spark.glm +gaussianDF <- irisDF +gaussianTestDF <- irisDF +gaussianGLM <- spark.glm(gaussianDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") + +# Model summary +summary(gaussianGLM) + +# Prediction +gaussianPredictions <- predict(gaussianGLM, gaussianTestDF) +showDF(gaussianPredictions) + +# Fit a generalized linear model with glm (R-compliant) +gaussianGLM2 <- glm(Sepal_Length ~ Sepal_Width + Species, gaussianDF, family = "gaussian") +summary(gaussianGLM2) + +# Fit a generalized linear model of family "binomial" with spark.glm +# Note: Filter out "setosa" from label column (two labels left) to match "binomial" family. +binomialDF <- filter(irisDF, irisDF$Species != "setosa") +binomialTestDF <- binomialDF +binomialGLM <- spark.glm(binomialDF, Species ~ Sepal_Length + Sepal_Width, family = "binomial") + +# Model summary +summary(binomialGLM) + +# Prediction +binomialPredictions <- predict(binomialGLM, binomialTestDF) +showDF(binomialPredictions) +# $example off$ diff --git a/examples/src/main/r/ml/isoreg.R b/examples/src/main/r/ml/isoreg.R new file mode 100644 index 0000000000..75dce97ed9 --- /dev/null +++ b/examples/src/main/r/ml/isoreg.R @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/isoreg.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-isoreg-example") + +# $example on$ +# Load training data +df <- read.df("data/mllib/sample_isotonic_regression_libsvm_data.txt", source = "libsvm") +training <- df +test <- df + +# Fit an isotonic regression model with spark.isoreg +model <- spark.isoreg(training, label ~ features, isotonic = FALSE) + +# Model summary +summary(model) + +# Prediction +predictions <- predict(model, test) +showDF(predictions) +# $example off$ diff --git a/examples/src/main/r/ml/kmeans.R b/examples/src/main/r/ml/kmeans.R new file mode 100644 index 0000000000..043b21b038 --- /dev/null +++ b/examples/src/main/r/ml/kmeans.R @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/kmeans.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-kmeans-example") + +# $example on$ +# Fit a k-means model with spark.kmeans +irisDF <- suppressWarnings(createDataFrame(iris)) +kmeansDF <- irisDF +kmeansTestDF <- irisDF +kmeansModel <- spark.kmeans(kmeansDF, ~ Sepal_Length + Sepal_Width + Petal_Length + Petal_Width, + k = 3) + +# Model summary +summary(kmeansModel) + +# Get fitted result from the k-means model +showDF(fitted(kmeansModel)) + +# Prediction +kmeansPredictions <- predict(kmeansModel, kmeansTestDF) +showDF(kmeansPredictions) +# $example off$ diff --git a/examples/src/main/r/ml/kstest.R b/examples/src/main/r/ml/kstest.R new file mode 100644 index 0000000000..12625f7d3e --- /dev/null +++ b/examples/src/main/r/ml/kstest.R @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/kstest.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-kstest-example") + +# $example on$ +# Load training data +data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25, -1, -0.5)) +df <- createDataFrame(data) +training <- df +test <- df + +# Conduct the two-sided Kolmogorov-Smirnov (KS) test with spark.kstest +model <- spark.kstest(df, "test", "norm") + +# Model summary +summary(model) +# $example off$ diff --git a/examples/src/main/r/ml/lda.R b/examples/src/main/r/ml/lda.R new file mode 100644 index 0000000000..7b187d155a --- /dev/null +++ b/examples/src/main/r/ml/lda.R @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/lda.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-lda-example") + +# $example on$ +# Load training data +df <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm") +training <- df +test <- df + +# Fit a latent dirichlet allocation model with spark.lda +model <- spark.lda(training, k = 10, maxIter = 10) + +# Model summary +summary(model) + +# Posterior probabilities +posterior <- spark.posterior(model, test) +showDF(posterior) + +# The log perplexity of the LDA model +logPerplexity <- spark.perplexity(model, test) +print(paste0("The upper bound bound on perplexity: ", logPerplexity)) +# $example off$ diff --git a/examples/src/main/r/ml/logit.R b/examples/src/main/r/ml/logit.R new file mode 100644 index 0000000000..a2ac882ed0 --- /dev/null +++ b/examples/src/main/r/ml/logit.R @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/logit.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-logit-example") + +# Binomial logistic regression + +# $example on:binomial$ +# Load training data +df <- read.df("data/mllib/sample_libsvm_data.txt", source = "libsvm") +training <- df +test <- df + +# Fit an binomial logistic regression model with spark.logit +model <- spark.logit(training, label ~ features, maxIter = 10, regParam = 0.3, elasticNetParam = 0.8) + +# Model summary +summary(model) + +# Prediction +predictions <- predict(model, test) +showDF(predictions) +# $example off:binomial$ + +# Multinomial logistic regression + +# $example on:multinomial$ +# Load training data +df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") +training <- df +test <- df + +# Fit a multinomial logistic regression model with spark.logit +model <- spark.logit(training, label ~ features, maxIter = 10, regParam = 0.3, elasticNetParam = 0.8) + +# Model summary +summary(model) + +# Prediction +predictions <- predict(model, test) +showDF(predictions) +# $example off:multinomial$ diff --git a/examples/src/main/r/ml/ml.R b/examples/src/main/r/ml/ml.R new file mode 100644 index 0000000000..d601590c22 --- /dev/null +++ b/examples/src/main/r/ml/ml.R @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/ml.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-example") + +############################ model read/write ############################################## +# $example on:read_write$ +irisDF <- suppressWarnings(createDataFrame(iris)) +# Fit a generalized linear model of family "gaussian" with spark.glm +gaussianDF <- irisDF +gaussianTestDF <- irisDF +gaussianGLM <- spark.glm(gaussianDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") + +# Save and then load a fitted MLlib model +modelPath <- tempfile(pattern = "ml", fileext = ".tmp") +write.ml(gaussianGLM, modelPath) +gaussianGLM2 <- read.ml(modelPath) + +# Check model summary +summary(gaussianGLM2) + +# Check model prediction +gaussianPredictions <- predict(gaussianGLM2, gaussianTestDF) +showDF(gaussianPredictions) + +unlink(modelPath) +# $example off:read_write$ + +############################ fit models with spark.lapply ##################################### +# Perform distributed training of multiple models with spark.lapply +costs <- exp(seq(from = log(1), to = log(1000), length.out = 5)) +train <- function(cost) { + stopifnot(requireNamespace("e1071", quietly = TRUE)) + model <- e1071::svm(Species ~ ., data = iris, cost = cost) + summary(model) +} + +model.summaries <- spark.lapply(costs, train) + +# Print the summary of each model +print(model.summaries) + +# Stop the SparkSession now +sparkR.session.stop() diff --git a/examples/src/main/r/ml/mlp.R b/examples/src/main/r/ml/mlp.R new file mode 100644 index 0000000000..d28fc069bd --- /dev/null +++ b/examples/src/main/r/ml/mlp.R @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/mlp.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-mlp-example") + +# $example on$ +# Load training data +df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") +training <- df +test <- df + +# specify layers for the neural network: +# input layer of size 4 (features), two intermediate of size 5 and 4 +# and output of size 3 (classes) +layers = c(4, 5, 4, 3) + +# Fit a multi-layer perceptron neural network model with spark.mlp +model <- spark.mlp(training, label ~ features, maxIter = 100, + layers = layers, blockSize = 128, seed = 1234) + +# Model summary +summary(model) + +# Prediction +predictions <- predict(model, test) +showDF(predictions) +# $example off$ diff --git a/examples/src/main/r/ml/naiveBayes.R b/examples/src/main/r/ml/naiveBayes.R new file mode 100644 index 0000000000..9c416599b4 --- /dev/null +++ b/examples/src/main/r/ml/naiveBayes.R @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/naiveBayes.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-naiveBayes-example") + +# $example on$ +# Fit a Bernoulli naive Bayes model with spark.naiveBayes +titanic <- as.data.frame(Titanic) +titanicDF <- createDataFrame(titanic[titanic$Freq > 0, -5]) +nbDF <- titanicDF +nbTestDF <- titanicDF +nbModel <- spark.naiveBayes(nbDF, Survived ~ Class + Sex + Age) + +# Model summary +summary(nbModel) + +# Prediction +nbPredictions <- predict(nbModel, nbTestDF) +showDF(nbPredictions) +# $example off$ diff --git a/examples/src/main/r/ml/randomForest.R b/examples/src/main/r/ml/randomForest.R new file mode 100644 index 0000000000..d1b96b62a0 --- /dev/null +++ b/examples/src/main/r/ml/randomForest.R @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/randomForest.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-randomForest-example") + +# Random forest classification model + +# $example on:classification$ +# Load training data +df <- read.df("data/mllib/sample_libsvm_data.txt", source = "libsvm") +training <- df +test <- df + +# Fit a random forest classification model with spark.randomForest +model <- spark.randomForest(training, label ~ features, "classification", numTrees = 10) + +# Model summary +summary(model) + +# Prediction +predictions <- predict(model, test) +showDF(predictions) +# $example off:classification$ + +# Random forest regression model + +# $example on:regression$ +# Load training data +df <- read.df("data/mllib/sample_linear_regression_data.txt", source = "libsvm") +training <- df +test <- df + +# Fit a random forest regression model with spark.randomForest +model <- spark.randomForest(training, label ~ features, "regression", numTrees = 10) + +# Model summary +summary(model) + +# Prediction +predictions <- predict(model, test) +showDF(predictions) +# $example off:regression$ diff --git a/examples/src/main/r/ml/survreg.R b/examples/src/main/r/ml/survreg.R new file mode 100644 index 0000000000..f728b8b5d8 --- /dev/null +++ b/examples/src/main/r/ml/survreg.R @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/survreg.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-survreg-example") + +# $example on$ +# Use the ovarian dataset available in R survival package +library(survival) + +# Fit an accelerated failure time (AFT) survival regression model with spark.survreg +ovarianDF <- suppressWarnings(createDataFrame(ovarian)) +aftDF <- ovarianDF +aftTestDF <- ovarianDF +aftModel <- spark.survreg(aftDF, Surv(futime, fustat) ~ ecog_ps + rx) + +# Model summary +summary(aftModel) + +# Prediction +aftPredictions <- predict(aftModel, aftTestDF) +showDF(aftPredictions) +# $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketedRandomProjectionLSHExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketedRandomProjectionLSHExample.scala new file mode 100644 index 0000000000..686cc39d3b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketedRandomProjectionLSHExample.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.BucketedRandomProjectionLSH +import org.apache.spark.ml.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SparkSession + +object BucketedRandomProjectionLSHExample { + def main(args: Array[String]): Unit = { + // Creates a SparkSession + val spark = SparkSession + .builder + .appName("BucketedRandomProjectionLSHExample") + .getOrCreate() + + // $example on$ + val dfA = spark.createDataFrame(Seq( + (0, Vectors.dense(1.0, 1.0)), + (1, Vectors.dense(1.0, -1.0)), + (2, Vectors.dense(-1.0, -1.0)), + (3, Vectors.dense(-1.0, 1.0)) + )).toDF("id", "keys") + + val dfB = spark.createDataFrame(Seq( + (4, Vectors.dense(1.0, 0.0)), + (5, Vectors.dense(-1.0, 0.0)), + (6, Vectors.dense(0.0, 1.0)), + (7, Vectors.dense(0.0, -1.0)) + )).toDF("id", "keys") + + val key = Vectors.dense(1.0, 0.0) + + val brp = new BucketedRandomProjectionLSH() + .setBucketLength(2.0) + .setNumHashTables(3) + .setInputCol("keys") + .setOutputCol("values") + + val model = brp.fit(dfA) + + // Feature Transformation + model.transform(dfA).show() + // Cache the transformed columns + val transformedA = model.transform(dfA).cache() + val transformedB = model.transform(dfB).cache() + + // Approximate similarity join + model.approxSimilarityJoin(dfA, dfB, 1.5).show() + model.approxSimilarityJoin(transformedA, transformedB, 1.5).show() + // Self Join + model.approxSimilarityJoin(dfA, dfA, 2.5).filter("datasetA.id < datasetB.id").show() + + // Approximate nearest neighbor search + model.approxNearestNeighbors(dfA, key, 2).show() + model.approxNearestNeighbors(transformedA, key, 2).show() + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinHashLSHExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinHashLSHExample.scala new file mode 100644 index 0000000000..f4fc3cf411 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MinHashLSHExample.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.MinHashLSH +import org.apache.spark.ml.linalg.Vectors +// $example off$ +import org.apache.spark.sql.SparkSession + +object MinHashLSHExample { + def main(args: Array[String]): Unit = { + // Creates a SparkSession + val spark = SparkSession + .builder + .appName("MinHashLSHExample") + .getOrCreate() + + // $example on$ + val dfA = spark.createDataFrame(Seq( + (0, Vectors.sparse(6, Seq((0, 1.0), (1, 1.0), (2, 1.0)))), + (1, Vectors.sparse(6, Seq((2, 1.0), (3, 1.0), (4, 1.0)))), + (2, Vectors.sparse(6, Seq((0, 1.0), (2, 1.0), (4, 1.0)))) + )).toDF("id", "keys") + + val dfB = spark.createDataFrame(Seq( + (3, Vectors.sparse(6, Seq((1, 1.0), (3, 1.0), (5, 1.0)))), + (4, Vectors.sparse(6, Seq((2, 1.0), (3, 1.0), (5, 1.0)))), + (5, Vectors.sparse(6, Seq((1, 1.0), (2, 1.0), (4, 1.0)))) + )).toDF("id", "keys") + + val key = Vectors.sparse(6, Seq((1, 1.0), (3, 1.0))) + + val mh = new MinHashLSH() + .setNumHashTables(3) + .setInputCol("keys") + .setOutputCol("values") + + val model = mh.fit(dfA) + + // Feature Transformation + model.transform(dfA).show() + // Cache the transformed columns + val transformedA = model.transform(dfA).cache() + val transformedB = model.transform(dfB).cache() + + // Approximate similarity join + model.approxSimilarityJoin(dfA, dfB, 0.6).show() + model.approxSimilarityJoin(transformedA, transformedB, 0.6).show() + // Self Join + model.approxSimilarityJoin(dfA, dfA, 0.6).filter("datasetA.id < datasetB.id").show() + + // Approximate nearest neighbor search + model.approxNearestNeighbors(dfA, key, 2).show() + model.approxNearestNeighbors(transformedA, key, 2).show() + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/ComputeSVDbyGramExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/ComputeSVDbyGramExample.scala new file mode 100644 index 0000000000..ddc4ccd990 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/ComputeSVDbyGramExample.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.distributed.{IndexedRow, IndexedRowMatrix} +// $example on$ +import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.mllib.linalg.SingularValueDecomposition +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.Vectors +// $example off$ + +object ComputeSVDbyGramExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("TallSkinnySVDExample") + val sc = new SparkContext(conf) + + // $example on$ + val data = Seq( + (0L, Vectors.dense(0.0, 1.0, 2.0)), + (1L, Vectors.dense(3.0, 4.0, 5.0)), + (3L, Vectors.dense(9.0, 0.0, 1.0)) + ).map(x => IndexedRow(x._1, x._2)) + + val indexedRows = sc.parallelize(data, 2) + + val mat = new IndexedRowMatrix(indexedRows) + + // Compute the singular value decompositions of mat. + val svd: SingularValueDecomposition[IndexedRowMatrix, Matrix] = + mat.computeSVDbyGram(computeU = true) + val U: IndexedRowMatrix = svd.U // The U factor is a RowMatrix. + val s: Vector = svd.s // The singular values are stored in a local dense vector. + val V: Matrix = svd.V // The V factor is a local dense matrix. + // $example off$ + val collect = U.rows.collect() + println("U factor is:") + collect.foreach { vector => println(vector) } + println(s"Singular values are: $s") + println(s"V factor is:\n$V") + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PartialSVDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PartialSVDExample.scala new file mode 100644 index 0000000000..9df2c4a6a4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PartialSVDExample.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.DenseMatrix +import org.apache.spark.mllib.linalg.distributed.BlockMatrix +// $example on$ +import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.mllib.linalg.SingularValueDecomposition +import org.apache.spark.mllib.linalg.Vector +// $example off$ + +object PartialSVDExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("PartialSVDExample") + val sc = new SparkContext(conf) + + // $example on$ + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 0.0))), + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 1.0)))) + + val mat = new BlockMatrix(sc.parallelize(blocks, 2), 2, 2) + + // Compute the top 4 singular values and corresponding singular vectors. + val svd: SingularValueDecomposition[BlockMatrix, BlockMatrix] = + mat.partialSVD(4, sc, computeU = true) + val U: BlockMatrix = svd.U // The U factor is a BlockMatrix. + val s: Vector = svd.s // The singular values are stored in a local dense vector. + val V: BlockMatrix = svd.V // The V factor is a BlockMatrix. + // $example off$ + val collectU = U.toIndexedRowMatrix().toRowMatrix().rows.collect() + println("U factor is:") + collectU.foreach { vector => println(vector) } + println(s"Singular values are: $s") + println("V factor is:") + val collectV = V.toIndexedRowMatrix().toRowMatrix().rows.collect() + collectV.foreach { vector => println(vector) } + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVDExample.scala new file mode 100644 index 0000000000..220f2d3655 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVDExample.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.distributed.{IndexedRow, IndexedRowMatrix} +// $example on$ +import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.mllib.linalg.SingularValueDecomposition +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.Vectors +// $example off$ + +object TallSkinnySVDExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("TallSkinnySVDExample") + val sc = new SparkContext(conf) + + // $example on$ + val data = Seq( + (0L, Vectors.dense(0.0, 1.0, 2.0)), + (1L, Vectors.dense(3.0, 4.0, 5.0)), + (3L, Vectors.dense(9.0, 0.0, 1.0)) + ).map(x => IndexedRow(x._1, x._2)) + + val indexedRows = sc.parallelize(data, 2) + + val mat = new IndexedRowMatrix(indexedRows) + + // Compute the top 3 singular values and corresponding singular vectors. + val svd: SingularValueDecomposition[IndexedRowMatrix, Matrix] = + mat.tallSkinnySVD(3, sc, computeU = true) + val U: IndexedRowMatrix = svd.U // The U factor is a RowMatrix. + val s: Vector = svd.s // The singular values are stored in a local dense vector. + val V: Matrix = svd.V // The V factor is a local dense matrix. + // $example off$ + val collect = U.rows.collect() + println("U factor is:") + collect.foreach { vector => println(vector) } + println(s"Singular values are: $s") + println(s"V factor is:\n$V") + } +} +// scalastyle:on println diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 57d553b75b..8948df2da8 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -96,7 +96,7 @@ org.apache.spark spark-tags_${scala.binary.version} - ${project.version} + test-jar test diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index fb0292a5f1..f8ef8a9913 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 5e9275c8e6..6d547c46d6 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -93,6 +93,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + target/scala-${scala.binary.version}/classes diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 7b68ca7373..46901d64ed 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -69,6 +69,18 @@ org.apache.spark spark-tags_${scala.binary.version}
+ + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + target/scala-${scala.binary.version}/classes diff --git a/external/java8-tests/pom.xml b/external/java8-tests/pom.xml index 1bc206e867..8fc46d7af2 100644 --- a/external/java8-tests/pom.xml +++ b/external/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -73,6 +73,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index 4f5045326a..295142cbfd 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index ebff5fd07a..6cf448e65e 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -88,6 +88,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + target/scala-${scala.binary.version}/classes diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala index 3f438e9918..3f396a7e6b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -86,7 +86,7 @@ private[kafka010] case class CachedKafkaConsumer private( var toFetchOffset = offset while (toFetchOffset != UNKNOWN_OFFSET) { try { - return fetchData(toFetchOffset, pollTimeoutMs) + return fetchData(toFetchOffset, untilOffset, pollTimeoutMs, failOnDataLoss) } catch { case e: OffsetOutOfRangeException => // When there is some error thrown, it's better to use a new consumer to drop all cached @@ -159,14 +159,18 @@ private[kafka010] case class CachedKafkaConsumer private( } /** - * Get the record at `offset`. + * Get the record for the given offset if available. Otherwise it will either throw error + * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset), + * or null. * * @throws OffsetOutOfRangeException if `offset` is out of range * @throws TimeoutException if cannot fetch the record in `pollTimeoutMs` milliseconds. */ private def fetchData( offset: Long, - pollTimeoutMs: Long): ConsumerRecord[Array[Byte], Array[Byte]] = { + untilOffset: Long, + pollTimeoutMs: Long, + failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = { if (offset != nextOffsetInFetchedData || !fetchedData.hasNext()) { // This is the first fetch, or the last pre-fetched data has been drained. // Seek to the offset because we may call seekToBeginning or seekToEnd before this. @@ -190,10 +194,31 @@ private[kafka010] case class CachedKafkaConsumer private( } else { val record = fetchedData.next() nextOffsetInFetchedData = record.offset + 1 - // `seek` is always called before "poll". So "record.offset" must be same as "offset". - assert(record.offset == offset, - s"The fetched data has a different offset: expected $offset but was ${record.offset}") - record + // In general, Kafka uses the specified offset as the start point, and tries to fetch the next + // available offset. Hence we need to handle offset mismatch. + if (record.offset > offset) { + // This may happen when some records aged out but their offsets already got verified + if (failOnDataLoss) { + reportDataLoss(true, s"Cannot fetch records in [$offset, ${record.offset})") + // Never happen as "reportDataLoss" will throw an exception + null + } else { + if (record.offset >= untilOffset) { + reportDataLoss(false, s"Skip missing records in [$offset, $untilOffset)") + null + } else { + reportDataLoss(false, s"Skip missing records in [$offset, ${record.offset})") + record + } + } + } else if (record.offset < offset) { + // This should not happen. If it does happen, then we probably misunderstand Kafka internal + // mechanism. + throw new IllegalStateException( + s"Tried to fetch $offset but the returned record offset was ${record.offset}") + } else { + record + } } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala index 13d717092a..868edb5dcd 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala @@ -81,7 +81,14 @@ private object JsonUtils { */ def partitionOffsets(partitionOffsets: Map[TopicPartition, Long]): String = { val result = new HashMap[String, HashMap[Int, Long]]() - partitionOffsets.foreach { case (tp, off) => + implicit val ordering = new Ordering[TopicPartition] { + override def compare(x: TopicPartition, y: TopicPartition): Int = { + Ordering.Tuple2[String, Int].compare((x.topic, x.partition), (y.topic, y.partition)) + } + } + val partitions = partitionOffsets.keySet.toSeq.sorted // sort for more determinism + partitions.foreach { tp => + val off = partitionOffsets(tp) val parts = result.getOrElse(tp.topic, new HashMap[Int, Long]) parts += tp.partition -> off result += tp.topic -> parts diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index d9ab4bb4f8..43b8d9d6d7 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -24,7 +24,7 @@ import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import scala.util.control.NonFatal -import org.apache.kafka.clients.consumer.{Consumer, KafkaConsumer, OffsetOutOfRangeException} +import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer, OffsetOutOfRangeException} import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener import org.apache.kafka.common.TopicPartition @@ -81,14 +81,16 @@ import org.apache.spark.util.UninterruptibleThread * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers * and not use wrong broker addresses. */ -private[kafka010] case class KafkaSource( +private[kafka010] class KafkaSource( sqlContext: SQLContext, consumerStrategy: ConsumerStrategy, + driverKafkaParams: ju.Map[String, Object], executorKafkaParams: ju.Map[String, Object], sourceOptions: Map[String, String], metadataPath: String, startingOffsets: StartingOffsets, - failOnDataLoss: Boolean) + failOnDataLoss: Boolean, + driverGroupIdPrefix: String) extends Source with Logging { private val sc = sqlContext.sparkContext @@ -102,16 +104,36 @@ private[kafka010] case class KafkaSource( sourceOptions.getOrElse("fetchOffset.numRetries", "3").toInt private val offsetFetchAttemptIntervalMs = - sourceOptions.getOrElse("fetchOffset.retryIntervalMs", "10").toLong + sourceOptions.getOrElse("fetchOffset.retryIntervalMs", "1000").toLong private val maxOffsetsPerTrigger = sourceOptions.get("maxOffsetsPerTrigger").map(_.toLong) + private var groupId: String = null + + private var nextId = 0 + + private def nextGroupId(): String = { + groupId = driverGroupIdPrefix + "-" + nextId + nextId += 1 + groupId + } + /** * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the * offsets and never commits them. */ - private val consumer = consumerStrategy.createConsumer() + private var consumer: Consumer[Array[Byte], Array[Byte]] = createConsumer() + + /** + * Create a consumer using the new generated group id. We always use a new consumer to avoid + * just using a broken consumer to retry on Kafka errors, which likely will fail again. + */ + private def createConsumer(): Consumer[Array[Byte], Array[Byte]] = synchronized { + val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams) + newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId()) + consumerStrategy.createConsumer(newKafkaParams) + } /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only @@ -171,6 +193,11 @@ private[kafka010] case class KafkaSource( Some(KafkaSourceOffset(offsets)) } + private def resetConsumer(): Unit = synchronized { + consumer.close() + consumer = createConsumer() + } + /** Proportionally distribute limit number of offsets among topicpartitions */ private def rateLimit( limit: Long, @@ -441,13 +468,12 @@ private[kafka010] case class KafkaSource( try { result = Some(body) } catch { - case x: OffsetOutOfRangeException => - reportDataLoss(x.getMessage) case NonFatal(e) => lastException = e logWarning(s"Error in attempt $attempt getting Kafka offsets: ", e) attempt += 1 Thread.sleep(offsetFetchAttemptIntervalMs) + resetConsumer() } } case _ => @@ -511,12 +537,12 @@ private[kafka010] object KafkaSource { )) sealed trait ConsumerStrategy { - def createConsumer(): Consumer[Array[Byte], Array[Byte]] + def createConsumer(kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] } - case class AssignStrategy(partitions: Array[TopicPartition], kafkaParams: ju.Map[String, Object]) - extends ConsumerStrategy { - override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { + case class AssignStrategy(partitions: Array[TopicPartition]) extends ConsumerStrategy { + override def createConsumer( + kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = { val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) consumer.assign(ju.Arrays.asList(partitions: _*)) consumer @@ -525,9 +551,9 @@ private[kafka010] object KafkaSource { override def toString: String = s"Assign[${partitions.mkString(", ")}]" } - case class SubscribeStrategy(topics: Seq[String], kafkaParams: ju.Map[String, Object]) - extends ConsumerStrategy { - override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { + case class SubscribeStrategy(topics: Seq[String]) extends ConsumerStrategy { + override def createConsumer( + kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = { val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) consumer.subscribe(topics.asJava) consumer @@ -536,10 +562,10 @@ private[kafka010] object KafkaSource { override def toString: String = s"Subscribe[${topics.mkString(", ")}]" } - case class SubscribePatternStrategy( - topicPattern: String, kafkaParams: ju.Map[String, Object]) + case class SubscribePatternStrategy(topicPattern: String) extends ConsumerStrategy { - override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { + override def createConsumer( + kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = { val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) consumer.subscribe( ju.regex.Pattern.compile(topicPattern), diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 585ced875c..aa01238f91 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -85,14 +85,11 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider case None => LatestOffsets } - val kafkaParamsForStrategy = + val kafkaParamsForDriver = ConfigUpdater("source", specifiedKafkaParams) .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) - // So that consumers in Kafka source do not mess with any existing group id - .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-driver") - // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial // offsets by itself instead of counting on KafkaConsumer. .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") @@ -129,17 +126,11 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider val strategy = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { case ("assign", value) => - AssignStrategy( - JsonUtils.partitions(value), - kafkaParamsForStrategy) + AssignStrategy(JsonUtils.partitions(value)) case ("subscribe", value) => - SubscribeStrategy( - value.split(",").map(_.trim()).filter(_.nonEmpty), - kafkaParamsForStrategy) + SubscribeStrategy(value.split(",").map(_.trim()).filter(_.nonEmpty)) case ("subscribepattern", value) => - SubscribePatternStrategy( - value.trim(), - kafkaParamsForStrategy) + SubscribePatternStrategy(value.trim()) case _ => // Should never reach here as we are already matching on // matched strategy names @@ -152,11 +143,13 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider new KafkaSource( sqlContext, strategy, + kafkaParamsForDriver, kafkaParamsForExecutors, parameters, metadataPath, startingOffsets, - failOnDataLoss) + failOnDataLoss, + driverGroupIdPrefix = s"$uniqueGroupId-driver") } private def validateOptions(parameters: Map[String, String]): Unit = { diff --git a/external/kafka-0-10-sql/src/test/resources/kafka-source-offset-version-2.1.0.txt b/external/kafka-0-10-sql/src/test/resources/kafka-source-offset-version-2.1.0.txt new file mode 100644 index 0000000000..6410031743 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/resources/kafka-source-offset-version-2.1.0.txt @@ -0,0 +1 @@ +{"topic1":{"0":456,"1":789},"topic2":{"0":0}} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala index 881018fd95..10b35c74f4 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala @@ -89,4 +89,17 @@ class KafkaSourceOffsetSuite extends OffsetSuite with SharedSQLContext { Array(0 -> batch0Serialized, 1 -> batch1Serialized)) } } + + test("read Spark 2.1.0 offset format") { + val offset = readFromResource("kafka-source-offset-version-2.1.0.txt") + assert(KafkaSourceOffset(offset) === + KafkaSourceOffset(("topic1", 0, 456L), ("topic1", 1, 789L), ("topic2", 0, 0L))) + } + + private def readFromResource(file: String): SerializedOffset = { + import scala.io.Source + val input = getClass.getResource(s"/$file").toURI + val str = Source.fromFile(input).mkString + SerializedOffset(str) + } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 2d6ccb22dd..544fbc5ec3 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -31,11 +31,12 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkContext import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest} -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { @@ -447,7 +448,7 @@ class KafkaSourceSuite extends KafkaSourceTest { AddKafkaData(Set(topic), 1, 2, 3), CheckAnswer(2, 3, 4), AssertOnQuery { query => - val recordsRead = query.recentProgresses.map(_.numInputRows).sum + val recordsRead = query.recentProgress.map(_.numInputRows).sum recordsRead == 3 } ) @@ -811,6 +812,11 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared private def newTopic(): String = s"failOnDataLoss-${topicId.getAndIncrement()}" + override def createSparkSession(): TestSparkSession = { + // Set maxRetries to 3 to handle NPE from `poll` when deleting a topic + new TestSparkSession(new SparkContext("local[2,3]", "test-sql-context", sparkConf)) + } + override def beforeAll(): Unit = { super.beforeAll() testUtils = new KafkaTestUtils { @@ -839,7 +845,7 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared } } - ignore("stress test for failOnDataLoss=false") { + test("stress test for failOnDataLoss=false") { val reader = spark .readStream .format("kafka") @@ -848,6 +854,7 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared .option("subscribePattern", "failOnDataLoss.*") .option("startingOffsets", "earliest") .option("failOnDataLoss", "false") + .option("fetchOffset.retryIntervalMs", "3000") val kafka = reader.load() .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") .as[(String, String)] diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index f43917e151..fd1689acf6 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -184,7 +184,7 @@ class KafkaTestUtils extends Logging { def deleteTopic(topic: String): Unit = { val partitions = zkUtils.getPartitionsForTopics(Seq(topic))(topic).size AdminUtils.deleteTopic(zkUtils, topic) - verifyTopicDeletion(zkUtils, topic, partitions, List(this.server)) + verifyTopicDeletionWithRetries(zkUtils, topic, partitions, List(this.server)) } /** Add new paritions to a Kafka topic */ @@ -286,36 +286,57 @@ class KafkaTestUtils extends Logging { props } + /** Verify topic is deleted in all places, e.g, brokers, zookeeper. */ private def verifyTopicDeletion( + topic: String, + numPartitions: Int, + servers: Seq[KafkaServer]): Unit = { + val topicAndPartitions = (0 until numPartitions).map(TopicAndPartition(topic, _)) + + import ZkUtils._ + // wait until admin path for delete topic is deleted, signaling completion of topic deletion + assert( + !zkUtils.pathExists(getDeleteTopicPath(topic)), + s"${getDeleteTopicPath(topic)} still exists") + assert(!zkUtils.pathExists(getTopicPath(topic)), s"${getTopicPath(topic)} still exists") + // ensure that the topic-partition has been deleted from all brokers' replica managers + assert(servers.forall(server => topicAndPartitions.forall(tp => + server.replicaManager.getPartition(tp.topic, tp.partition) == None)), + s"topic $topic still exists in the replica manager") + // ensure that logs from all replicas are deleted if delete topic is marked successful + assert(servers.forall(server => topicAndPartitions.forall(tp => + server.getLogManager().getLog(tp).isEmpty)), + s"topic $topic still exists in log mananger") + // ensure that topic is removed from all cleaner offsets + assert(servers.forall(server => topicAndPartitions.forall { tp => + val checkpoints = server.getLogManager().logDirs.map { logDir => + new OffsetCheckpoint(new File(logDir, "cleaner-offset-checkpoint")).read() + } + checkpoints.forall(checkpointsPerLogDir => !checkpointsPerLogDir.contains(tp)) + }), s"checkpoint for topic $topic still exists") + // ensure the topic is gone + assert( + !zkUtils.getAllTopics().contains(topic), + s"topic $topic still exists on zookeeper") + } + + /** Verify topic is deleted. Retry to delete the topic if not. */ + private def verifyTopicDeletionWithRetries( zkUtils: ZkUtils, topic: String, numPartitions: Int, servers: Seq[KafkaServer]) { - import ZkUtils._ - val topicAndPartitions = (0 until numPartitions).map(TopicAndPartition(topic, _)) - def isDeleted(): Boolean = { - // wait until admin path for delete topic is deleted, signaling completion of topic deletion - val deletePath = !zkUtils.pathExists(getDeleteTopicPath(topic)) - val topicPath = !zkUtils.pathExists(getTopicPath(topic)) - // ensure that the topic-partition has been deleted from all brokers' replica managers - val replicaManager = servers.forall(server => topicAndPartitions.forall(tp => - server.replicaManager.getPartition(tp.topic, tp.partition) == None)) - // ensure that logs from all replicas are deleted if delete topic is marked successful - val logManager = servers.forall(server => topicAndPartitions.forall(tp => - server.getLogManager().getLog(tp).isEmpty)) - // ensure that topic is removed from all cleaner offsets - val cleaner = servers.forall(server => topicAndPartitions.forall { tp => - val checkpoints = server.getLogManager().logDirs.map { logDir => - new OffsetCheckpoint(new File(logDir, "cleaner-offset-checkpoint")).read() - } - checkpoints.forall(checkpointsPerLogDir => !checkpointsPerLogDir.contains(tp)) - }) - // ensure the topic is gone - val deleted = !zkUtils.getAllTopics().contains(topic) - deletePath && topicPath && replicaManager && logManager && cleaner && deleted - } - eventually(timeout(60.seconds)) { - assert(isDeleted, s"$topic not deleted after timeout") + eventually(timeout(60.seconds), interval(200.millis)) { + try { + verifyTopicDeletion(topic, numPartitions, servers) + } catch { + case e: Throwable => + // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small + // chance that a topic will be recreated after deletion due to the asynchronous update. + // Hence, delete the topic and retry. + AdminUtils.deleteTopic(zkUtils, topic) + throw e + } } } @@ -331,7 +352,7 @@ class KafkaTestUtils extends Logging { case _ => false } - eventually(timeout(10.seconds)) { + eventually(timeout(60.seconds)) { assert(isPropagated, s"Partition [$topic, $partition] metadata not propagated after timeout") } } diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index c36d479007..88499240cd 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -89,6 +89,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + target/scala-${scala.binary.version}/classes diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index bc02b8a662..3fedd9eda1 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml index 91ccd4a927..8368a1f122 100644 --- a/external/kafka-0-8/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -89,6 +89,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + target/scala-${scala.binary.version}/classes diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index f7cb764463..90bb0e4987 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index 57809ff692..b2bac7c938 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -78,6 +78,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + target/scala-${scala.binary.version}/classes diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 858368d135..393e56a393 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -221,6 +221,12 @@ private[kinesis] class KinesisReceiver[T]( } } + /** Return the current rate limit defined in [[BlockGenerator]]. */ + private[kinesis] def getCurrentLimit: Int = { + assert(blockGenerator != null) + math.min(blockGenerator.getCurrentLimit, Int.MaxValue).toInt + } + /** Get the latest sequence number for the given shard that can be checkpointed through KCL */ private[kinesis] def getLatestSeqNumToCheckpoint(shardId: String): Option[String] = { Option(shardIdToLatestStoredSeqNum.get(shardId)) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index a0ccd086d9..73ccc4ad23 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -68,8 +68,18 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { if (!receiver.isStopped()) { try { - receiver.addRecords(shardId, batch) - logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") + // Limit the number of processed records from Kinesis stream. This is because the KCL cannot + // control the number of aggregated records to be fetched even if we set `MaxRecords` + // in `KinesisClientLibConfiguration`. For example, if we set 10 to the number of max + // records in a worker and a producer aggregates two records into one message, the worker + // possibly 20 records every callback function called. + val maxRecords = receiver.getCurrentLimit + for (start <- 0 until batch.size by maxRecords) { + val miniBatch = batch.subList(start, math.min(start + maxRecords, batch.size)) + receiver.addRecords(shardId, miniBatch) + logDebug(s"Stored: Worker $workerId stored ${miniBatch.size} records " + + s"for shardId $shardId") + } receiver.setCheckpointer(shardId, checkpointer) } catch { case NonFatal(e) => diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index deac9090e2..800502a77d 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -69,6 +69,7 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft test("process records including store and set checkpointer") { when(receiverMock.isStopped()).thenReturn(false) + when(receiverMock.getCurrentLimit).thenReturn(Int.MaxValue) val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.initialize(shardId) @@ -79,8 +80,23 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft verify(receiverMock, times(1)).setCheckpointer(shardId, checkpointerMock) } + test("split into multiple processes if a limitation is set") { + when(receiverMock.isStopped()).thenReturn(false) + when(receiverMock.getCurrentLimit).thenReturn(1) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) + recordProcessor.initialize(shardId) + recordProcessor.processRecords(batch, checkpointerMock) + + verify(receiverMock, times(1)).isStopped() + verify(receiverMock, times(1)).addRecords(shardId, batch.subList(0, 1)) + verify(receiverMock, times(1)).addRecords(shardId, batch.subList(1, 2)) + verify(receiverMock, times(1)).setCheckpointer(shardId, checkpointerMock) + } + test("shouldn't store and update checkpointer when receiver is stopped") { when(receiverMock.isStopped()).thenReturn(true) + when(receiverMock.getCurrentLimit).thenReturn(Int.MaxValue) val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.processRecords(batch, checkpointerMock) @@ -92,6 +108,7 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft test("shouldn't update checkpointer when exception occurs during store") { when(receiverMock.isStopped()).thenReturn(false) + when(receiverMock.getCurrentLimit).thenReturn(Int.MaxValue) when( receiverMock.addRecords(anyString, anyListOf(classOf[Record])) ).thenThrow(new RuntimeException()) diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index fab409d3e9..7da27817eb 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 10d5ba93eb..8df33660ea 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../pom.xml @@ -78,6 +78,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + target/scala-${scala.binary.version}/classes diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index feb3f47667..37b6e45359 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -115,9 +115,9 @@ object PageRank extends Logging { val src: VertexId = srcId.getOrElse(-1L) // Initialize the PageRank graph with each edge attribute having - // weight 1/outDegree and each vertex with attribute resetProb. + // weight 1/outDegree and each vertex with attribute 1.0. // When running personalized pagerank, only the source vertex - // has an attribute resetProb. All others are set to 0. + // has an attribute 1.0. All others are set to 0. var rankGraph: Graph[Double, Double] = graph // Associate the degree with each vertex .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) } @@ -125,7 +125,7 @@ object PageRank extends Logging { .mapTriplets( e => 1.0 / e.srcAttr, TripletFields.Src ) // Set the vertex attributes to the initial pagerank values .mapVertices { (id, attr) => - if (!(id != src && personalized)) resetProb else 0.0 + if (!(id != src && personalized)) 1.0 else 0.0 } def delta(u: VertexId, v: VertexId): Double = { if (u == v) 1.0 else 0.0 } @@ -150,8 +150,8 @@ object PageRank extends Logging { (src: VertexId, id: VertexId) => resetProb } - rankGraph = rankGraph.joinVertices(rankUpdates) { - (id, oldRank, msgSum) => rPrb(src, id) + (1.0 - resetProb) * msgSum + rankGraph = rankGraph.outerJoinVertices(rankUpdates) { + (id, oldRank, msgSumOpt) => rPrb(src, id) + (1.0 - resetProb) * msgSumOpt.getOrElse(0.0) }.cache() rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices @@ -196,7 +196,7 @@ object PageRank extends Logging { // we won't be able to store its activations in a sparse vector val zero = Vectors.sparse(sources.size, List()).asBreeze val sourcesInitMap = sources.zipWithIndex.map { case (vid, i) => - val v = Vectors.sparse(sources.size, Array(i), Array(resetProb)).asBreeze + val v = Vectors.sparse(sources.size, Array(i), Array(1.0)).asBreeze (vid, v) }.toMap val sc = graph.vertices.sparkContext @@ -225,11 +225,11 @@ object PageRank extends Logging { ctx => ctx.sendToDst(ctx.srcAttr :* ctx.attr), (a : BV[Double], b : BV[Double]) => a :+ b, TripletFields.Src) - rankGraph = rankGraph.joinVertices(rankUpdates) { - (vid, oldRank, msgSum) => - val popActivations: BV[Double] = msgSum :* (1.0 - resetProb) + rankGraph = rankGraph.outerJoinVertices(rankUpdates) { + (vid, oldRank, msgSumOpt) => + val popActivations: BV[Double] = msgSumOpt.getOrElse(zero) :* (1.0 - resetProb) val resetActivations = if (sourcesInitMapBC.value contains vid) { - sourcesInitMapBC.value(vid) + sourcesInitMapBC.value(vid) :* resetProb } else { zero } @@ -307,7 +307,7 @@ object PageRank extends Logging { .mapTriplets( e => 1.0 / e.srcAttr ) // Set the vertex attributes to (initialPR, delta = 0) .mapVertices { (id, attr) => - if (id == src) (resetProb, Double.NegativeInfinity) else (0.0, 0.0) + if (id == src) (1.0, Double.NegativeInfinity) else (0.0, 0.0) } .cache() @@ -323,7 +323,7 @@ object PageRank extends Logging { msgSum: Double): (Double, Double) = { val (oldPR, lastDelta) = attr var teleport = oldPR - val delta = if (src==id) 1.0 else 0.0 + val delta = if (src==id) resetProb else 0.0 teleport = oldPR*delta val newPR = teleport + (1.0 - resetProb) * msgSum diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index b6305c8d00..6afbb5a959 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -41,7 +41,7 @@ object GridPageRank { } } // compute the pagerank - var pr = Array.fill(nRows * nCols)(resetProb) + var pr = Array.fill(nRows * nCols)(1.0) for (iter <- 0 until nIter) { val oldPr = pr pr = new Array[Double](nRows * nCols) @@ -70,10 +70,10 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext { val resetProb = 0.15 val errorTol = 1.0e-5 - val staticRanks1 = starGraph.staticPageRank(numIter = 1, resetProb).vertices - val staticRanks2 = starGraph.staticPageRank(numIter = 2, resetProb).vertices.cache() + val staticRanks1 = starGraph.staticPageRank(numIter = 2, resetProb).vertices + val staticRanks2 = starGraph.staticPageRank(numIter = 3, resetProb).vertices.cache() - // Static PageRank should only take 2 iterations to converge + // Static PageRank should only take 3 iterations to converge val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => if (pr1 != pr2) 1 else 0 }.map { case (vid, test) => test }.sum() @@ -203,4 +203,30 @@ class PageRankSuite extends SparkFunSuite with LocalSparkContext { assert(compareRanks(staticRanks, parallelStaticRanks) < errorTol) } } + + test("Loop with source PageRank") { + withSpark { sc => + val edges = sc.parallelize((1L, 2L) :: (2L, 3L) :: (3L, 4L) :: (4L, 2L) :: Nil) + val g = Graph.fromEdgeTuples(edges, 1) + val resetProb = 0.15 + val tol = 0.0001 + val numIter = 50 + val errorTol = 1.0e-5 + + val staticRanks = g.staticPageRank(numIter, resetProb).vertices + val dynamicRanks = g.pageRank(tol, resetProb).vertices + assert(compareRanks(staticRanks, dynamicRanks) < errorTol) + + // Computed in igraph 1.0 w/ R bindings: + // > page_rank(graph_from_literal( A -+ B -+ C -+ D -+ B)) + // Alternatively in NetworkX 1.11: + // > nx.pagerank(nx.DiGraph([(1,2),(2,3),(3,4),(4,2)])) + // We multiply by the number of vertices to account for difference in normalization + val igraphPR = Seq(0.0375000, 0.3326045, 0.3202138, 0.3096817).map(_ * 4) + val ranks = VertexRDD(sc.parallelize(1L to 4L zip igraphPR)) + assert(compareRanks(staticRanks, ranks) < errorTol) + assert(compareRanks(dynamicRanks, ranks) < errorTol) + + } + } } diff --git a/launcher/pom.xml b/launcher/pom.xml index 6023cf0771..025cd84f20 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../pom.xml @@ -67,6 +67,17 @@ spark-tags_${scala.binary.version} + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + org.apache.hadoop diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index c7488082ca..0622fef17c 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -26,9 +26,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Properties; +import java.util.Set; import java.util.regex.Pattern; import static org.apache.spark.launcher.CommandBuilderUtils.*; @@ -135,7 +137,7 @@ void addOptionString(List cmd, String options) { List buildClassPath(String appClassPath) throws IOException { String sparkHome = getSparkHome(); - List cp = new ArrayList<>(); + Set cp = new LinkedHashSet<>(); addToClassPath(cp, getenv("SPARK_CLASSPATH")); addToClassPath(cp, appClassPath); @@ -158,12 +160,13 @@ List buildClassPath(String appClassPath) throws IOException { "launcher", "mllib", "repl", + "resource-managers/mesos", + "resource-managers/yarn", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", - "streaming", - "yarn" + "streaming" ); if (prependClasses) { if (!isTesting) { @@ -200,7 +203,7 @@ List buildClassPath(String appClassPath) throws IOException { addToClassPath(cp, getenv("HADOOP_CONF_DIR")); addToClassPath(cp, getenv("YARN_CONF_DIR")); addToClassPath(cp, getenv("SPARK_DIST_CLASSPATH")); - return cp; + return new ArrayList<>(cp); } /** @@ -209,7 +212,7 @@ List buildClassPath(String appClassPath) throws IOException { * @param cp List to which the new entries are appended. * @param entries New classpath entries (separated by File.pathSeparator). */ - private void addToClassPath(List cp, String entries) { + private void addToClassPath(Set cp, String entries) { if (isEmpty(entries)) { return; } diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 8c985fd13a..663f7fb0b0 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../pom.xml @@ -56,6 +56,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + diff --git a/mllib/pom.xml b/mllib/pom.xml index 4484998a49..82f840b0fc 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../pom.xml @@ -113,6 +113,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/english.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/english.txt index d075cc0bab..d6094d774a 100644 --- a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/english.txt +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/english.txt @@ -125,29 +125,57 @@ just don should now -d -ll -m -o -re -ve -y -ain -aren -couldn -didn -doesn -hadn -hasn -haven -isn -ma -mightn -mustn -needn -shan -shouldn -wasn -weren -won -wouldn +i'll +you'll +he'll +she'll +we'll +they'll +i'd +you'd +he'd +she'd +we'd +they'd +i'm +you're +he's +she's +it's +we're +they're +i've +we've +you've +they've +isn't +aren't +wasn't +weren't +haven't +hasn't +hadn't +don't +doesn't +didn't +won't +wouldn't +shan't +shouldn't +mustn't +can't +couldn't +cannot +could +here's +how's +let's +ought +that's +there's +what's +when's +where's +who's +why's +would \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 7e0bc19a7a..9f60f0896e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -81,7 +81,7 @@ class DecisionTreeClassifier @Since("1.4.0") ( * E.g. 10 means that the cache will get checkpointed every 10 iterations. * This is only used if cacheNodeIds is true and if the checkpoint directory is set in * [[org.apache.spark.SparkContext]]. - * Must be >= 1. + * Must be at least 1. * (default = 10) * @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c5fc3c8772..c99b63b25d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -98,7 +98,7 @@ class GBTClassifier @Since("1.4.0") ( * E.g. 10 means that the cache will get checkpointed every 10 iterations. * This is only used if cacheNodeIds is true and if the checkpoint directory is set in * [[org.apache.spark.SparkContext]]. - * Must be >= 1. + * Must be at least 1. * (default = 10) * @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index f4ab0a074c..e58b30d665 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -140,6 +140,14 @@ final class OneVsRestModel private[ml] ( this(uid, Metadata.empty, models.asScala.toArray) } + /** @group setParam */ + @Since("2.1.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.1.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) @@ -175,6 +183,7 @@ final class OneVsRestModel private[ml] ( val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => predictions + ((index, prediction(1))) } + model.setFeaturesCol($(featuresCol)) val transformedDataset = model.transform(df).select(columns: _*) val updatedDataset = transformedDataset .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 34c055dce6..5bbaafeff3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -83,7 +83,7 @@ class RandomForestClassifier @Since("1.4.0") ( * E.g. 10 means that the cache will get checkpointed every 10 iterations. * This is only used if cacheNodeIds is true and if the checkpoint directory is set in * [[org.apache.spark.SparkContext]]. - * Must be >= 1. + * Must be at least 1. * (default = 10) * @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index e168a418cb..e02b532ca8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -302,22 +302,19 @@ class KMeans @Since("1.5.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE - fit(dataset, handlePersistence) - } - - @Since("2.2.0") - protected def fit(dataset: Dataset[_], handlePersistence: Boolean): KMeansModel = { transformSchema(dataset.schema, logging = true) + + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => OldVectors.fromML(point) } + if (handlePersistence) { instances.persist(StorageLevel.MEMORY_AND_DISK) } + val instr = Instrumentation.create(this, instances) instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol) - val algo = new MLlibKMeans() .setK($(k)) .setInitializationMode($(initMode)) @@ -329,6 +326,7 @@ class KMeans @Since("1.5.0") ( val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) + model.setSummary(Some(summary)) instr.logSuccess(model) if (handlePersistence) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index eb4d42f255..d1f3b2af1e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -78,9 +78,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String def setOutputCol(value: String): this.type = set(outputCol, value) /** - * Param for how to handle invalid entries. Options are skip (filter out rows with - * invalid values), error (throw an error), or keep (keep invalid values in a special additional - * bucket). + * Param for how to handle invalid entries. Options are 'skip' (filter out rows with + * invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special + * additional bucket). * Default: "error" * @group param */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index b4fcfa2da4..80c7f55e26 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -66,9 +66,9 @@ private[feature] trait QuantileDiscretizerBase extends Params def getRelativeError: Double = getOrDefault(relativeError) /** - * Param for how to handle invalid entries. Options are skip (filter out rows with - * invalid values), error (throw an error), or keep (keep invalid values in a special additional - * bucket). + * Param for how to handle invalid entries. Options are 'skip' (filter out rows with + * invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special + * additional bucket). * Default: "error" * @group param */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 8bcc9fe5d1..78f401f29b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -23,17 +23,12 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} -import org.apache.spark.ml.feature.{IndexToString, RFormula} -import org.apache.spark.ml.regression._ -import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.r.RWrapperUtils._ +import org.apache.spark.ml.regression._ import org.apache.spark.ml.util._ import org.apache.spark.sql._ -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ private[r] class GeneralizedLinearRegressionWrapper private ( val pipeline: PipelineModel, @@ -48,8 +43,6 @@ private[r] class GeneralizedLinearRegressionWrapper private ( val rNumIterations: Int, val isLoaded: Boolean = false) extends MLWritable { - import GeneralizedLinearRegressionWrapper._ - private val glm: GeneralizedLinearRegressionModel = pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] @@ -60,16 +53,7 @@ private[r] class GeneralizedLinearRegressionWrapper private ( def residuals(residualsType: String): DataFrame = glm.summary.residuals(residualsType) def transform(dataset: Dataset[_]): DataFrame = { - if (rFamily == "binomial") { - pipeline.transform(dataset) - .drop(PREDICTED_LABEL_PROB_COL) - .drop(PREDICTED_LABEL_INDEX_COL) - .drop(glm.getFeaturesCol) - .drop(glm.getLabelCol) - } else { - pipeline.transform(dataset) - .drop(glm.getFeaturesCol) - } + pipeline.transform(dataset).drop(glm.getFeaturesCol) } override def write: MLWriter = @@ -79,10 +63,6 @@ private[r] class GeneralizedLinearRegressionWrapper private ( private[r] object GeneralizedLinearRegressionWrapper extends MLReadable[GeneralizedLinearRegressionWrapper] { - val PREDICTED_LABEL_PROB_COL = "pred_label_prob" - val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" - val PREDICTED_LABEL_COL = "prediction" - def fit( formula: String, data: DataFrame, @@ -93,7 +73,6 @@ private[r] object GeneralizedLinearRegressionWrapper weightCol: String, regParam: Double): GeneralizedLinearRegressionWrapper = { val rFormula = new RFormula().setFormula(formula) - if (family == "binomial") rFormula.setForceIndexLabel(true) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema @@ -111,28 +90,9 @@ private[r] object GeneralizedLinearRegressionWrapper .setWeightCol(weightCol) .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) - .setLabelCol(rFormula.getLabelCol) - val pipeline = if (family == "binomial") { - // Convert prediction from probability to label index. - val probToPred = new ProbabilityToPrediction() - .setInputCol(PREDICTED_LABEL_PROB_COL) - .setOutputCol(PREDICTED_LABEL_INDEX_COL) - // Convert prediction from label index to original label. - val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) - .asInstanceOf[NominalAttribute] - val labels = labelAttr.values.get - val idxToStr = new IndexToString() - .setInputCol(PREDICTED_LABEL_INDEX_COL) - .setOutputCol(PREDICTED_LABEL_COL) - .setLabels(labels) - - new Pipeline() - .setStages(Array(rFormulaModel, glr.setPredictionCol(PREDICTED_LABEL_PROB_COL), - probToPred, idxToStr)) - .fit(data) - } else { - new Pipeline().setStages(Array(rFormulaModel, glr)).fit(data) - } + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, glr)) + .fit(data) val glm: GeneralizedLinearRegressionModel = pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] @@ -248,27 +208,3 @@ private[r] object GeneralizedLinearRegressionWrapper } } } - -/** - * This utility transformer converts the predicted value of GeneralizedLinearRegressionModel - * with "binomial" family from probability to prediction according to threshold 0.5. - */ -private[r] class ProbabilityToPrediction private[r] (override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { - - def this() = this(Identifiable.randomUID("probToPred")) - - def setInputCol(value: String): this.type = set(inputCol, value) - - def setOutputCol(value: String): this.type = set(outputCol, value) - - override def transformSchema(schema: StructType): StructType = { - StructType(schema.fields :+ StructField($(outputCol), DoubleType)) - } - - override def transform(dataset: Dataset[_]): DataFrame = { - dataset.withColumn($(outputCol), round(col($(inputCol)))) - } - - override def copy(extra: ParamMap): ProbabilityToPrediction = defaultCopy(extra) -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala index 9fe6202980..645bc7247f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -23,8 +23,9 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.feature.{IndexToString, RFormula} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -32,38 +33,48 @@ import org.apache.spark.sql.{DataFrame, Dataset} private[r] class LogisticRegressionWrapper private ( val pipeline: PipelineModel, val features: Array[String], - val isLoaded: Boolean = false) extends MLWritable { + val labels: Array[String]) extends MLWritable { import LogisticRegressionWrapper._ - private val logisticRegressionModel: LogisticRegressionModel = + private val lrModel: LogisticRegressionModel = pipeline.stages(1).asInstanceOf[LogisticRegressionModel] - lazy val totalIterations: Int = logisticRegressionModel.summary.totalIterations - - lazy val objectiveHistory: Array[Double] = logisticRegressionModel.summary.objectiveHistory - - lazy val blrSummary = - logisticRegressionModel.summary.asInstanceOf[BinaryLogisticRegressionSummary] - - lazy val roc: DataFrame = blrSummary.roc - - lazy val areaUnderROC: Double = blrSummary.areaUnderROC - - lazy val pr: DataFrame = blrSummary.pr - - lazy val fMeasureByThreshold: DataFrame = blrSummary.fMeasureByThreshold - - lazy val precisionByThreshold: DataFrame = blrSummary.precisionByThreshold + val rFeatures: Array[String] = if (lrModel.getFitIntercept) { + Array("(Intercept)") ++ features + } else { + features + } - lazy val recallByThreshold: DataFrame = blrSummary.recallByThreshold + val rCoefficients: Array[Double] = { + val numRows = lrModel.coefficientMatrix.numRows + val numCols = lrModel.coefficientMatrix.numCols + val numColsWithIntercept = if (lrModel.getFitIntercept) numCols + 1 else numCols + val coefficients: Array[Double] = new Array[Double](numRows * numColsWithIntercept) + val coefficientVectors: Seq[Vector] = lrModel.coefficientMatrix.rowIter.toSeq + var i = 0 + if (lrModel.getFitIntercept) { + while (i < numRows) { + coefficients(i * numColsWithIntercept) = lrModel.interceptVector(i) + System.arraycopy(coefficientVectors(i).toArray, 0, + coefficients, i * numColsWithIntercept + 1, numCols) + i += 1 + } + } else { + while (i < numRows) { + System.arraycopy(coefficientVectors(i).toArray, 0, + coefficients, i * numColsWithIntercept, numCols) + i += 1 + } + } + coefficients + } def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset) .drop(PREDICTED_LABEL_INDEX_COL) - .drop(logisticRegressionModel.getFeaturesCol) - .drop(logisticRegressionModel.getLabelCol) - + .drop(lrModel.getFeaturesCol) + .drop(lrModel.getLabelCol) } override def write: MLWriter = new LogisticRegressionWrapper.LogisticRegressionWrapperWriter(this) @@ -85,9 +96,7 @@ private[r] object LogisticRegressionWrapper family: String, standardization: Boolean, thresholds: Array[Double], - weightCol: String, - aggregationDepth: Int, - probability: String + weightCol: String ): LogisticRegressionWrapper = { val rFormula = new RFormula() @@ -102,7 +111,7 @@ private[r] object LogisticRegressionWrapper val (features, labels) = getFeaturesAndLabels(rFormulaModel, data) // assemble and fit the pipeline - val logisticRegression = new LogisticRegression() + val lr = new LogisticRegression() .setRegParam(regParam) .setElasticNetParam(elasticNetParam) .setMaxIter(maxIter) @@ -111,16 +120,14 @@ private[r] object LogisticRegressionWrapper .setFamily(family) .setStandardization(standardization) .setWeightCol(weightCol) - .setAggregationDepth(aggregationDepth) .setFeaturesCol(rFormula.getFeaturesCol) .setLabelCol(rFormula.getLabelCol) - .setProbabilityCol(probability) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) if (thresholds.length > 1) { - logisticRegression.setThresholds(thresholds) + lr.setThresholds(thresholds) } else { - logisticRegression.setThreshold(thresholds(0)) + lr.setThreshold(thresholds(0)) } val idxToStr = new IndexToString() @@ -129,10 +136,10 @@ private[r] object LogisticRegressionWrapper .setLabels(labels) val pipeline = new Pipeline() - .setStages(Array(rFormulaModel, logisticRegression, idxToStr)) + .setStages(Array(rFormulaModel, lr, idxToStr)) .fit(data) - new LogisticRegressionWrapper(pipeline, features) + new LogisticRegressionWrapper(pipeline, features, labels) } override def read: MLReader[LogisticRegressionWrapper] = new LogisticRegressionWrapperReader @@ -146,7 +153,8 @@ private[r] object LogisticRegressionWrapper val pipelinePath = new Path(path, "pipeline").toString val rMetadata = ("class" -> instance.getClass.getName) ~ - ("features" -> instance.features.toSeq) + ("features" -> instance.features.toSeq) ~ + ("labels" -> instance.labels.toSeq) val rMetadataJson: String = compact(render(rMetadata)) sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) @@ -164,9 +172,10 @@ private[r] object LogisticRegressionWrapper val rMetadataStr = sc.textFile(rMetadataPath, 1).first() val rMetadata = parse(rMetadataStr) val features = (rMetadata \ "features").extract[Array[String]] + val labels = (rMetadata \ "labels").extract[Array[String]] val pipeline = PipelineModel.load(pipelinePath) - new LogisticRegressionWrapper(pipeline, features, isLoaded = true) + new LogisticRegressionWrapper(pipeline, features, labels) } } -} \ No newline at end of file +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 0b860e5af9..366f375b58 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -76,7 +76,6 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC featureSubsetStrategy: String, seed: String, subsamplingRate: Double, - probabilityCol: String, maxMemoryInMB: Int, cacheNodeIds: Boolean): RandomForestClassifierWrapper = { @@ -102,7 +101,6 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC .setSubsamplingRate(subsamplingRate) .setMaxMemoryInMB(maxMemoryInMB) .setCacheNodeIds(cacheNodeIds) - .setProbabilityCol(probabilityCol) .setFeaturesCol(rFormula.getFeaturesCol) .setLabelCol(rFormula.getLabelCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 0cdfa7b0b7..01c5cc1c7e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -80,7 +80,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S * E.g. 10 means that the cache will get checkpointed every 10 iterations. * This is only used if cacheNodeIds is true and if the checkpoint directory is set in * [[org.apache.spark.SparkContext]]. - * Must be >= 1. + * Must be at least 1. * (default = 10) * @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 49a3f8b6b5..f8ab3d3a45 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -95,7 +95,7 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) * E.g. 10 means that the cache will get checkpointed every 10 iterations. * This is only used if cacheNodeIds is true and if the checkpoint directory is set in * [[org.apache.spark.SparkContext]]. - * Must be >= 1. + * Must be at least 1. * (default = 10) * @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 770a2571bb..3891ae63a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -215,6 +215,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val * Sets the value of param [[weightCol]]. * If this is not set or empty, we treat all instance weights as 1.0. * Default is not set, so all instances have weight one. + * In the Binomial family, weights correspond to number of trials and should be integer. + * Non-integer weights are rounded to integer in AIC calculation. * * @group setParam */ @@ -467,10 +469,12 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def variance(mu: Double): Double = mu * (1.0 - mu) + private def ylogy(y: Double, mu: Double): Double = { + if (y == 0) 0.0 else y * math.log(y / mu) + } + override def deviance(y: Double, mu: Double, weight: Double): Double = { - val my = 1.0 - y - 2.0 * weight * (y * math.log(math.max(y, 1.0) / mu) + - my * math.log(math.max(my, 1.0) / (1.0 - mu))) + 2.0 * weight * (ylogy(y, mu) + ylogy(1.0 - y, 1.0 - mu)) } override def aic( @@ -479,7 +483,13 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine numInstances: Double, weightSum: Double): Double = { -2.0 * predictions.map { case (y: Double, mu: Double, weight: Double) => - weight * dist.Binomial(1, mu).logProbabilityOf(math.round(y).toInt) + // weights for Binomial distribution correspond to number of trials + val wt = math.round(weight).toInt + if (wt == 0) { + 0.0 + } else { + dist.Binomial(wt, mu).logProbabilityOf(math.round(y * weight).toInt) + } }.sum() } @@ -505,7 +515,11 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def initialize(y: Double, weight: Double): Double = { require(y >= 0.0, "The response variable of Poisson family " + s"should be non-negative, but got $y") - y + /* + Force Poisson mean > 0 to avoid numerical instability in IRLS. + R uses y + 0.1 for initialization. See poisson()$initialize. + */ + math.max(y, 0.1) } override def variance(mu: Double): Double = mu diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 67fb648625..ca4a50b825 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -82,7 +82,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S * E.g. 10 means that the cache will get checkpointed every 10 iterations. * This is only used if cacheNodeIds is true and if the checkpoint directory is set in * [[org.apache.spark.SparkContext]]. - * Must be >= 1. + * Must be at least 1. * (default = 10) * @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index ff81a2f03e..e178ac0db9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -18,10 +18,12 @@ package org.apache.spark.mllib.linalg.distributed import scala.collection.mutable.ArrayBuffer +import scala.util.Random -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Matrix => BM, SparseVector => BSV, Vector => BV} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Matrix => BM, +SparseVector => BSV, Vector => BV} -import org.apache.spark.{Partitioner, SparkException} +import org.apache.spark.{Partitioner, SparkContext, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg._ @@ -335,7 +337,10 @@ class BlockMatrix @Since("1.3.0") ( new BlockMatrix(transposedBlocks, colsPerBlock, rowsPerBlock, nCols, nRows) } - /** Collects data and assembles a local dense breeze matrix (for test only). */ + /** + * Convert distributed storage of BlockMatrix into locally stored BDM, whereas + * asBreeze works on matrices stored locally and requires no memcopy. + */ private[mllib] def toBreeze(): BDM[Double] = { val localMat = toLocalMatrix() new BDM[Double](localMat.numRows, localMat.numCols, localMat.toArray) @@ -497,4 +502,264 @@ class BlockMatrix @Since("1.3.0") ( s"A.colsPerBlock: $colsPerBlock, B.rowsPerBlock: ${other.rowsPerBlock}") } } + + /** + * Computes the randomized singular value decomposition of this BlockMatrix. + * Denote this matrix by A (m x n), this will compute matrices U, S, V such + * that A ~ U * S * V', where the columns of U are orthonormal, S is a + * diagonal matrix with non-negative real numbers on the diagonal, and the + * columns of V are orthonormal. + * + * At most k largest non-zero singular values and associated vectors are + * returned. If there are k such values, then the dimensions of the return + * will be: + * - U is a RowMatrix of size m x k that satisfies U' * U = eye(k), + * - s is a Vector of size k, holding the singular values in + * descending order, + * - V is a Matrix of size n x k that satisfies V' * V = eye(k). + * + * @param k number of singular values to keep. We might return less than k + * if there are numerically zero singular values. + * @param sc SparkContext, use to generate the random gaussian matrix. + * @param computeU whether to compute U. + * @param isGram whether to compute the Gram matrix for matrix + * orthonormalization. + * @param iteration number of normalized power iterations to conduct. + * @param isRandom whether or not fix seed to generate random matrix. + * @return SingularValueDecomposition(U, s, V). + * + * @note if isGram is true, it will lose half or more of the precision + * of the arithmetic but could accelerate the computation. + */ + @Since("2.0.0") + def partialSVD(k: Int, sc: SparkContext, computeU: Boolean = false, + isGram: Boolean = false, iteration: Int = 2, + isRandom: Boolean = true): + SingularValueDecomposition[BlockMatrix, BlockMatrix] = { + + // Form and store the transpose At. + val At = transpose + + /** + * Generate a random [[BlockMatrix]] to compute the singular + * value decomposition. + * + * @param k number of columns in this random [[BlockMatrix]]. + * @param sc a [[SparkContext]] to generate the random [[BlockMatrix]]. + * @return a random [[BlockMatrix]]. + * + * @note the generated random [[BlockMatrix]] V has colsPerBlock for the + * number of rows in each block, and rowsPerBlock for the number + * of columns in each block. We will perform matrix multiplication + * with A, i.e., A * V. We want the number of rows in each block + * of V to be same as the number of columns in each block of A. + */ + def generateRandomMatrices(k: Int, sc: SparkContext): BlockMatrix = { + val rowPartitions = math.ceil(numCols().toInt * 1.0 / colsPerBlock).toInt + val colPartitions = math.ceil(k * 1.0 / rowsPerBlock).toInt + val lastColBlock = k % rowsPerBlock + val lastRowBlock = numCols().toInt % colsPerBlock + + val numPartitions = rowPartitions * colPartitions + + val data = sc.parallelize(0 until numPartitions, numPartitions).persist(). + mapPartitionsWithIndex{(idx, iter) => + // Whether to set the random seed or not. + val random = if (isRandom) new Random(158342769L + idx.toLong) + else new Random + iter.map(i => ((i/colPartitions, i%colPartitions), { + val (p, q) = { + if (i % colPartitions == colPartitions - 1 && + i / colPartitions == rowPartitions - 1 && lastColBlock > 0 + && lastRowBlock > 0) { + // last columnBlock and last RowBlock + (lastRowBlock, lastColBlock) + } else if (i % colPartitions == colPartitions - 1 + && lastColBlock > 0) { + // last columnBlock + (colsPerBlock, lastColBlock) + } else if (i / colPartitions == rowPartitions - 1 + && lastRowBlock > 0) { + // last rowBlock + (lastRowBlock, rowsPerBlock) + } else { + // not last columnBlock nor last rowBlock + (colsPerBlock, rowsPerBlock) + }} + Matrices.dense(p, q, Array.fill(p * q)(random.nextGaussian())) + })) + } + + new BlockMatrix(data, colsPerBlock, rowsPerBlock) + } + + /** + * Computes the partial singular value decomposition of the[[BlockMatrix]] + * A given an [[BlockMatrix]] Q such that A' ~ A' * Q * Q'. The columns + * of Q are orthonormal. + * + * @param Q a [[BlockMatrix]] with orthonormal columns. + * @param k number of singular values to compute. + * @param computeU whether to compute U. + * @param isGram whether to compute the Gram matrix for matrix + * orthonormalization. + * @return SingularValueDecomposition[U, s, V], U = null if computeU = false. + * + * @note if isGram is true, it will lose half or more of the precision + * of the arithmetic but could accelerate the computation. It will + * orthonormalize twice to makes the columns of the matrix be orthonormal + * to 15 digits. + */ + def lastStep(Q: BlockMatrix, k: Int, computeU: Boolean, isGram: Boolean): + SingularValueDecomposition[BlockMatrix, BlockMatrix] = { + // Compute B = At * Q. + val B = At.multiply(Q) + + // Find SVD of B such that B = V * S * X'. + val svdResult = B.toIndexedRowMatrix().tallSkinnySVD( + Math.min(k, B.nCols.toInt), sc, computeU = true, isGram, ifTwice = true) + + // Convert V's type. + val V = svdResult.U.toBlockMatrix(colsPerBlock, rowsPerBlock) + + // Compute U amd return svd of A. + if (computeU) { + // U = Q * X. + val XMat = svdResult.V + val U = Q.toIndexedRowMatrix().multiply(XMat). + toBlockMatrix(rowsPerBlock, colsPerBlock) + SingularValueDecomposition(U, svdResult.s, V) + } else { + SingularValueDecomposition(null, svdResult.s, V) + } + } + + val V = generateRandomMatrices(k, sc) + + // V = A * V, with the V on the left now known as x. + val x = multiply(V) + // Orthonormalize V (now known as x). + var y = x.orthonormal(sc, isGram, ifTwice = false) + + for (i <- 0 until iteration) { + // V = At * V, with the V on the left now known as a, and the V on + // the right known as y. + val a = At.multiply(y) + // Orthonormalize V (now known as a). + val b = a.orthonormal(sc, isGram, ifTwice = false) + // V = A * V, with the V on the left now known as c, and the V on + // the right known as b. + val c = multiply(b) + // Orthonormalize V (now known as c). If it is the last iteration, we + // perform orthonormalization twice so the columns of left singular + // vectors of A will be orthonormal to 15 digits. + val ifTwice = if (i == iteration - 1) true else false + y = c.orthonormal(sc, isGram, ifTwice) + } + + // Find SVD of A using V (now known as y). + lastStep(y, k, computeU, isGram) + } + + /** + * Orthonormalize the columns of the [[BlockMatrix]] V by using + * tallSkinnySVD. We convert V to [[RowMatrix]] first, then apply + * tallSkinnySVD. The columns of the result orthonormal matrix are the left + * singular vectors of the input matrix V. + * + * @param sc SparkContext used to create RDDs if isGram = false. + * @param isGram whether to compute the Gram matrix when computing + * tallSkinnySVD. + * @param ifTwice whether to compute orthonormalization twice to make + * the columns of the matrix be orthonormal to nearly the + * machine precision. + * @return a [[BlockMatrix]] whose columns are orthonormal vectors. + * + * @note if isGram is true, it will lose half or more of the precision + * of the arithmetic but could accelerate the computation. + */ + @Since("2.0.0") + def orthonormal(sc: SparkContext = null, isGram: Boolean = false, + ifTwice: Boolean = true): BlockMatrix = { + // Orthonormalize the columns of the input BlockMatrix. + val Q = toIndexedRowMatrix().tallSkinnySVD(nCols.toInt, + sc, computeU = true, isGram, ifTwice).U + Q.toBlockMatrix(rowsPerBlock, colsPerBlock) + + } + + /** + * Estimate the largest singular value of [[BlockMatrix]] A using + * power method. + * + * @param iteration number of iterations for power method. + * @param sc a [[SparkContext]] generates the normalized + * vector in each iteration. + * @return a [[Double]] estimate of the largest singular value. + */ + @Since("2.0.0") + def spectralNormEst(iteration: Int = 20, sc: SparkContext): Double = { + /** + * Normalize the [[BlockMatrix]] v which has one column such that it has + * unit norm. + * + * @param v the [[BlockMatrix]] which has one column. + * @param sc SparkContext, use to generate the normalized [[BlockMatrix]]. + * @return a [[BlockMatrix]] such that it has unit norm. + */ + def unit(v: BlockMatrix, sc: SparkContext): BlockMatrix = { + + // Find the norm of v. + val vSquareSum = v.blocks.map{ case ((a, b), c) => + c.toArray.map(x => x*x).sum}.sum() + // Normalize v. + val vUnit = v.blocks.map{ case((a, b), c) => + ((a, b), c.map(x => x/math.sqrt(vSquareSum)))} + + new BlockMatrix(vUnit, v.rowsPerBlock, v.colsPerBlock) + } + + // Generate a random vector v. + var v = { + val rowPartitions = math.ceil(numCols().toInt * 1.0 / colsPerBlock).toInt + val lastRowBlock = numCols().toInt % colsPerBlock + + val data = sc.parallelize(0 until rowPartitions, rowPartitions).persist(). + mapPartitionsWithIndex{(idx, iter) => + val random = new Random(951342768L + idx.toLong) + iter.map(i => ((i, 0), { + val p = { + if (i == rowPartitions - 1 && lastRowBlock > 0) { + // last rowBlock + lastRowBlock + } else { + // not last rowBlock + colsPerBlock + }} + Matrices.dense(p, 1, Array.fill(p)(random.nextGaussian())) + })) + } + + new BlockMatrix(data, colsPerBlock, 1) + } + + // Form and store the transpose. + val At = transpose + + // Find the largest singular value of A using power method. + for (i <- 0 until iteration) { + // normalize v (now known as u). + val u = unit(v, sc) + // v = A * u (now known as Av). + val Av = multiply(u) + // normalize v (now known as w). + val w = unit(Av, sc) + // v = A' * v. + v = At.multiply(w) + } + + // Calculate the 2-norm of final v. + math.sqrt(v.blocks.map{ case ((a, b), c) => + c.toArray.map(x => x*x).sum}.sum()) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index d7255d527f..cc3bda060c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -17,12 +17,26 @@ package org.apache.spark.mllib.linalg.distributed -import breeze.linalg.{DenseMatrix => BDM} +import java.util.Arrays +import scala.util.Random + +import breeze.linalg.{*, axpy => brzAxpy, eigSym, shuffle, svd => brzSvd, +DenseMatrix => BDM, DenseVector => BDV, MatrixSingularException, +SparseVector => BSV} +import breeze.linalg.eigSym.EigSym +import breeze.math.{i, Complex} +import breeze.numerics.{sqrt => brzSqrt} +import breeze.signal.{fourierTr, iFourierTr} + +import org.apache.spark.SparkContext import org.apache.spark.annotation.Since +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg._ -import org.apache.spark.mllib.linalg.SingularValueDecomposition +import org.apache.spark.mllib.stat.Statistics import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + /** * Represents a row of [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]]. @@ -44,7 +58,7 @@ case class IndexedRow(index: Long, vector: Vector) class IndexedRowMatrix @Since("1.0.0") ( @Since("1.0.0") val rows: RDD[IndexedRow], private var nRows: Long, - private var nCols: Int) extends DistributedMatrix { + private var nCols: Int) extends DistributedMatrix with Logging{ /** Alternative constructor leaving matrix dimensions to be determined automatically. */ @Since("1.0.0") @@ -131,62 +145,740 @@ class IndexedRowMatrix @Since("1.0.0") ( } /** - * Computes the singular value decomposition of this IndexedRowMatrix. - * Denote this matrix by A (m x n), this will compute matrices U, S, V such that A = U * S * V'. + * Multiplies the Gramian matrix `A^T A` by a dense vector on the right + * without computing `A^T A`. * - * The cost and implementation of this method is identical to that in - * [[org.apache.spark.mllib.linalg.distributed.RowMatrix]] - * With the addition of indices. + * @param v a dense vector whose length must match the number of columns of + * this matrix. + * @return a dense vector representing the product. + */ + private[mllib] def multiplyGramianMatrixBy(v: BDV[Double]): BDV[Double] = { + val n = numCols().toInt + val vbr = rows.context.broadcast(v) + rows.treeAggregate(BDV.zeros[Double](n))( + seqOp = (U, r) => { + val rBrz = r.vector.asBreeze + val a = rBrz.dot(vbr.value) + rBrz match { + // use specialized axpy for better performance + case _: BDV[_] => brzAxpy(a, rBrz.asInstanceOf[BDV[Double]], U) + case _: BSV[_] => brzAxpy(a, rBrz.asInstanceOf[BSV[Double]], U) + case _ => throw new UnsupportedOperationException( + s"Do not support vector operation from type" + + s" ${rBrz.getClass.getName}.") + } + U + }, combOp = (U1, U2) => U1 += U2) + } + + /** + * Computes singular value decomposition of this matrix. Denote this matrix + * by A (m x n). This will compute matrices U, S, V such that + * A ~ U * S * V', where S contains the leading k singular values, U and V + * contain the corresponding singular vectors. * - * At most k largest non-zero singular values and associated vectors are returned. - * If there are k such values, then the dimensions of the return will be: + * At most k largest non-zero singular values and associated vectors are + * returned. If there are k such values, then the dimensions of the return + * will be: + * - U is a IndexedRowMatrix of size m x k that satisfies U' * U = eye(k), + * - s is a Vector of size k, holding the singular values in descending + * order, + * - V is a Matrix of size n x k that satisfies V' * V = eye(k). * - * U is an [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]] of size m x k that - * satisfies U'U = eye(k), - * s is a Vector of size k, holding the singular values in descending order, - * and V is a local Matrix of size n x k that satisfies V'V = eye(k). + * We assume n is smaller than m, though this is not strictly required. + * The singular values and the right singular vectors are derived + * from the eigenvalues and the eigenvectors of the Gramian matrix A' * A. U, + * the matrix storing the right singular vectors, is computed via matrix + * multiplication as U = A * (V * S^-1^), if requested by user. The actual + * method to use is determined automatically based on the cost: + * - If n is small (n < 100) or k is large compared with n (k > + * n / 2), we compute the Gramian matrix first and then compute its top + * eigenvalues and eigenvectors locally on the driver. This requires a + * single pass with O(n^2^) storage on each executor and on the driver, + * and O(n^2^ k) time on the driver. + * - Otherwise, we compute (A' * A) * v in a distributive way and send it to + * ARPACK's DSAUPD to compute (A' * A)'s top eigenvalues and eigenvectors + * on the driver node. This requires O(k) passes, O(n) storage on each + * executor, and O(n k) storage on the driver. * - * @param k number of singular values to keep. We might return less than k if there are - * numerically zero singular values. See rCond. + * Several internal parameters are set to default values. The reciprocal + * condition number rCond is set to 1e-9. All singular values smaller than + * rCond * sigma(0) are treated as zeros, where sigma(0) is the largest + * singular value. The maximum number of Arnoldi update iterations for + * ARPACK is set to 300 or k * 3, whichever is larger. The numerical + * tolerance for ARPACK's eigen-decomposition is set to 1e-10. + * + * @note The conditions that decide which method to use internally and the + * default parameters are subject to change. + * + * @param k number of leading singular values to keep (0 < k <= n). + * It might return less than k if there are numerically zero + * singular values or there are not enough Ritz values converged + * before the maximum number of Arnoldi update iterations is reached + * (in case that matrix A is ill-conditioned). * @param computeU whether to compute U - * @param rCond the reciprocal condition number. All singular values smaller than rCond * sigma(0) - * are treated as zero, where sigma(0) is the largest singular value. - * @return SingularValueDecomposition(U, s, V) + * @param rCond the reciprocal condition number. All singular values smaller + * than rCond * sigma(0) are treated as zero, where sigma(0) is + * the largest singular value. + * @return SingularValueDecomposition(U, s, V). U = null if computeU = false. */ @Since("1.0.0") def computeSVD( - k: Int, - computeU: Boolean = false, - rCond: Double = 1e-9): SingularValueDecomposition[IndexedRowMatrix, Matrix] = { + k: Int, + computeU: Boolean = false, + rCond: Double = 1e-9): + SingularValueDecomposition[IndexedRowMatrix, Matrix] = { + // maximum number of Arnoldi update iterations for invoking ARPACK + val maxIter = math.max(300, k * 3) + // numerical tolerance for invoking ARPACK + val tol = 1e-10 + computeSVD(k, computeU, rCond, maxIter, tol, "auto") + } + /** + * The actual SVD implementation, visible for testing. + * + * @param k number of leading singular values to keep (0 < k <= n). + * @param computeU whether to compute U. + * @param rCond the reciprocal condition number. + * @param maxIter max number of iterations (if ARPACK is used). + * @param tol termination tolerance (if ARPACK is used). + * @param mode computation mode (auto: determine automatically which mode to + * use, local-svd: compute gram matrix and computes its full SVD + * locally, local-eigs: compute gram matrix and computes its top + * eigenvalues locally, dist-eigs: compute the top eigenvalues of + * the gram matrix distributively). + * @return SingularValueDecomposition(U, s, V). U = null if computeU = false. + */ + private[mllib] def computeSVD( + k: Int, + computeU: Boolean, + rCond: Double, + maxIter: Int, + tol: Double, + mode: String): + SingularValueDecomposition[IndexedRowMatrix, Matrix] = { val n = numCols().toInt - require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.") - val indices = rows.map(_.index) - val svd = toRowMatrix().computeSVD(k, computeU, rCond) - val U = if (computeU) { - val indexedRows = indices.zip(svd.U.rows).map { case (i, v) => - IndexedRow(i, v) + require(k > 0 && k <= n, s"Requested k singular values but got k=$k" + + s" and numCols=$n.") + + object SVDMode extends Enumeration { + val LocalARPACK, LocalLAPACK, DistARPACK = Value + } + + val computeMode = mode match { + case "auto" => + if (k > 5000) { + logWarning(s"computing svd with k=$k and n=$n, please check" + + " necessity") + } + + // TODO: The conditions below are not fully tested. + if (n < 100 || (k > n / 2 && n <= 15000)) { + // If n is small or k is large compared with n, we better compute the + // Gramian matrix first and then compute its eigenvalues locally, + // instead of making multiple passes. + if (k < n / 3) { + SVDMode.LocalARPACK + } else { + SVDMode.LocalLAPACK + } + } else { + // If k is small compared with n, we use ARPACK with distributed + // multiplication. + SVDMode.DistARPACK + } + case "local-svd" => SVDMode.LocalLAPACK + case "local-eigs" => SVDMode.LocalARPACK + case "dist-eigs" => SVDMode.DistARPACK + case _ => throw new IllegalArgumentException(s"Do not support" + + s" mode $mode.") + } + + // Compute the eigen-decomposition of A' * A. + val (sigmaSquares: BDV[Double], u: BDM[Double]) = computeMode match { + case SVDMode.LocalARPACK => + require(k < n, s"k must be smaller than n in local-eigs mode but" + + s" got k=$k and n=$n.") + val G = computeGramianMatrix().asBreeze.asInstanceOf[BDM[Double]] + EigenValueDecomposition.symmetricEigs(v => G * v, n, k, tol, maxIter) + case SVDMode.LocalLAPACK => + // breeze (v0.10) svd latent constraint, 7 * n * n + 4 * n < + // Int.MaxValue. + require(n < 17515, s"$n exceeds the breeze svd capability") + val G = computeGramianMatrix().asBreeze.asInstanceOf[BDM[Double]] + val brzSvd.SVD(uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = + brzSvd(G) + (sigmaSquaresFull, uFull) + case SVDMode.DistARPACK => + if (rows.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data is not directly cached, which may hurt" + + "performance if its parent RDDs are also uncached.") + } + require(k < n, s"k must be smaller than n in dist-eigs mode but got" + + s" k=$k and n=$n.") + EigenValueDecomposition.symmetricEigs(multiplyGramianMatrixBy, n, k, + tol, maxIter) + } + + val sigmas: BDV[Double] = brzSqrt(sigmaSquares) + + // Determine the effective rank. + val sigma0 = sigmas(0) + val threshold = rCond * sigma0 + var i = 0 + // sigmas might have a length smaller than k, if some Ritz values do not satisfy the convergence + // criterion specified by tol after max number of iterations. + // Thus use i < min(k, sigmas.length) instead of i < k. + if (sigmas.length < k) { + logWarning(s"Requested $k singular values but only found ${sigmas.length} converged.") + } + while (i < math.min(k, sigmas.length) && sigmas(i) >= threshold) { + i += 1 + } + val sk = i + + // Warn at the end of the run as well, for increased visibility. + if (computeMode == SVDMode.DistARPACK && rows.getStorageLevel == + StorageLevel.NONE) { + logWarning("The input data was not directly cached, which may hurt" + + " performance if its parent RDDs are also uncached.") + } + + val s = Vectors.dense(Arrays.copyOfRange(sigmas.data, 0, sk)) + val V = Matrices.dense(n, sk, Arrays.copyOfRange(u.data, 0, n * sk)) + + if (computeU) { + // N = Vk * Sk^{-1} + val N = new BDM[Double](n, sk, Arrays.copyOfRange(u.data, 0, n * sk)) + var i = 0 + var j = 0 + while (j < sk) { + i = 0 + val sigma = sigmas(j) + while (i < n) { + N(i, j) /= sigma + i += 1 + } + j += 1 } - new IndexedRowMatrix(indexedRows, nRows, svd.U.numCols().toInt) + val U = this.multiply(Matrices.fromBreeze(N)) + SingularValueDecomposition(U, s, V) } else { - null + SingularValueDecomposition(null, s, V) } - SingularValueDecomposition(U, svd.s, svd.V) } /** * Multiply this matrix by a local matrix on the right. * - * @param B a local matrix whose number of rows must match the number of columns of this matrix - * @return an IndexedRowMatrix representing the product, which preserves partitioning + * @param B a local matrix whose number of rows must match the number of + * columns of this matrix. + * @return an IndexedRowMatrix representing the product, which preserves + * partitioning. */ @Since("1.0.0") def multiply(B: Matrix): IndexedRowMatrix = { - val mat = toRowMatrix().multiply(B) - val indexedRows = rows.map(_.index).zip(mat.rows).map { case (i, v) => - IndexedRow(i, v) + val n = numCols().toInt + val k = B.numCols + require(n == B.numRows, s"Dimension mismatch: $n vs ${B.numRows}") + + require(B.isInstanceOf[DenseMatrix], s"Only support dense matrix at" + + s" this time but found ${B.getClass.getName}.") + + val Bb = rows.context.broadcast(B.asBreeze.asInstanceOf[BDM[Double]]. + toDenseVector.toArray) + val AB = rows.mapPartitions { iter => + val Bi = Bb.value + iter.map { row => + val v = BDV.zeros[Double](k) + var i = 0 + while (i < k) { + v(i) = row.vector.asBreeze.dot(new BDV(Bi, i * n, 1, n)) + i += 1 + } + IndexedRow(row.index, Vectors.fromBreeze(v)) + } + } + + new IndexedRowMatrix(AB, nRows, B.numCols) + } + + /** + * Compute QR decomposition for [[IndexedRowMatrix]]. The implementation is + * designed to optimize the QR decomposition (factorization) for the + * [[IndexedRowMatrix]] of a tall and skinny shape. + * Reference: + * Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations + * in MapReduce architectures" + * ([[http://dx.doi.org/10.1145/1996092.1996103]]) + * + * @param computeQ whether to computeQ + * @return QRDecomposition(Q, R), Q = null if computeQ = false. + */ + @Since("1.5.0") + def tallSkinnyQR(computeQ: Boolean = false): + QRDecomposition[IndexedRowMatrix, Matrix] = { + /** + * Solve Q*R = A for Q using forward substitution where A = + * [[IndexedRowMatrix]] and R is upper-triangular. If the (i,i)th entry of + * R is close to 0, then we set the ith column of Q to 0, as well. + * + * @param R upper-triangular matrix. + * @return Q [[IndexedRowMatrix]] such that Q*R = A. + */ + def forwardSolve(R: breeze.linalg.DenseMatrix[Double]): + IndexedRowMatrix = { + val m = numRows() + val n = R.cols + val dim = math.min(R.rows, n) + val Bb = rows.context.broadcast(R(0 until dim, 0 until dim).toArray) + + val AB = rows.mapPartitions { iter => + val LHS = Bb.value + val LHSMat = Matrices.dense(dim, dim, LHS).asBreeze + val FNorm = Vectors.norm(Vectors.dense(LHS), 2.0) + iter.map { row => + val RHS = row.vector.asBreeze.toArray + val v = BDV.zeros[Double] (dim) + // We don't use LAPACK here since it will be numerically unstable if + // R is singular. If R is singular, we set the corresponding + // column of Q to 0. + for ( i <- 0 until dim) { + v(i) = if (math.abs(LHSMat(i, i)) > 1.0e-15 * FNorm) { + val sum = (0 until i).map{ j => LHSMat(j, i) * v(j)}.toArray.sum + (RHS(i) - sum) / LHSMat(i, i) + } else 0.0 + } + IndexedRow(row.index, Vectors.fromBreeze(v)) + } + } + new IndexedRowMatrix(AB, m, dim) + } + + val col = numCols().toInt + // partition into blocks of rows, and compute QR for each of them. + val blockQRs = rows.retag(classOf[IndexedRow]).glom(). + filter(_.length != 0).map { partRows => + val bdm = BDM.zeros[Double](partRows.length, col) + var i = 0 + partRows.foreach { row => + bdm(i, ::) := row.vector.asBreeze.t + i += 1 + } + breeze.linalg.qr.reduced(bdm).r + } + + // combine the R part from previous results vertically into a tall matrix + val combinedR = blockQRs.treeReduce { (r1, r2) => + val stackedR = BDM.vertcat(r1, r2) + breeze.linalg.qr.reduced(stackedR).r + } + + val finalR = Matrices.fromBreeze(combinedR.toDenseMatrix) + val finalQ = if (computeQ) { + try { + forwardSolve(combinedR) + } catch { + case err: MatrixSingularException => + logWarning("R is not invertible and return Q as null") + null + } + } else { + null + } + QRDecomposition(finalQ, finalR) + } + + /** + * Compute SVD decomposition for [[IndexedRowMatrix]] A. The implementation + * is designed to optimize the SVD decomposition (factorization) for the + * [[IndexedRowMatrix]] of a tall and skinny shape. We either: (1) multiply + * the matrix being processed by a random orthogonal matrix in order to mix + * the columns, obviating the need for pivoting; or (2) compute the Gram + * matrix of A. + * + * References: + * Parker, Douglass Stott, and Brad Pierce. The randomizing FFT: an + * alternative to pivoting in Gaussian elimination. University of + * California (Los Angeles). Computer Science Department, 1995. + * Le, Dinh, and D. Stott Parker. "Using randomization to make recursive + * matrix algorithms practical." Journal of Functional Programming + * 9.06 (1999): 605-624. + * Benson, Austin R., David F. Gleich, and James Demmel. "Direct QR + * factorizations for tall-and-skinny matrices in MapReduce architectures." + * Big Data, 2013 IEEE International Conference on. IEEE, 2013. + * Mary, Theo, et al. "Performance of random sampling for computing + * low-rank approximations of a dense matrix on GPUs." Proceedings of the + * International Conference for High Performance Computing, Networking, + * Storage and Analysis. ACM, 2015. + * + * @param k number of singular values to keep. We might return less than k + * if there are numerically zero singular values. See rCond. + * @param sc SparkContext used in an intermediate step which converts an + * upper triangular matrix to RDD[IndexedRow] if isGram = false. + * @param computeU whether to compute U. + * @param isGram whether to compute the Gram matrix for matrix + * orthonormalization. + * @param ifTwice whether to compute orthonormalization twice to make + * the columns of the matrix be orthonormal to nearly the + * machine precision. + * @param iteration number of times to run multiplyDFS if isGram = false. + * @param rCond the reciprocal condition number. All singular values smaller + * than rCond * sigma(0) are treated as zero, where sigma(0) is + * the largest singular value. + * @return SingularValueDecomposition[U, s, V], U = null if computeU = false. + * @note it will lose half or more of the precision of the arithmetic + * but could accelerate the computation if isGram = true. + */ + @Since("2.0.0") + def tallSkinnySVD(k: Int, sc: SparkContext = null, computeU: Boolean = false, + isGram: Boolean = false, ifTwice: Boolean = true, + iteration: Int = 2, rCond: Option[Double] = None): + SingularValueDecomposition[IndexedRowMatrix, Matrix] = { + + /** + * Convert [[Matrix]] to [[RDD[IndexedRow]]]. + * @param mat an [[Matrix]]. + * @param sc SparkContext used to create RDDs. + * @return RDD[IndexedRow]. + */ + def toRDD(mat: Matrix, sc: SparkContext): RDD[IndexedRow] = { + val columns = mat.transpose.toArray.grouped(mat.numCols).zipWithIndex + val rows = columns.toSeq + val vectors = rows.map( row => IndexedRow(row._2.toLong, + new DenseVector(row._1))) + + // Create RDD[IndexedRow] + sc.parallelize(vectors) + } + + require(k > 0 && k <= numCols().toInt, + s"Requested k singular values but got k=$k and" + + s" numCols=$numCols().toInt.") + + // Compute Q and R such that A = Q * R where Q has orthonormal columns. + // When isGram = true, the columns of Q are the left singular vectors of A + // and R is not necessary upper triangular. When isGram = false, R is upper + // triangular. + val (qMat, rMat) = if (isGram) { + if (ifTwice) { + // Apply computeSVDbyGram twice to A in order to produce the + // factorization A = U1 * S1 * V1' = U2 * (S2 * V2' * S1 * V1') + // = U2 * R. Orthonormalizing twice makes the columns of U2 be + // orthonormal to nearly the machine precision. + val svdResult1 = computeSVDbyGram(computeU = true) + val svdResult2 = svdResult1.U.computeSVDbyGram(computeU = true) + val V1 = svdResult1.V.asBreeze.toDenseMatrix + val V2 = svdResult2.V.asBreeze.toDenseMatrix + // Compute R1 = S1 * V1'. + val R1 = V1.mapPairs{case ((i, j), x) => x * svdResult1.s(j)}.t + // Compute R2 = S2 * V2'. + val R2 = V2.mapPairs{case ((i, j), x) => x * svdResult2.s(j)}.t + // Return U2 and R = R2 * R1. + (svdResult2.U, R2 * R1) + } else { + // Apply computeSVDbyGram to A and directly return the result. + return computeSVDbyGram(computeU) + } + } else { + // Convert the input IndexedRowMatrix A to another IndexedRowMatrix B by + // multiplying with a random matrix, discrete fourier transform, and + // random shuffle, i.e. A * Q = B where Q = D * F * S. Repeat several + // times, according to the number of iterations specified by iteration + // (default 2). + val (aq, randUnit, randIndex) = multiplyDFS(iteration, + isForward = true, null, null) + + val (qMat, rMat) = if (ifTwice) { + // Apply tallSkinnyQR twice to B in order to produce the + // factorization B = Q1 * R1 = Q2 * R2 * R1 = Q2 * (R2 * R1) = Q * R. + // Orthonormalizing twice makes the columns of the matrix be + // orthonormal to nearly the machine precision. Later parts of the code + // assume that the columns are numerically orthonormal in order to + // simplify the computations. + val qrResult1 = aq.tallSkinnyQR(computeQ = true) + val qrResult2 = qrResult1.Q.tallSkinnyQR(computeQ = true) + // Return Q and R = R2 * R1. + (qrResult2.Q, qrResult2.R.asBreeze.toDenseMatrix * + qrResult1.R.asBreeze.toDenseMatrix) + } else { + // Apply tallSkinnyQR to B such that B = Q * R. + val qrResult = aq.tallSkinnyQR(computeQ = true) + (qrResult.Q, qrResult.R.asBreeze.toDenseMatrix) + } + // Convert R to IndexedRowMatrix. + val RIndexRowMat = new IndexedRowMatrix(toRDD(Matrices.fromBreeze(rMat), + sc)) + // Convert RIndexRowMat back by reverse shuffle, inverse fourier + // transform, and dividing random matrix, + // i.e. R * Q^T = R * S^{-1} * F^{-1} * D^{-1}. Repeat several times, + // according to the number of iterations specified by iteration. + val (rq, _, _) = RIndexRowMat.multiplyDFS(iteration, + isForward = false, randUnit, randIndex) + (qMat, rq.toBreeze()) + } + // Apply SVD on R * Q^T. + val brzSvd.SVD(w, s, vt) = brzSvd.reduced.apply(rMat) + + // Determine the effective rank. + val rConD = if (rCond.isDefined) rCond.get + else if (!isGram) 1e-11 else 1e-6 + + val threshold = rConD * s(0) + var rank = 0 + while (rank < math.min(k, s.length) && s(rank) >= threshold) { + rank += 1 + } + + // Truncate S, V. + val sk = Vectors.fromBreeze(s(0 until rank)) + val VMat = Matrices.dense(rank, numCols().toInt, vt(0 until rank, + 0 until numCols().toInt).toArray).transpose + + if (computeU) { + // Truncate W. + val WMat = Matrices.dense(rMat.rows, rank, + Arrays.copyOfRange(w.toArray, 0, rMat.rows * rank)) + // U = Q * W. + val U = qMat.multiply(WMat) + SingularValueDecomposition(U, sk, VMat) + } else { + SingularValueDecomposition(null, sk, VMat) + } + } + + /** + * Given a m-by-2n or m-by-(2n+1) real [[IndexedRowMatrix]], convert it to + * m-by-n complex [[IndexedRowMatrix]]. Multiply this m-by-n complex + * [[IndexedRowMatrix]] by a random diagonal n-by-n [[BDM[Complex]] D, + * discrete fourier transform F, and random shuffle n-by-n [[BDM[Int]]] S + * with a given [[Int]] k number of times, and convert it back to m-by-2n + * real [[IndexedRowMatrix]]; or backwards, i.e., convert it from real to + * complex, apply reverse random shuffle n-by-n [[BDM[Int]]] S^{-1}, inverse + * fourier transform F^{-1}, dividing the given random diagonal n-by-n + * [[BDM[Complex]] D with a given [[Int]] k number of times, and convert it + * from complex to real. + * + * References: + * Parker, Douglass Stott, and Brad Pierce. The randomizing FFT: an + * alternative to pivoting in Gaussian elimination. University of + * California (Los Angeles). Computer Science Department, 1995. + * Le, Dinh, and D. Stott Parker. "Using randomization to make recursive + * matrix algorithms practical." Journal of Functional Programming + * 9.06 (1999): 605-624. + * Ailon, Nir, and Edo Liberty. "An almost optimal unrestricted fast + * Johnson-Lindenstrauss transform." ACM Transactions on Algorithms (TALG) + * 9.3 (2013): 21. + * + * @note The entries with the same column index of input + * [[IndexedRowMatrix]] are multiplied by the same random number, and + * shuffle to the same place. + * + * @param iteration k number of times applying D, F, and S. + * @param isForward whether to apply D, F, S forwards or backwards. + * If backwards, then needs to specify rUnit and rIndex. + * @param rUnit a complex k-by-n matrix such that each entry is a complex + * number with absolute value 1. + * @param rIndex an integer k-by-n matrix such that each row is a random + * permutation of the integers 1, 2, ..., n. + * @return transformed m-by-2n or m-by-(2n+1) IndexedRowMatrix, a complex + * k-by-n matrix, and an int k-by-n matrix. + */ + @Since("2.0.0") + def multiplyDFS(iteration: Int = 2, isForward: Boolean, rUnit: BDM[Complex], + rIndex: BDM[Int]): + (IndexedRowMatrix, BDM[Complex], BDM[Int]) = { + + /** + * Given a 1-by-2n [[BDV[Double]] arr, either do D, F, S forwards with + * [[Int]] iteration k times if [[Boolean]] isForward is true; + * or S, F, D backwards with [[Int]] iteration k times if [[Boolean]] + * isForward is false. + * + * @param iteration k number of times applying D, F, and S. + * @param isForward whether to apply D, F, S forwards or backwards. + * If backwards, then needs to specify rUnit and rIndex. + * @param randUnit a complex k-by-n such that each entry is a complex number + * with absolute value 1. + * @param randIndex an integer k-by-n matrix such that each row is a random + * permutation of the integers 1, 2, ..., n. + * @param arr a 1-by-2n [[BDV[Double]]]. + * @param index the row index of arr. + * @return a 1-by-2n [[IndexedRow]]. + */ + def dfs(iteration: Int, isForward: Boolean, randUnit: BDM[Complex], + randIndex: BDM[Int], arr: BDV[Double], index: Long): IndexedRow = { + + // Either keep arr or add an extra entry with value 0 so that + // number of indices is even. + val input = { + if (arr.length % 2 == 1) BDV.vertcat(arr, BDV.zeros[Double](1)) + else arr + } + + // convert input from real to complex. + var inputComplex = realToComplex(input) + if (isForward) { + // Apply D, F, S to the input iteration times. + for (i <- 0 until iteration) { + // Element-wise multiplication with randUnit. + val inputMul = inputComplex :* randUnit(i, ::).t + // Discrete Fourier transform. + val inputFFT = fourierTr(inputMul).toArray + // Random shuffle. + val shuffleIndex = randIndex(i, ::).t.toArray + inputComplex = BDV(shuffle(inputFFT, shuffleIndex, false)) + } + } else { + // Apply S^{-1}, F^{-1}, D^{-1} to the input iteration times. + for (i <- iteration - 1 to 0 by -1) { + // Reverse shuffle. + val shuffleIndex = randIndex(i, ::).t.toArray + val shuffleBackArr = shuffle(inputComplex.toArray, shuffleIndex, true) + // Inverse Fourier transform. + val inputIFFT = iFourierTr(BDV(shuffleBackArr)) + // Element-wise divide with randUnit. + inputComplex = inputIFFT :/ randUnit(i, ::).t + } + } + IndexedRow(index, Vectors.fromBreeze(complexToReal(inputComplex))) + } + + /** + * Given [[Int]] k and [[Int]] n, generate a k-by-n [[BDM[Complex]]] such + * that each entry has absolute value 1 and a k-by-n [[BDM[Int]]] such + * that each row is a random permutation of integers from 1 to n. + * + * @param iteration k number of rows for D and S. + * @param nCols the number of columns n. + * @return a k-by-n complex matrix and a k-by-n int matrix. + */ + def generateDS(iteration: Int = 2, nCols: Int): (BDM[Complex], BDM[Int]) = { + val temp = new BDM[Int](iteration, nCols) + + // Random permuatation of integers from 1 to n iteration times. + val shuffleIndex = temp(*, ::).map{dt => + shuffle(BDV((0 until nCols).toArray)) + } + // Generate random complex number with absolute value 1. These random + // complex numbers are uniformly distributed over the unit circle. + val randUnit = temp.mapPairs{case ((i, j), x) => + val random = new Random(851342769L + j.toLong) + val randComplex = Complex(random.nextGaussian(), + Random.nextGaussian()) + randComplex / randComplex.abs + } + (randUnit, shuffleIndex) + } + + /** + * Convert a 2n-by-1 [[BDV[Double]]] u to n-by-1 [[BDV[Complex]]] v. The + * odd index entry of u changes to the real part of each entry in v. The + * even index entry of u changes to the imaginary part of each entry in v. + * Please note that "index" here refers to "1-based indexing" rather than + * "0-based indexing." + * + * @param arr 2n-by-1 real vector. + * @return n-by-1 complex vector. + */ + def realToComplex(arr: BDV[Double]): BDV[Complex] = { + // Odd entries transfer to real part. + val odd = arr(0 until arr.length by 2).map(v => v + i * 0) + // Even entries transfer to imaginary part. + val even = arr(1 until arr.length by 2).map(v => i * v) + // Combine real part and imaginary part. + odd + even + } + + /** + * Convert a n-by-1 [[BDV[Complex]]] v to 2n-by-1 [[BDV[Double]]] u. The + * the real part of each entry in v changes to the odd index entry of u. + * The imaginary part of each entry in v changes to the even index entry + * uf u. Please note that "index" here refers to "1-based indexing" rather + * than "0-based indexing." + * + * @param arr n-by-1 complex vector. + * @return 2n-by-1 real vector. + */ + def complexToReal(arr: BDV[Complex]): BDV[Double] = { + // Filter out the real part. + val reconReal = arr.map(v => v.real) + // Filter out the imaginary part. + val reconImag = arr.map(v => v.imag) + // Concatenate the real and imaginary part. + BDV.horzcat(reconReal, reconImag).t.toDenseVector + } + + // Either generate D and S or take them from the input. + val (randUnit, randIndex) = { + if (isForward) generateDS(iteration, (numCols().toInt + 1) / 2) + else (rUnit, rIndex) + } + + // Apply DFS forwards or backwards to the input IndexedRowMatrix. + val AB = rows.mapPartitions( iter => + if (iter.nonEmpty) { + val temp = iter.toArray + val tempAfter = temp.map{ i => + dfs(iteration, isForward, randUnit, randIndex, + BDV(i.vector.toArray), i.index) + } + Iterator.tabulate(temp.length)(tempAfter(_)) + } else { + Iterator.empty + } + ) + + // Generate the output IndexedRowMatrix with even number of columns. + val n = if (nCols % 2 == 0) nCols else nCols + 1 + (new IndexedRowMatrix(AB, nRows, n), randUnit, randIndex) + } + + /** + * Compute the singular value decomposition of the [[IndexedRowMatrix]] A + * such that A ~ U * S * V' via computing the Gram matrix of A. We (1) + * compute the Gram matrix G = A' * A, (2) apply the eigenvalue decomposition + * on G = V * D * V', (3) compute W = A * V, then the Euclidean norms of the + * columns of W are the singular values of A, and (4) normalizing the columns + * of W yields U such that A = U * S * V', where S is the diagonal matrix of + * singular values. + * + * @return SingularValueDecomposition[U, s, V]. + * @note it will lose half or more of the precision of the arithmetic + * but could accelerate the computation compared to tallSkinnyQR. + */ + @Since("2.0.0") + def computeSVDbyGram(computeU: Boolean = false): + SingularValueDecomposition[IndexedRowMatrix, Matrix] = { + // Compute Gram matrix G of A such that G = A' * A. + val G = computeGramianMatrix().asBreeze.toDenseMatrix + + // Compute the eigenvalue decomposition of G such that G = V * D * V'. + val EigSym(d, vMat) = eigSym(G) + + // Find the effective rank of G. + val eigenRank = { + var i = d.length - 1 + while (i >= 0 && d(i) > 1e-14 * d(d.length - 1)) i = i - 1 + i + 1 + } + + // Calculate W such that W = A * V. + val vMatTruncated = vMat(::, vMat.cols - 1 to eigenRank by -1) + val V = Matrices.dense(vMatTruncated.rows, vMatTruncated.cols, + vMat(::, vMat.cols - 1 to eigenRank by -1).toArray) + val W = multiply(V) + val normW = Statistics.colStats(W.rows.map(_.vector)).normL2.asBreeze. + toDenseVector + + if (computeU) { + // Normalize W to U such that each column of U has norm 1. + val U = W.multiply(Matrices.diag(Vectors.fromBreeze(1.0 / normW))) + SingularValueDecomposition(U, Vectors.fromBreeze(normW), V) + } else { + SingularValueDecomposition(null, Vectors.fromBreeze(normW), V) } - new IndexedRowMatrix(indexedRows, nRows, B.numCols) } /** @@ -199,6 +891,10 @@ class IndexedRowMatrix @Since("1.0.0") ( toRowMatrix().computeGramianMatrix() } + /** + * Convert distributed storage of IndexedRowMatrix into locally stored BDM, whereas + * asBreeze works on matrices stored locally and requires no memcopy. + */ private[mllib] override def toBreeze(): BDM[Double] = { val m = numRows().toInt val n = numCols().toInt diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 78a8810052..b653e98ffd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -21,14 +21,15 @@ import java.util.Arrays import scala.collection.mutable.ListBuffer -import breeze.linalg.{axpy => brzAxpy, inv, svd => brzSvd, DenseMatrix => BDM, DenseVector => BDV, - MatrixSingularException, SparseVector => BSV} +import breeze.linalg.{axpy => brzAxpy, svd => brzSvd, DenseMatrix => BDM, + DenseVector => BDV, MatrixSingularException, SparseVector => BSV} import breeze.numerics.{sqrt => brzSqrt} import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg._ -import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} +import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, + MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.XORShiftRandom @@ -181,6 +182,11 @@ class RowMatrix @Since("1.0.0") ( * * @note The conditions that decide which method to use internally and the default parameters are * subject to change. + * @note Only the singular values and right singular vectors (not the left + * singular vectors) that computeSVD computes are meaningful when + * using multiple executors/machines. IndexedRowMatrix provides an + * analogous computeSVD function that computes meaningful left + * singular vectors. */ @Since("1.0.0") def computeSVD( @@ -285,10 +291,6 @@ class RowMatrix @Since("1.0.0") ( } val sk = i - if (sk < k) { - logWarning(s"Requested $k singular values but only found $sk nonzeros.") - } - // Warn at the end of the run as well, for increased visibility. if (computeMode == SVDMode.DistARPACK && rows.getStorageLevel == StorageLevel.NONE) { logWarning("The input data was not directly cached, which may hurt performance if its" @@ -319,6 +321,37 @@ class RowMatrix @Since("1.0.0") ( } } + /** + * Determine the effective rank + * + * @param k number of singular values to keep. We might return less than k if there are + * numerically zero singular values. See rCond. + * @param sigmas singular values of matrix + * @param rCond the reciprocal condition number. All singular values smaller + * than rCond * sigma(0) are treated as zero, where sigma(0) is + * the largest singular value. + * @return a [[Int]] + */ + def determineRank(k: Int, sigmas: BDV[Double], rCond: Double): Int = { + // Determine the effective rank. + val sigma0 = sigmas(0) + val threshold = rCond * sigma0 + var i = 0 + // sigmas might have a length smaller than k, if some Ritz values do not satisfy the convergence + // criterion specified by tol after max number of iterations. + // Thus use i < min(k, sigmas.length) instead of i < k. + if (sigmas.length < k) { + logWarning(s"Requested $k singular values but only found ${sigmas.length} converged.") + } + while (i < math.min(k, sigmas.length) && sigmas(i) >= threshold) { + i += 1 + } + if (i < k) { + logWarning(s"Requested $k singular values but only found $i nonzeros.") + } + i + } + /** * Computes the covariance matrix, treating each row as an observation. * @@ -533,11 +566,58 @@ class RowMatrix @Since("1.0.0") ( * Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce * architectures" (see here) * + * @note Only the R (not the Q) in the QR decomposition is meaningful when + * using multiple executors/machines. IndexedRowMatrix provides an + * analogous tallSkinnyQR function that computes a meaningful Q in a QR + * decomposition. + * * @param computeQ whether to computeQ * @return QRDecomposition(Q, R), Q = null if computeQ = false. */ @Since("1.5.0") def tallSkinnyQR(computeQ: Boolean = false): QRDecomposition[RowMatrix, Matrix] = { + /** + * Solve Q*R = A for Q using forward substitution where A = [[RowMatrix]] + * and R is upper-triangular. If the (i,i)th entry of R is close to 0, then + * we set the ith column of Q to 0, as well. + * + * @param R upper-triangular matrix. + * @return Q [[RowMatrix]] such that Q*R = A. + */ + def forwardSolve(R: breeze.linalg.DenseMatrix[Double]): + RowMatrix = { + val m = numRows() + val n = R.cols + val dim = math.min(R.rows, n) + val Bb = rows.context.broadcast(R(0 until dim, 0 until dim).toArray) + + val AB = rows.mapPartitions { iter => + val LHS = Bb.value + val LHSMat = Matrices.dense(dim, dim, LHS).asBreeze + val FNorm = Vectors.norm(Vectors.dense(LHS), 2.0) + iter.map { row => + val RHS = row.asBreeze.toArray + val v = BDV.zeros[Double] (dim) + // We don't use LAPACK here since it will be numerically unstable if + // R is singular. If R is singular, we set the corresponding + // column of Q to 0. + for ( i <- 0 until dim) { + if (math.abs(LHSMat(i, i)) > 1.0e-15 * FNorm) { + var sum = 0.0 + for ( j <- 0 until i) { + sum += LHSMat(j, i) * v(j) + } + v(i) = (RHS(i) - sum) / LHSMat(i, i) + } else { + v(i) = 0.0 + } + } + Vectors.fromBreeze(v) + } + } + new RowMatrix(AB, m, dim) + } + val col = numCols().toInt // split rows horizontally into smaller matrices, and compute QR for each of them val blockQRs = rows.retag(classOf[Vector]).glom().filter(_.length != 0).map { partRows => @@ -559,8 +639,7 @@ class RowMatrix @Since("1.0.0") ( val finalR = Matrices.fromBreeze(combinedR.toDenseMatrix) val finalQ = if (computeQ) { try { - val invR = inv(combinedR) - this.multiply(Matrices.fromBreeze(invR)) + forwardSolve(combinedR) } catch { case err: MatrixSingularException => logWarning("R is not invertible and return Q as null") @@ -662,6 +741,10 @@ class RowMatrix @Since("1.0.0") ( new CoordinateMatrix(sims, numCols(), numCols()) } + /** + * Convert distributed storage of RowMatrix into locally stored BDM, whereas + * asBreeze works on matrices stored locally and requires no memcopy. + */ private[mllib] override def toBreeze(): BDM[Double] = { val m = numRows().toInt val n = numCols().toInt diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/shuffle.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/shuffle.scala new file mode 100644 index 0000000000..35d505eb20 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/shuffle.scala @@ -0,0 +1,189 @@ +// scalastyle:off + +package org.apache.spark.mllib.linalg.distributed + +import breeze.generic.UFunc +import scala.collection.generic.CanBuildFrom +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag +import breeze.stats.distributions.Rand +import breeze.linalg.{DenseMatrix, DenseVector} + +/** + * Return the given DenseVector, Array, or DenseMatrix as a shuffled copy + * by using Fisher-Yates shuffle. + * Additionally, can return the given Array as a shuffled copy with the + * corresponding shuffle index information, or return the given Array as a + * shuffled copy using the inverse of the given shuffle index information, + * reversing the shuffle. + * + * @author ktakagaki + * @date 05/12/2014 + * @author huaminli + * @date 08/01/2016 + */ +object shuffle extends UFunc { + + implicit def implShuffle_Arr_eq_Arr[T](implicit ct: ClassTag[T]): + Impl[Array[T], Array[T]] = { + new Impl[Array[T], Array[T]] { + /** + * Shuffle the given [[Array[T]]]. + * @param arr the given array stored as [[Array[T]]] + * @return a shuffled array stored as [[Array[T]]] + */ + def apply(arr: Array[T]): Array[T] = { + // Make a copy of the input. + val tempret = arr.clone() + + // Shuffle tempret via Fisher-Yates method. + var count = tempret.length - 1 + while(count > 0) { + swap(tempret, count, Rand.randInt(count + 1).get()) + count -= 1 + } + tempret + } + + /** + * Swap two elements of [[Array[T]]] with specified indices [[Int]]. + * @param arr the given array stored as [[Array[T]]] + * @param indexA the first given [[Int]] index + * @param indexB the second given [[Int]] index + */ + def swap[T](arr: Array[T], indexA: Int, indexB: Int): Unit = { + val temp = arr(indexA) + arr(indexA) = arr(indexB) + arr(indexB) = temp + } + } + } + + implicit def implShuffle_Arr_Arr_Boolean_eq_Arr[T](implicit ct: ClassTag[T]): + Impl3[Array[T], Array[Int], Boolean, Array[T]] = { + new Impl3[Array[T], Array[Int], Boolean, Array[T]] { + /** + * shuffle the given [[Array[T]]] arr according to the given + * permutation [[Array[Int]]] arrIndex if [[Boolean]] isInverse is + * false; shuffle the given [[Array[T]]] arr according to the inverse of + * the given permutation [[Array[Int]]] arrIndex if [[Boolean]] + * isInverse is true. + * @param arr the given array stored as [[Array[T]]] + * @param arrIndex the given permutation array stored as [[Array[Int]]] + * @param isInverse the indicator whether perform inverse shuffle + * @return a shuffled array stored as [[Array[T]]] + */ + def apply(arr: Array[T], arrIndex: Array[Int], + isInverse: Boolean): Array[T] = { + require(arr.length == arrIndex.length, + "The two input arrays should have the same length!") + // Make a copy of the input. + val tempret = new Array[T](arr.length) + + if (!isInverse) { + // Shuffle tempret via given permutation. + for (i <- arr.indices) { + tempret(i) = arr(arrIndex(i)) + } + } else { + // Inverse shuffle tempret via given permutation. + for (i <- arr.indices) { + tempret(arrIndex(i)) = arr(i) + } + } + tempret + } + } + } + + implicit def implShuffle_Arr_Arr_eq_Arr[T](implicit ct: ClassTag[T]): + Impl2[Array[T], Array[Int], Array[T]] = { + new Impl2[Array[T], Array[Int], Array[T]] { + /** + * Shuffle the given [[Array[T]] arr according to the given + * permutation [[Array[Int]]] arrIndex. + * @param arr the given array stored as [[Array[T]]] + * @param arrIndex the given permutation array stored as [[Array[Int]]] + * @return a shuffled array stored as [[Array[T]]] + */ + def apply(arr: Array[T], arrIndex: Array[Int]): Array[T] = { + // Shuffle the input via given permutation. + shuffle(arr, arrIndex, false) + } + } + } + + implicit def implShuffle_Coll_eq_Coll[Coll, T, CollRes](implicit view: + Coll <:< IndexedSeq[T], cbf: CanBuildFrom[Coll, T, CollRes]) : + Impl[Coll, CollRes] = { + new Impl[Coll, CollRes] { + /** + * Shuffle the given [[Coll]]. + * @param v the given collection stored as [[Coll]] + * @return a shuffled collection stored as [[CollRes]] + */ + override def apply(v: Coll): CollRes = { + // Make a copy of the input. + val builder = cbf(v) + val copy = v.to[ArrayBuffer] + + // Shuffle tempret via Fisher-Yates method. + var count = copy.length - 1 + while (count > 0) { + swap(copy, count, Rand.randInt(count + 1).get()) + count -= 1 + } + builder ++= copy + builder.result() + } + + /** + * Swap two elements of [[Coll]] with specified indices [[Int]]. + * @param arr the given array stored as [[Array[T]]] + * @param indexA the first given [[Int]] index + * @param indexB the second given [[Int]] index + */ + def swap(arr: ArrayBuffer[T], indexA: Int, indexB: Int): Unit = { + val temp = arr(indexA) + arr(indexA) = arr(indexB) + arr(indexB) = temp + } + + } + } + + implicit def implShuffle_DV_eq_DV[T](implicit arrImpl: + Impl[Array[T], Array[T]], ct: ClassTag[T]): + Impl[DenseVector[T], DenseVector[T]] = { + new Impl[DenseVector[T], DenseVector[T]] { + /** + * Shuffle the given [[DenseVector[T]]]. + * @param dv the given vector stored as [[DenseVector[T]]] + * @return a shuffled vector stored as [[DenseVector[T]]] + */ + def apply(dv: DenseVector[T]): DenseVector[T] = + // convert to array and perform the shuffling. + new DenseVector(shuffle(dv.toArray)) + } + } + + implicit def implShuffle_DM_eq_DM[T](implicit arrImpl: + Impl[Array[T], Array[T]], ct: ClassTag[T]): + Impl[DenseMatrix[T], DenseMatrix[T]] = { + new Impl[DenseMatrix[T], DenseMatrix[T]] { + /** + * Shuffle the given [[DenseMatrix[T]]]. + * @param dm the given matrix stored as [[DenseMatrix[T]]] + * @return a shuffled matrix stored as [[DenseMatrix[T]]] + */ + def apply(dm: DenseMatrix[T]): DenseMatrix[T] = { + // convert to array and perform the shuffling. + val rows = dm.rows + val cols = dm.cols + new DenseMatrix(rows, cols, shuffle(dm.toArray)) + } + } + } +} + +// scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index e0e41f711b..7a714db853 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -241,16 +241,24 @@ object LBFGS extends Logging { val bcW = data.context.broadcast(w) val localGradient = gradient - val (gradientSum, lossSum) = data.treeAggregate((Vectors.zeros(n), 0.0))( - seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => - val l = localGradient.compute( - features, label, bcW.value, grad) - (grad, loss + l) - }, - combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => - axpy(1.0, grad2, grad1) - (grad1, loss1 + loss2) - }) + val seqOp = (c: (Vector, Double), v: (Double, Vector)) => + (c, v) match { + case ((grad, loss), (label, features)) => + val denseGrad = grad.toDense + val l = localGradient.compute(features, label, bcW.value, denseGrad) + (denseGrad, loss + l) + } + + val combOp = (c1: (Vector, Double), c2: (Vector, Double)) => + (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => + val denseGrad1 = grad1.toDense + val denseGrad2 = grad2.toDense + axpy(1.0, denseGrad2, denseGrad1) + (denseGrad1, loss1 + loss2) + } + + val zeroSparseVector = Vectors.sparse(n, Seq()) + val (gradientSum, lossSum) = data.treeAggregate((zeroSparseVector, 0.0))(seqOp, combOp) // broadcasted model is not needed anymore bcW.destroy() diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 3f9bcec427..aacb7921b8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.StringIndexer -import org.apache.spark.ml.linalg.{DenseMatrix, Vectors} +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS @@ -33,6 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.Metadata class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -136,6 +137,17 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(outputFields.contains("p")) } + test("SPARK-18625 : OneVsRestModel should support setFeaturesCol and setPredictionCol") { + val ova = new OneVsRest().setClassifier(new LogisticRegression) + val ovaModel = ova.fit(dataset) + val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea")) + ovaModel.setFeaturesCol("fea") + ovaModel.setPredictionCol("pred") + val transformedDataset = ovaModel.transform(dataset2) + val outputFields = transformedDataset.schema.fieldNames.toSet + assert(outputFields === Set("y", "fea", "pred")) + } + test("SPARK-8049: OneVsRest shouldn't output temp columns") { val logReg = new LogisticRegression() .setMaxIter(1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 957cf58a68..5262b146b1 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -45,7 +45,7 @@ class StopWordsRemoverSuite .setOutputCol("filtered") val dataSet = Seq( (Seq("test", "test"), Seq("test", "test")), - (Seq("a", "b", "c", "d"), Seq("b", "c")), + (Seq("a", "b", "c", "d"), Seq("b", "c", "d")), (Seq("a", "the", "an"), Seq()), (Seq("A", "The", "AN"), Seq()), (Seq(null), Seq(null)), diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 4fab216033..ed24c1e16a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -89,11 +89,14 @@ class GeneralizedLinearRegressionSuite xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "poisson", link = "log").toDF() - datasetPoissonLogWithZero = generateGeneralizedLinearRegressionInput( - intercept = -1.5, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), - xVariance = Array(0.7, 1.2), nPoints = 100, seed, noiseLevel = 0.01, - family = "poisson", link = "log") - .map{x => LabeledPoint(if (x.label < 0.7) 0.0 else x.label, x.features)}.toDF() + datasetPoissonLogWithZero = Seq( + LabeledPoint(0.0, Vectors.dense(18, 1.0)), + LabeledPoint(1.0, Vectors.dense(12, 0.0)), + LabeledPoint(0.0, Vectors.dense(15, 0.0)), + LabeledPoint(0.0, Vectors.dense(13, 2.0)), + LabeledPoint(0.0, Vectors.dense(15, 1.0)), + LabeledPoint(1.0, Vectors.dense(16, 1.0)) + ).toDF() datasetPoissonIdentity = generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), @@ -480,12 +483,12 @@ class GeneralizedLinearRegressionSuite model <- glm(formula, family="poisson", data=data) print(as.vector(coef(model))) } - [1] 0.4272661 -0.1565423 - [1] -3.6911354 0.6214301 0.1295814 + [1] -0.0457441 -0.6833928 + [1] 1.8121235 -0.1747493 -0.5815417 */ val expected = Seq( - Vectors.dense(0.0, 0.4272661, -0.1565423), - Vectors.dense(-3.6911354, 0.6214301, 0.1295814)) + Vectors.dense(0.0, -0.0457441, -0.6833928), + Vectors.dense(1.8121235, -0.1747493, -0.5815417)) import GeneralizedLinearRegression._ @@ -708,16 +711,17 @@ class GeneralizedLinearRegressionSuite R code: A <- matrix(c(0, 1, 2, 3, 5, 2, 1, 3), 4, 2) - b <- c(1, 0, 1, 0) - w <- c(1, 2, 3, 4) + b <- c(1, 0.5, 1, 0) + w <- c(1, 2.0, 0.3, 4.7) df <- as.data.frame(cbind(A, b)) */ val datasetWithWeight = Seq( Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), - Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)), - Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), - Instance(0.0, 4.0, Vectors.dense(3.0, 3.0)) + Instance(0.5, 2.0, Vectors.dense(1.0, 2.0)), + Instance(1.0, 0.3, Vectors.dense(2.0, 1.0)), + Instance(0.0, 4.7, Vectors.dense(3.0, 3.0)) ).toDF() + /* R code: @@ -725,34 +729,34 @@ class GeneralizedLinearRegressionSuite summary(model) Deviance Residuals: - 1 2 3 4 - 1.273 -1.437 2.533 -1.556 + 1 2 3 4 + 0.2404 0.1965 1.2824 -0.6916 Coefficients: Estimate Std. Error z value Pr(>|z|) - V1 -0.30217 0.46242 -0.653 0.513 - V2 -0.04452 0.37124 -0.120 0.905 + x1 -1.6901 1.2764 -1.324 0.185 + x2 0.7059 0.9449 0.747 0.455 (Dispersion parameter for binomial family taken to be 1) - Null deviance: 13.863 on 4 degrees of freedom - Residual deviance: 12.524 on 2 degrees of freedom - AIC: 16.524 + Null deviance: 8.3178 on 4 degrees of freedom + Residual deviance: 2.2193 on 2 degrees of freedom + AIC: 5.9915 Number of Fisher Scoring iterations: 5 residuals(model, type="pearson") 1 2 3 4 - 1.117731 -1.162962 2.395838 -1.189005 + 0.171217 0.197406 2.085864 -0.495332 residuals(model, type="working") 1 2 3 4 - 2.249324 -1.676240 2.913346 -1.353433 + 1.029315 0.281881 15.502768 -1.052203 residuals(model, type="response") - 1 2 3 4 - 0.5554219 -0.4034267 0.6567520 -0.2611382 - */ + 1 2 3 4 + 0.028480 0.069123 0.935495 -0.049613 + */ val trainer = new GeneralizedLinearRegression() .setFamily("binomial") .setWeightCol("weight") @@ -760,21 +764,21 @@ class GeneralizedLinearRegressionSuite val model = trainer.fit(datasetWithWeight) - val coefficientsR = Vectors.dense(Array(-0.30217, -0.04452)) + val coefficientsR = Vectors.dense(Array(-1.690134, 0.705929)) val interceptR = 0.0 - val devianceResidualsR = Array(1.273, -1.437, 2.533, -1.556) - val pearsonResidualsR = Array(1.117731, -1.162962, 2.395838, -1.189005) - val workingResidualsR = Array(2.249324, -1.676240, 2.913346, -1.353433) - val responseResidualsR = Array(0.5554219, -0.4034267, 0.6567520, -0.2611382) - val seCoefR = Array(0.46242, 0.37124) - val tValsR = Array(-0.653, -0.120) - val pValsR = Array(0.513, 0.905) + val devianceResidualsR = Array(0.2404, 0.1965, 1.2824, -0.6916) + val pearsonResidualsR = Array(0.171217, 0.197406, 2.085864, -0.495332) + val workingResidualsR = Array(1.029315, 0.281881, 15.502768, -1.052203) + val responseResidualsR = Array(0.02848, 0.069123, 0.935495, -0.049613) + val seCoefR = Array(1.276417, 0.944934) + val tValsR = Array(-1.324124, 0.747068) + val pValsR = Array(0.185462, 0.455023) val dispersionR = 1.0 - val nullDevianceR = 13.863 - val residualDevianceR = 12.524 + val nullDevianceR = 8.3178 + val residualDevianceR = 2.2193 val residualDegreeOfFreedomNullR = 4 val residualDegreeOfFreedomR = 2 - val aicR = 16.524 + val aicR = 5.991537 val summary = model.summary val devianceResiduals = summary.residuals() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index 61266f3c78..1539d4d18a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.mllib.linalg.distributed import java.{util => ju} -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV} +import breeze.linalg.{diag, svd => brzSvd, DenseMatrix => BDM, + DenseVector => BDV, SparseVector => BSV} import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.mllib.linalg.{DenseMatrix, DenseVector, Matrices, Matrix, SparseMatrix, SparseVector, Vectors} +import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -389,4 +390,46 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val A = AT2.transpose assert(A.toBreeze() === gridBasedMat.toBreeze()) } + + test("Partial SVD") { + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 0.0))), + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 0.0)))) + + val gridBasedMatLowRank = new BlockMatrix(sc.parallelize(blocks, + numPartitions), rowPerPart, colPerPart) + val localMat = gridBasedMatLowRank.toBreeze() + val n = gridBasedMatLowRank.numCols().toInt + for (k <- Seq(n, n - 1)) { + for (isGram <- Seq(true, false)) { + for (isRandom <- Seq(true, false)) { + val iteration = 2 + val svdResult = gridBasedMatLowRank.partialSVD(k, sc, + computeU = true, isGram, iteration, isRandom) + val U = svdResult.U + val S = svdResult.s + val V = svdResult.V + val reconstruct = U.toBreeze() * diag(S.asBreeze. + asInstanceOf[BDV[Double]]) * V.transpose.toBreeze() + val diff = localMat - reconstruct + val brzSvd.SVD(_, diffNorm, _) = brzSvd.reduced.apply(diff) + val tol = if (isGram) 5.0e-6 else 5.0e-13 + assert(diffNorm(0) ~== 0.0 absTol tol) + val svdWithoutU = gridBasedMat.partialSVD(k, sc, + computeU = false, isGram, iteration, isRandom) + assert(svdWithoutU.U === null) + } + } + } + } + + test("Spectral norm estimation") { + val norm = gridBasedMat.spectralNormEst(iteration = 20, sc) + val localMat = gridBasedMat.toBreeze() + val brzSvd.SVD(_, localS, _) = brzSvd.reduced.apply(localMat) + assert(norm ~== localS(0) absTol 1e-6) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 99af5fa10d..88df3b0acc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.mllib.linalg.distributed -import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV} +import breeze.linalg.{diag => brzDiag, norm => brzNorm, svd => brzSvd, + DenseMatrix => BDM, DenseVector => BDV} +import breeze.numerics._ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Matrices, Vectors} @@ -112,6 +114,21 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(localC === expected) } + test("QR Decomposition") { + val A = new IndexedRowMatrix(indexedRows) + val result = A.tallSkinnyQR(true) + val expected = breeze.linalg.qr.reduced(A.toBreeze()) + val calcQ = result.Q + val calcR = result.R + assert(closeToZero(abs(expected.q) - abs(calcQ.toBreeze()))) + assert(closeToZero(abs(expected.r) - abs(calcR.asBreeze.asInstanceOf[BDM[Double]]))) + assert(closeToZero(calcQ.multiply(calcR).toBreeze - A.toBreeze())) + // Decomposition without computing Q + val rOnly = A.tallSkinnyQR(computeQ = false) + assert(rOnly.Q == null) + assert(closeToZero(abs(expected.r) - abs(rOnly.R.asBreeze.asInstanceOf[BDM[Double]]))) + } + test("gram") { val A = new IndexedRowMatrix(indexedRows) val G = A.computeGramianMatrix() @@ -153,6 +170,57 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("tallSkinnySVD") { + val mat = new IndexedRowMatrix(indexedRows) + val localMat = mat.toBreeze() + val brzSvd.SVD(localU, localSigma, localVt) = brzSvd(localMat) + val localV: BDM[Double] = localVt.t.toDenseMatrix + val k = 2 + val svd = mat.tallSkinnySVD(k, sc, computeU = true) + val U = svd.U + val s = svd.s + val V = svd.V + assert(U.numRows() === m) + assert(U.numCols() === k) + assert(s.size === k) + assert(V.numRows === n) + assert(V.numCols === k) + assertColumnEqualUpToSign(U.toBreeze(), localU, k) + assertColumnEqualUpToSign(V.asBreeze.asInstanceOf[BDM[Double]], localV, k) + assert(closeToZero(s.asBreeze.asInstanceOf[BDV[Double]] - + localSigma(0 until k))) + + val svdWithoutU = mat.tallSkinnySVD(k, sc, computeU = false) + assert(svdWithoutU.U === null) + + intercept[IllegalArgumentException] { + mat.tallSkinnySVD(k = -1, sc) + } + } + + test("computeSVDbyGram") { + val mat = new IndexedRowMatrix(indexedRows) + val localMat = mat.toBreeze() + val brzSvd.SVD(localU, localSigma, localVt) = brzSvd(localMat) + val localV: BDM[Double] = localVt.t.toDenseMatrix + val svd = mat.computeSVDbyGram(computeU = true) + val U = svd.U + val s = svd.s + val V = svd.V + assert(U.numRows() === m) + assert(U.numCols() === n) + assert(s.size === n) + assert(V.numRows === n) + assert(V.numCols === n) + assertColumnEqualUpToSign(U.toBreeze(), localU, n) + assertColumnEqualUpToSign(V.asBreeze.asInstanceOf[BDM[Double]], localV, n) + assert(closeToZero(s.asBreeze.asInstanceOf[BDV[Double]] - + localSigma(0 until n))) + + val svdWithoutU = mat.computeSVDbyGram(computeU = false) + assert(svdWithoutU.U === null) + } + test("similar columns") { val A = new IndexedRowMatrix(indexedRows) val gram = A.computeGramianMatrix().asBreeze.toDenseMatrix @@ -168,5 +236,19 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { def closeToZero(G: BDM[Double]): Boolean = { G.valuesIterator.map(math.abs).sum < 1e-6 } + + def closeToZero(v: BDV[Double]): Boolean = { + brzNorm(v, 1.0) < 1e-6 + } + + def assertColumnEqualUpToSign(A: BDM[Double], B: BDM[Double], k: Int) { + assert(A.rows === B.rows) + for (j <- 0 until k) { + val aj = A(::, j) + val bj = B(::, j) + assert(closeToZero(aj - bj) || closeToZero(aj + bj), + s"The $j-th columns mismatch: $aj and $bj") + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 75ae0eb32f..572959200f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -230,6 +230,25 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers (weightLBFGS(0) ~= weightGD(0) relTol 0.02) && (weightLBFGS(1) ~= weightGD(1) relTol 0.02), "The weight differences between LBFGS and GD should be within 2%.") } + + test("SPARK-18471: LBFGS aggregator on empty partitions") { + val regParam = 0 + + val initialWeightsWithIntercept = Vectors.dense(0.0) + val convergenceTol = 1e-12 + val numIterations = 1 + val dataWithEmptyPartitions = sc.parallelize(Seq((1.0, Vectors.dense(2.0))), 2) + + LBFGS.runLBFGS( + dataWithEmptyPartitions, + gradient, + simpleUpdater, + numCorrections, + convergenceTol, + numIterations, + regParam, + initialWeightsWithIntercept) + } } class LBFGSClusterSuite extends SparkFunSuite with LocalClusterSparkContext { diff --git a/pom.xml b/pom.xml index c391102d37..a0c44f5ac1 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -145,7 +145,7 @@ 1.7.7 hadoop2 0.7.1 - 1.6.1 + 1.6.2 0.10.2 @@ -179,7 +179,7 @@ 4.5.3 1.1 2.52.0 - 2.8 + 2.6 1.8 1.0.0 @@ -293,6 +293,12 @@ spark-tags_${scala.binary.version} ${project.version} + + org.apache.spark + spark-tags_${scala.binary.version} + ${project.version} + test-jar + com.twitter chill_${scala.binary.version} @@ -557,7 +563,7 @@ io.netty netty - 3.8.0.Final + 3.9.9.Final org.apache.derby @@ -1863,6 +1869,11 @@ + + com.thoughtworks.paranamer + paranamer + ${paranamer.version} + @@ -1909,7 +1920,7 @@ org.codehaus.mojo build-helper-maven-plugin - 1.10 + 1.12 net.alchim31.maven @@ -1972,7 +1983,7 @@ org.apache.maven.plugins maven-compiler-plugin - 3.5.1 + 3.6.0 ${java.version} ${java.version} @@ -2092,7 +2103,7 @@ org.apache.maven.plugins maven-jar-plugin - 2.6 + 3.0.2 org.apache.maven.plugins @@ -2102,7 +2113,7 @@ org.apache.maven.plugins maven-source-plugin - 2.4 + 3.0.1 true @@ -2137,17 +2148,17 @@ org.apache.maven.plugins maven-javadoc-plugin - 2.10.3 + 2.10.4 org.codehaus.mojo exec-maven-plugin - 1.4.0 + 1.5.0 org.apache.maven.plugins maven-assembly-plugin - 2.6 + 3.0.0 org.apache.maven.plugins @@ -2580,7 +2591,7 @@ yarn - yarn + resource-managers/yarn common/network-yarn @@ -2588,7 +2599,7 @@ mesos - mesos + resource-managers/mesos diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 77397eab81..de0655b6cb 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -22,7 +22,7 @@ import com.typesafe.tools.mima.core._ import com.typesafe.tools.mima.core.MissingClassProblem import com.typesafe.tools.mima.core.MissingTypesProblem import com.typesafe.tools.mima.core.ProblemFilters._ -import com.typesafe.tools.mima.plugin.MimaKeys.{binaryIssueFilters, previousArtifact} +import com.typesafe.tools.mima.plugin.MimaKeys.{mimaBinaryIssueFilters, mimaPreviousArtifacts} import com.typesafe.tools.mima.plugin.MimaPlugin.mimaDefaultSettings @@ -92,8 +92,8 @@ object MimaBuild { val project = projectRef.project val fullId = "spark-" + project + "_2.11" mimaDefaultSettings ++ - Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), - binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value)) + Seq(mimaPreviousArtifacts := Set(organization % fullId % previousSparkVersion), + mimaBinaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value)) } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index b113bbf803..2314d7f45c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,9 +34,16 @@ import com.typesafe.tools.mima.core.ProblemFilters._ */ object MimaExcludes { + // Exclude rules for 2.2.x lazy val v22excludes = v21excludes ++ Seq( // [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray"), + + // [SPARK-18949] [SQL] Add repairTable API to Catalog + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.recoverPartitions"), + + // [SPARK-18537] Add a REST api to spark streaming + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.streaming.scheduler.StreamingListener.onStreamingStarted") ) // Exclude rules for 2.1.x @@ -90,7 +97,7 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.sourceStatuses"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQuery.id"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.lastProgress"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.recentProgresses"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.recentProgress"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.id"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.get"), @@ -105,7 +112,17 @@ object MimaExcludes { ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasAggregationDepth.org$apache$spark$ml$param$shared$HasAggregationDepth$_setter_$aggregationDepth_="), // [SPARK-18236] Reduce duplicate objects in Spark UI and HistoryServer - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.TaskInfo.accumulables") + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.TaskInfo.accumulables"), + + // [SPARK-18657] Add StreamingQuery.runId + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.runId"), + + // [SPARK-18694] Add StreamingQuery.explain and exception to Python and fix StreamingQueryException + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryException$"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryException.startOffset"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryException.endOffset"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryException.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryException.query") ) } @@ -918,7 +935,7 @@ object MimaExcludes { def excludes(version: String) = version match { case v if v.startsWith("2.2") => v22excludes - case v if v.startsWith("2.1") => v22excludes // TODO: Update this when we bump version to 2.2 + case v if v.startsWith("2.1") => v21excludes case v if v.startsWith("2.0") => v20excludes case _ => Seq() } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e3fbe0379f..74edd537f5 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -251,13 +251,12 @@ object SparkBuild extends PomBuild { Resolver.file("local", file(Path.userHome.absolutePath + "/.ivy2/local"))(Resolver.ivyStylePatterns) ), externalResolvers := resolvers.value, - otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), - publishLocalConfiguration in MavenCompile <<= (packagedArtifacts, deliverLocal, ivyLoggingLevel) map { - (arts, _, level) => new PublishConfiguration(None, "dotM2", arts, Seq(), level) - }, + otherResolvers := SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))).value, + publishLocalConfiguration in MavenCompile := + new PublishConfiguration(None, "dotM2", packagedArtifacts.value, Seq(), ivyLoggingLevel.value), publishMavenStyle in MavenCompile := true, - publishLocal in MavenCompile <<= publishTask(publishLocalConfiguration in MavenCompile, deliverLocal), - publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn, + publishLocal in MavenCompile := publishTask(publishLocalConfiguration in MavenCompile, deliverLocal).value, + publishLocalBoth := Seq(publishLocal in MavenCompile, publishLocal).dependOn.value, javacOptions in (Compile, doc) ++= { val versionParts = System.getProperty("java.version").split("[+.\\-]+", 3) @@ -431,7 +430,8 @@ object SparkBuild extends PomBuild { val packages :: className :: otherArgs = spaceDelimited(" [args]").parsed.toList val scalaRun = (runner in run).value val classpath = (fullClasspath in Runtime).value - val args = Seq("--packages", packages, "--class", className, (Keys.`package` in Compile in "core").value.getCanonicalPath) ++ otherArgs + val args = Seq("--packages", packages, "--class", className, (Keys.`package` in Compile in LocalProject("core")) + .value.getCanonicalPath) ++ otherArgs println(args) scalaRun.run("org.apache.spark.deploy.SparkSubmit", classpath.map(_.data), args, streams.value.log) }, @@ -443,7 +443,7 @@ object SparkBuild extends PomBuild { } ))(assembly) - enable(Seq(sparkShell := sparkShell in "assembly"))(spark) + enable(Seq(sparkShell := sparkShell in LocalProject("assembly")))(spark) // TODO: move this to its upstream project. override def projectDefinitions(baseDirectory: File): Seq[Project] = { @@ -512,9 +512,9 @@ object OldDeps { lazy val project = Project("oldDeps", file("dev"), settings = oldDepsSettings) - lazy val allPreviousArtifactKeys = Def.settingDyn[Seq[Option[ModuleID]]] { + lazy val allPreviousArtifactKeys = Def.settingDyn[Seq[Set[ModuleID]]] { SparkBuild.mimaProjects - .map { project => MimaKeys.previousArtifact in project } + .map { project => MimaKeys.mimaPreviousArtifacts in project } .map(k => Def.setting(k.value)) .join } @@ -568,9 +568,9 @@ object Hive { javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"), // Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings // only for this subproject. - scalacOptions <<= scalacOptions map { currentOpts: Seq[String] => + scalacOptions := (scalacOptions map { currentOpts: Seq[String] => currentOpts.filterNot(_ == "-deprecation") - }, + }).value, initialCommands in console := """ |import org.apache.spark.SparkContext @@ -608,17 +608,18 @@ object Assembly { sys.props.get("hadoop.version") .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, - jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-0-8-assembly") || mName.contains("streaming-kafka-0-10-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { + jarName in assembly := { + if (moduleName.value.contains("streaming-flume-assembly") + || moduleName.value.contains("streaming-kafka-0-8-assembly") + || moduleName.value.contains("streaming-kafka-0-10-assembly") + || moduleName.value.contains("streaming-kinesis-asl-assembly")) { // This must match the same name used in maven (see external/kafka-0-8-assembly/pom.xml) - s"${mName}-${v}.jar" + s"${moduleName.value}-${version.value}.jar" } else { - s"${mName}-${v}-hadoop${hv}.jar" + s"${moduleName.value}-${version.value}-hadoop${hadoopVersion.value}.jar" } }, - jarName in (Test, assembly) <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - s"${mName}-test-${v}.jar" - }, + jarName in (Test, assembly) := s"${moduleName.value}-test-${version.value}.jar", mergeStrategy in assembly := { case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard @@ -639,13 +640,13 @@ object PySparkAssembly { // Use a resource generator to copy all .py files from python/pyspark into a managed directory // to be included in the assembly. We can't just add "python/" to the assembly's resource dir // list since that will copy unneeded / unwanted files. - resourceGenerators in Compile <+= resourceManaged in Compile map { outDir: File => + resourceGenerators in Compile += Def.macroValueI(resourceManaged in Compile map { outDir: File => val src = new File(BuildCommons.sparkHome, "python/pyspark") val zipFile = new File(BuildCommons.sparkHome , "python/lib/pyspark.zip") zipFile.delete() zipRecursive(src, zipFile) Seq[File]() - } + }).value ) private def zipRecursive(source: File, destZipFile: File) = { @@ -771,7 +772,7 @@ object Unidoc { object CopyDependencies { val copyDeps = TaskKey[Unit]("copyDeps", "Copies needed dependencies to the build directory.") - val destPath = (crossTarget in Compile) / "jars" + val destPath = (crossTarget in Compile) { _ / "jars"} lazy val settings = Seq( copyDeps := { @@ -791,7 +792,7 @@ object CopyDependencies { } }, crossTarget in (Compile, packageBin) := destPath.value, - packageBin in Compile <<= (packageBin in Compile).dependsOn(copyDeps) + packageBin in Compile := (packageBin in Compile).dependsOn(copyDeps).value ) } @@ -823,7 +824,8 @@ object TestSettings { // launched by the tests have access to the correct test-time classpath. envVars in Test ++= Map( "SPARK_DIST_CLASSPATH" -> - (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), + (fullClasspath in Test).value.files.map(_.getAbsolutePath) + .mkString(File.pathSeparator).stripSuffix(File.pathSeparator), "SPARK_PREPEND_CLASSES" -> "1", "SPARK_SCALA_VERSION" -> scalaBinaryVersion, "SPARK_TESTING" -> "1", @@ -862,7 +864,7 @@ object TestSettings { // Only allow one test at a time, even across projects, since they run in the same JVM parallelExecution in Test := false, // Make sure the test temp directory exists. - resourceGenerators in Test <+= resourceManaged in Test map { outDir: File => + resourceGenerators in Test += Def.macroValueI(resourceManaged in Test map { outDir: File => var dir = new File(testTempDir) if (!dir.isDirectory()) { // Because File.mkdirs() can fail if multiple callers are trying to create the same @@ -880,7 +882,7 @@ object TestSettings { } } Seq[File]() - }, + }).value, concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), // Remove certain packages from Scaladoc scalacOptions in (Compile, doc) := Seq( diff --git a/project/build.properties b/project/build.properties index 1e38156e0b..d339865ab9 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.11 +sbt.version=0.13.13 diff --git a/project/plugins.sbt b/project/plugins.sbt index 76597d2729..84d1239990 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,12 +1,12 @@ addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") -addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "4.0.0") +addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "5.0.1") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.8.2") addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0") -addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.11") +addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.12") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") @@ -16,9 +16,9 @@ addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") addSbtPlugin("io.spray" % "sbt-revolver" % "0.8.0") -libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" +libraryDependencies += "org.ow2.asm" % "asm" % "5.1" -libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" +libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.1" addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.11") diff --git a/python/MANIFEST.in b/python/MANIFEST.in index bbcce1baa4..40f1fb2f1e 100644 --- a/python/MANIFEST.in +++ b/python/MANIFEST.in @@ -17,6 +17,8 @@ global-exclude *.py[cod] __pycache__ .DS_Store recursive-include deps/jars *.jar graft deps/bin +recursive-include deps/data *.data *.txt +recursive-include deps/licenses *.txt recursive-include deps/examples *.py recursive-include lib *.zip include README.md diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 5f93586a48..9331e74eed 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -34,6 +34,8 @@ Access files shipped with jobs. - :class:`StorageLevel`: Finer-grained cache persistence levels. + - :class:`TaskContext`: + Information about the current running task, avaialble on the workers and experimental. """ @@ -49,6 +51,7 @@ from pyspark.broadcast import Broadcast from pyspark.serializers import MarshalSerializer, PickleSerializer from pyspark.status import * +from pyspark.taskcontext import TaskContext from pyspark.profiler import Profiler, BasicProfiler from pyspark.version import __version__ @@ -106,5 +109,5 @@ def wrapper(*args, **kwargs): __all__ = [ "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", - "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", + "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", "TaskContext", ] diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 1d62b32534..62c31431b5 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -165,8 +165,8 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav typeConverter=TypeConverters.toListFloat) handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " + - "Options are skip (filter out rows with invalid values), " + - "error (throw an error), or keep (keep invalid values in a special " + + "Options are 'skip' (filter out rows with invalid values), " + + "'error' (throw an error), or 'keep' (keep invalid values in a special " + "additional bucket).", typeConverter=TypeConverters.toString) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 9e05da89af..b384b2b507 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -135,12 +135,11 @@ def _load_from_socket(port, serializer): break if not sock: raise Exception("could not open socket") - try: - rf = sock.makefile("rb", 65536) - for item in serializer.load_stream(rf): - yield item - finally: - sock.close() + # The RDD materialization time is unpredicable, if we set a timeout for socket reading + # operation, it will very possibly fail. See SPARK-18281. + sock.settimeout(None) + # The socket will be automatically closed when garbage-collected. + return serializer.load_stream(sock.makefile("rb", 65536)) def ignore_unicode_prefix(f): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 2a1326947f..c4f2f08cb4 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -61,7 +61,7 @@ if sys.version < '3': import cPickle as pickle protocol = 2 - from itertools import izip as zip + from itertools import izip as zip, imap as map else: import pickle protocol = 3 @@ -96,7 +96,12 @@ def load_stream(self, stream): raise NotImplementedError def _load_stream_without_unbatching(self, stream): - return self.load_stream(stream) + """ + Return an iterator of deserialized batches (lists) of objects from the input stream. + if the serializer does not operate on batches the default implementation returns an + iterator of single element lists. + """ + return map(lambda x: [x], self.load_stream(stream)) # Note: our notion of "equality" is that output generated by # equal serializers can be deserialized using the same serializer. @@ -278,50 +283,57 @@ def __repr__(self): return "AutoBatchedSerializer(%s)" % self.serializer -class CartesianDeserializer(FramedSerializer): +class CartesianDeserializer(Serializer): """ Deserializes the JavaRDD cartesian() of two PythonRDDs. + Due to pyspark batching we cannot simply use the result of the Java RDD cartesian, + we additionally need to do the cartesian within each pair of batches. """ def __init__(self, key_ser, val_ser): - FramedSerializer.__init__(self) self.key_ser = key_ser self.val_ser = val_ser - def prepare_keys_values(self, stream): - key_stream = self.key_ser._load_stream_without_unbatching(stream) - val_stream = self.val_ser._load_stream_without_unbatching(stream) - key_is_batched = isinstance(self.key_ser, BatchedSerializer) - val_is_batched = isinstance(self.val_ser, BatchedSerializer) - for (keys, vals) in zip(key_stream, val_stream): - keys = keys if key_is_batched else [keys] - vals = vals if val_is_batched else [vals] - yield (keys, vals) + def _load_stream_without_unbatching(self, stream): + key_batch_stream = self.key_ser._load_stream_without_unbatching(stream) + val_batch_stream = self.val_ser._load_stream_without_unbatching(stream) + for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream): + # for correctness with repeated cartesian/zip this must be returned as one batch + yield product(key_batch, val_batch) def load_stream(self, stream): - for (keys, vals) in self.prepare_keys_values(stream): - for pair in product(keys, vals): - yield pair + return chain.from_iterable(self._load_stream_without_unbatching(stream)) def __repr__(self): return "CartesianDeserializer(%s, %s)" % \ (str(self.key_ser), str(self.val_ser)) -class PairDeserializer(CartesianDeserializer): +class PairDeserializer(Serializer): """ Deserializes the JavaRDD zip() of two PythonRDDs. + Due to pyspark batching we cannot simply use the result of the Java RDD zip, + we additionally need to do the zip within each pair of batches. """ + def __init__(self, key_ser, val_ser): + self.key_ser = key_ser + self.val_ser = val_ser + + def _load_stream_without_unbatching(self, stream): + key_batch_stream = self.key_ser._load_stream_without_unbatching(stream) + val_batch_stream = self.val_ser._load_stream_without_unbatching(stream) + for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream): + if len(key_batch) != len(val_batch): + raise ValueError("Can not deserialize PairRDD with different number of items" + " in batches: (%d, %d)" % (len(key_batch), len(val_batch))) + # for correctness with repeated cartesian/zip this must be returned as one batch + yield zip(key_batch, val_batch) + def load_stream(self, stream): - for (keys, vals) in self.prepare_keys_values(stream): - if len(keys) != len(vals): - raise ValueError("Can not deserialize RDD with different number of items" - " in pair: (%d, %d)" % (len(keys), len(vals))) - for pair in zip(keys, vals): - yield pair + return chain.from_iterable(self._load_stream_without_unbatching(stream)) def __repr__(self): return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser)) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index a36d02e0db..30c7a3fe4f 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -258,6 +258,11 @@ def refreshTable(self, tableName): """Invalidate and refresh all the cached metadata of the given table.""" self._jcatalog.refreshTable(tableName) + @since('2.1.1') + def recoverPartitions(self, tableName): + """Recover all the partitions of the given table and update the catalog.""" + self._jcatalog.recoverPartitions(tableName) + def _reset(self): """(Internal use only) Drop all existing databases (except "default"), tables, partitions and functions, and set the current database to "default". diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 84f01d3d9a..5014299ad2 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -28,8 +28,10 @@ from pyspark import since, keyword_only from pyspark.rdd import ignore_unicode_prefix +from pyspark.sql.column import _to_seq from pyspark.sql.readwriter import OptionUtils, to_str from pyspark.sql.types import * +from pyspark.sql.utils import StreamingQueryException __all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", "DataStreamWriter"] @@ -50,14 +52,29 @@ def __init__(self, jsq): @property @since(2.0) def id(self): - """The id of the streaming query. + """Returns the unique id of this query that persists across restarts from checkpoint data. + That is, this id is generated when a query is started for the first time, and + will be the same every time it is restarted from checkpoint data. + There can only be one query with the same id active in a Spark cluster. + Also see, `runId`. """ return self._jsq.id().toString() + @property + @since(2.1) + def runId(self): + """Returns the unique id of this query that does not persist across restarts. That is, every + query that is started (or restarted from checkpoint) will have a different runId. + """ + return self._jsq.runId().toString() + @property @since(2.0) def name(self): - """The name of the streaming query. This name is unique across all active queries. + """Returns the user-specified name of the query, or null if not specified. + This name can be specified in the `org.apache.spark.sql.streaming.DataStreamWriter` + as `dataframe.writeStream.queryName("query").start()`. + This name, if set, must be unique across all active queries. """ return self._jsq.name() @@ -98,21 +115,26 @@ def status(self): @property @since(2.1) - def recentProgresses(self): + def recentProgress(self): """Returns an array of the most recent [[StreamingQueryProgress]] updates for this query. The number of progress updates retained for each stream is configured by Spark session - configuration `spark.sql.streaming.numRecentProgresses`. + configuration `spark.sql.streaming.numRecentProgressUpdates`. """ - return [json.loads(p.json()) for p in self._jsq.recentProgresses()] + return [json.loads(p.json()) for p in self._jsq.recentProgress()] @property @since(2.1) def lastProgress(self): """ - Returns the most recent :class:`StreamingQueryProgress` update of this streaming query. + Returns the most recent :class:`StreamingQueryProgress` update of this streaming query or + None if there were no progress updates :return: a map """ - return json.loads(self._jsq.lastProgress().json()) + lastProgress = self._jsq.lastProgress() + if lastProgress: + return json.loads(lastProgress.json()) + else: + return None @since(2.0) def processAllAvailable(self): @@ -132,6 +154,45 @@ def stop(self): """ self._jsq.stop() + @since(2.1) + def explain(self, extended=False): + """Prints the (logical and physical) plans to the console for debugging purpose. + + :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. + + >>> sq = sdf.writeStream.format('memory').queryName('query_explain').start() + >>> sq.processAllAvailable() # Wait a bit to generate the runtime plans. + >>> sq.explain() + == Physical Plan == + ... + >>> sq.explain(True) + == Parsed Logical Plan == + ... + == Analyzed Logical Plan == + ... + == Optimized Logical Plan == + ... + == Physical Plan == + ... + >>> sq.stop() + """ + # Cannot call `_jsq.explain(...)` because it will print in the JVM process. + # We should print it in the Python process. + print(self._jsq.explainInternal(extended)) + + @since(2.1) + def exception(self): + """ + :return: the StreamingQueryException if the query was terminated by an exception, or None. + """ + if self._jsq.exception().isDefined(): + je = self._jsq.exception().get() + msg = je.toString().split(': ', 1)[1] # Drop the Java StreamingQueryException type info + stackTrace = '\n\t at '.join(map(lambda x: x.toString(), je.getStackTrace())) + return StreamingQueryException(msg, stackTrace) + else: + return None + class StreamingQueryManager(object): """A class to manage all the :class:`StreamingQuery` StreamingQueries active. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b7b2a5923c..18fd68ec5e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -50,7 +50,7 @@ from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests -from pyspark.sql.functions import UserDefinedFunction, sha2 +from pyspark.sql.functions import UserDefinedFunction, sha2, lit from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException @@ -360,6 +360,15 @@ def test_broadcast_in_udf(self): [res] = self.spark.sql("SELECT MYUDF('')").collect() self.assertEqual("", res[0]) + def test_udf_with_filter_function(self): + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql.functions import udf, col + from pyspark.sql.types import BooleanType + + my_filter = udf(lambda a: a < 2, BooleanType()) + sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2")) + self.assertEqual(sel.collect(), [Row(key=1, value='1')]) + def test_udf_with_aggregate_function(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) from pyspark.sql.functions import udf, col, sum @@ -384,6 +393,26 @@ def test_udf_in_generate(self): row = df.select(explode(f(*df))).groupBy().sum().first() self.assertEqual(row[0], 10) + df = self.spark.range(3) + res = df.select("id", explode(f(df.id))).collect() + self.assertEqual(res[0][0], 1) + self.assertEqual(res[0][1], 0) + self.assertEqual(res[1][0], 2) + self.assertEqual(res[1][1], 0) + self.assertEqual(res[2][0], 2) + self.assertEqual(res[2][1], 1) + + range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType())) + res = df.select("id", explode(range_udf(df.id))).collect() + self.assertEqual(res[0][0], 0) + self.assertEqual(res[0][1], -1) + self.assertEqual(res[1][0], 0) + self.assertEqual(res[1][1], 0) + self.assertEqual(res[2][0], 1) + self.assertEqual(res[2][1], 0) + self.assertEqual(res[3][0], 1) + self.assertEqual(res[3][1], 1) + def test_udf_with_order_by_and_limit(self): from pyspark.sql.functions import udf my_copy = udf(lambda x: x, IntegerType()) @@ -392,6 +421,14 @@ def test_udf_with_order_by_and_limit(self): res.explain(True) self.assertEqual(res.collect(), [Row(id=0, copy=0)]) + def test_udf_with_input_file_name(self): + from pyspark.sql.functions import udf, input_file_name + from pyspark.sql.types import StringType + sourceFile = udf(lambda path: path, StringType()) + filePath = "python/test_support/sql/people1.json" + row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() + self.assertTrue(row[0].find("people1.json") != -1) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) @@ -1028,7 +1065,8 @@ def test_stream_read_options_overwrite(self): self.assertEqual(df.schema.simpleString(), "struct") def test_stream_save_options(self): - df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') \ + .withColumn('id', lit(1)) for q in self.spark._wrapped.streams.active: q.stop() tmpPath = tempfile.mkdtemp() @@ -1037,7 +1075,7 @@ def test_stream_save_options(self): out = os.path.join(tmpPath, 'out') chk = os.path.join(tmpPath, 'chk') q = df.writeStream.option('checkpointLocation', chk).queryName('this_query') \ - .format('parquet').outputMode('append').option('path', out).start() + .format('parquet').partitionBy('id').outputMode('append').option('path', out).start() try: self.assertEqual(q.name, 'this_query') self.assertTrue(q.isActive) @@ -1091,16 +1129,32 @@ def test_stream_status_and_progress(self): self.assertTrue(df.isStreaming) out = os.path.join(tmpPath, 'out') chk = os.path.join(tmpPath, 'chk') - q = df.writeStream \ + + def func(x): + time.sleep(1) + return x + + from pyspark.sql.functions import col, udf + sleep_udf = udf(func) + + # Use "sleep_udf" to delay the progress update so that we can test `lastProgress` when there + # were no updates. + q = df.select(sleep_udf(col("value")).alias('value')).writeStream \ .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) try: + # "lastProgress" will return None in most cases. However, as it may be flaky when + # Jenkins is very slow, we don't assert it. If there is something wrong, "lastProgress" + # may throw error with a high chance and make this test flaky, so we should still be + # able to detect broken codes. + q.lastProgress + q.processAllAvailable() lastProgress = q.lastProgress - recentProgresses = q.recentProgresses + recentProgress = q.recentProgress status = q.status self.assertEqual(lastProgress['name'], q.name) self.assertEqual(lastProgress['id'], q.id) - self.assertTrue(any(p == lastProgress for p in recentProgresses)) + self.assertTrue(any(p == lastProgress for p in recentProgress)) self.assertTrue( "message" in status and "isDataAvailable" in status and @@ -1137,6 +1191,35 @@ def test_stream_await_termination(self): q.stop() shutil.rmtree(tmpPath) + def test_stream_exception(self): + sdf = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + sq = sdf.writeStream.format('memory').queryName('query_explain').start() + try: + sq.processAllAvailable() + self.assertEqual(sq.exception(), None) + finally: + sq.stop() + + from pyspark.sql.functions import col, udf + from pyspark.sql.utils import StreamingQueryException + bad_udf = udf(lambda x: 1 / 0) + sq = sdf.select(bad_udf(col("value")))\ + .writeStream\ + .format('memory')\ + .queryName('this_query')\ + .start() + try: + # Process some data to fail the query + sq.processAllAvailable() + self.fail("bad udf should fail the query") + except StreamingQueryException as e: + # This is expected + self.assertTrue("ZeroDivisionError" in e.desc) + finally: + sq.stop() + self.assertTrue(type(sq.exception()) is StreamingQueryException) + self.assertTrue("ZeroDivisionError" in sq.exception().desc) + def test_query_manager_await_termination(self): df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') for q in self.spark._wrapped.streams.active: @@ -1980,6 +2063,41 @@ def assert_runs_only_one_job_stage_and_task(job_group_name, f): # Regression test for SPARK-17514: limit(n).collect() should the perform same as take(n) assert_runs_only_one_job_stage_and_task("collect_limit", lambda: df.limit(1).collect()) + @unittest.skipIf(sys.version_info < (3, 3), "Unittest < 3.3 doesn't support mocking") + def test_unbounded_frames(self): + from unittest.mock import patch + from pyspark.sql import functions as F + from pyspark.sql import window + import importlib + + df = self.spark.range(0, 3) + + def rows_frame_match(): + return "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select( + F.count("*").over(window.Window.rowsBetween(-sys.maxsize, sys.maxsize)) + ).columns[0] + + def range_frame_match(): + return "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select( + F.count("*").over(window.Window.rangeBetween(-sys.maxsize, sys.maxsize)) + ).columns[0] + + with patch("sys.maxsize", 2 ** 31 - 1): + importlib.reload(window) + self.assertTrue(rows_frame_match()) + self.assertTrue(range_frame_match()) + + with patch("sys.maxsize", 2 ** 63 - 1): + importlib.reload(window) + self.assertTrue(rows_frame_match()) + self.assertTrue(range_frame_match()) + + with patch("sys.maxsize", 2 ** 127 - 1): + importlib.reload(window) + self.assertTrue(rows_frame_match()) + self.assertTrue(range_frame_match()) + + importlib.reload(window) if __name__ == "__main__": from pyspark.sql.tests import * diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index c345e623f1..7ce27f9b10 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -49,6 +49,8 @@ class Window(object): _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 + _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG) + _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG) unboundedPreceding = _JAVA_MIN_LONG @@ -98,9 +100,9 @@ def rowsBetween(start, end): The frame is unbounded if this is ``Window.unboundedFollowing``, or any value greater than or equal to 9223372036854775807. """ - if start <= Window._JAVA_MIN_LONG: + if start <= Window._PRECEDING_THRESHOLD: start = Window.unboundedPreceding - if end >= Window._JAVA_MAX_LONG: + if end >= Window._FOLLOWING_THRESHOLD: end = Window.unboundedFollowing sc = SparkContext._active_spark_context jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rowsBetween(start, end) @@ -123,14 +125,14 @@ def rangeBetween(start, end): :param start: boundary start, inclusive. The frame is unbounded if this is ``Window.unboundedPreceding``, or - any value less than or equal to -9223372036854775808. + any value less than or equal to max(-sys.maxsize, -9223372036854775808). :param end: boundary end, inclusive. The frame is unbounded if this is ``Window.unboundedFollowing``, or - any value greater than or equal to 9223372036854775807. + any value greater than or equal to min(sys.maxsize, 9223372036854775807). """ - if start <= Window._JAVA_MIN_LONG: + if start <= Window._PRECEDING_THRESHOLD: start = Window.unboundedPreceding - if end >= Window._JAVA_MAX_LONG: + if end >= Window._FOLLOWING_THRESHOLD: end = Window.unboundedFollowing sc = SparkContext._active_spark_context jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rangeBetween(start, end) @@ -185,14 +187,14 @@ def rowsBetween(self, start, end): :param start: boundary start, inclusive. The frame is unbounded if this is ``Window.unboundedPreceding``, or - any value less than or equal to -9223372036854775808. + any value less than or equal to max(-sys.maxsize, -9223372036854775808). :param end: boundary end, inclusive. The frame is unbounded if this is ``Window.unboundedFollowing``, or - any value greater than or equal to 9223372036854775807. + any value greater than or equal to min(sys.maxsize, 9223372036854775807). """ - if start <= Window._JAVA_MIN_LONG: + if start <= Window._PRECEDING_THRESHOLD: start = Window.unboundedPreceding - if end >= Window._JAVA_MAX_LONG: + if end >= Window._FOLLOWING_THRESHOLD: end = Window.unboundedFollowing return WindowSpec(self._jspec.rowsBetween(start, end)) @@ -211,14 +213,14 @@ def rangeBetween(self, start, end): :param start: boundary start, inclusive. The frame is unbounded if this is ``Window.unboundedPreceding``, or - any value less than or equal to -9223372036854775808. + any value less than or equal to max(-sys.maxsize, -9223372036854775808). :param end: boundary end, inclusive. The frame is unbounded if this is ``Window.unboundedFollowing``, or - any value greater than or equal to 9223372036854775807. + any value greater than or equal to min(sys.maxsize, 9223372036854775807). """ - if start <= Window._JAVA_MIN_LONG: + if start <= Window._PRECEDING_THRESHOLD: start = Window.unboundedPreceding - if end >= Window._JAVA_MAX_LONG: + if end >= Window._FOLLOWING_THRESHOLD: end = Window.unboundedFollowing return WindowSpec(self._jspec.rangeBetween(start, end)) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py new file mode 100644 index 0000000000..e5218d9e75 --- /dev/null +++ b/python/pyspark/taskcontext.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + + +class TaskContext(object): + + """ + .. note:: Experimental + + Contextual information about a task which can be read or mutated during + execution. To access the TaskContext for a running task, use: + L{TaskContext.get()}. + """ + + _taskContext = None + + _attemptNumber = None + _partitionId = None + _stageId = None + _taskAttemptId = None + + def __new__(cls): + """Even if users construct TaskContext instead of using get, give them the singleton.""" + taskContext = cls._taskContext + if taskContext is not None: + return taskContext + cls._taskContext = taskContext = object.__new__(cls) + return taskContext + + def __init__(self): + """Construct a TaskContext, use get instead""" + pass + + @classmethod + def _getOrCreate(cls): + """Internal function to get or create global TaskContext.""" + if cls._taskContext is None: + cls._taskContext = TaskContext() + return cls._taskContext + + @classmethod + def get(cls): + """ + Return the currently active TaskContext. This can be called inside of + user functions to access contextual information about running tasks. + + .. note:: Must be called on the worker, not the driver. Returns None if not initialized. + """ + return cls._taskContext + + def stageId(self): + """The ID of the stage that this task belong to.""" + return self._stageId + + def partitionId(self): + """ + The ID of the RDD partition that is computed by this task. + """ + return self._partitionId + + def attemptNumber(self): + """" + How many times this task has been attempted. The first task attempt will be assigned + attemptNumber = 0, and subsequent attempts will have increasing attempt numbers. + """ + return self._attemptNumber + + def taskAttemptId(self): + """ + An ID that is unique to this task attempt (within the same SparkContext, no two task + attempts will share the same attempt ID). This is roughly equivalent to Hadoop's + TaskAttemptID. + """ + return self._taskAttemptId diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index ab4bef8329..c383d9ab67 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -69,6 +69,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter from pyspark import shuffle from pyspark.profiler import BasicProfiler +from pyspark.taskcontext import TaskContext _have_scipy = False _have_numpy = False @@ -478,6 +479,70 @@ def func(x): self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) +class TaskContextTests(PySparkTestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + # Allow retries even though they are normally disabled in local mode + self.sc = SparkContext('local[4, 2]', class_name) + + def test_stage_id(self): + """Test the stage ids are available and incrementing as expected.""" + rdd = self.sc.parallelize(range(10)) + stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] + stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0] + # Test using the constructor directly rather than the get() + stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0] + self.assertEqual(stage1 + 1, stage2) + self.assertEqual(stage1 + 2, stage3) + self.assertEqual(stage2 + 1, stage3) + + def test_partition_id(self): + """Test the partition id.""" + rdd1 = self.sc.parallelize(range(10), 1) + rdd2 = self.sc.parallelize(range(10), 2) + pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect() + pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect() + self.assertEqual(0, pids1[0]) + self.assertEqual(0, pids1[9]) + self.assertEqual(0, pids2[0]) + self.assertEqual(1, pids2[9]) + + def test_attempt_number(self): + """Verify the attempt numbers are correctly reported.""" + rdd = self.sc.parallelize(range(10)) + # Verify a simple job with no failures + attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect() + map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers) + + def fail_on_first(x): + """Fail on the first attempt so we get a positive attempt number""" + tc = TaskContext.get() + attempt_number = tc.attemptNumber() + partition_id = tc.partitionId() + attempt_id = tc.taskAttemptId() + if attempt_number == 0 and partition_id == 0: + raise Exception("Failing on first attempt") + else: + return [x, partition_id, attempt_number, attempt_id] + result = rdd.map(fail_on_first).collect() + # We should re-submit the first partition to it but other partitions should be attempt 0 + self.assertEqual([0, 0, 1], result[0][0:3]) + self.assertEqual([9, 3, 0], result[9][0:3]) + first_partition = filter(lambda x: x[1] == 0, result) + map(lambda x: self.assertEqual(1, x[2]), first_partition) + other_partitions = filter(lambda x: x[1] != 0, result) + map(lambda x: self.assertEqual(0, x[2]), other_partitions) + # The task attempt id should be different + self.assertTrue(result[0][3] != result[9][3]) + + def test_tc_on_driver(self): + """Verify that getting the TaskContext on the driver returns None.""" + tc = TaskContext.get() + self.assertTrue(tc is None) + + class RDDTests(ReusedPySparkTestCase): def test_range(self): @@ -502,6 +567,18 @@ def test_sum(self): self.assertEqual(0, self.sc.emptyRDD().sum()) self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) + def test_to_localiterator(self): + from time import sleep + rdd = self.sc.parallelize([1, 2, 3]) + it = rdd.toLocalIterator() + sleep(5) + self.assertEqual([1, 2, 3], sorted(it)) + + rdd2 = rdd.repartition(1000) + it2 = rdd2.toLocalIterator() + sleep(5) + self.assertEqual([1, 2, 3], sorted(it2)) + def test_save_as_textfile_with_unicode(self): # Regression test for SPARK-970 x = u"\u00A1Hola, mundo!" @@ -548,6 +625,24 @@ def test_cartesian_on_textfile(self): self.assertEqual(u"Hello World!", x.strip()) self.assertEqual(u"Hello World!", y.strip()) + def test_cartesian_chaining(self): + # Tests for SPARK-16589 + rdd = self.sc.parallelize(range(10), 2) + self.assertSetEqual( + set(rdd.cartesian(rdd).cartesian(rdd).collect()), + set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)]) + ) + + self.assertSetEqual( + set(rdd.cartesian(rdd.cartesian(rdd)).collect()), + set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)]) + ) + + self.assertSetEqual( + set(rdd.cartesian(rdd.zip(rdd)).collect()), + set([(x, (y, y)) for x in range(10) for y in range(10)]) + ) + def test_deleting_input_files(self): # Regression test for SPARK-1025 tempFile = tempfile.NamedTemporaryFile(delete=False) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 0918282953..25ee475c7f 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -27,6 +27,7 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry +from pyspark.taskcontext import TaskContext from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer @@ -125,6 +126,11 @@ def main(infile, outfile): ("%d.%d" % sys.version_info[:2], version)) # initialize global state + taskContext = TaskContext._getOrCreate() + taskContext._stageId = read_int(infile) + taskContext._partitionId = read_int(infile) + taskContext._attemptNumber = read_int(infile) + taskContext._taskAttemptId = read_long(infile) shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() diff --git a/python/setup.py b/python/setup.py index 625aea0407..bc2eb4ce9d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -69,10 +69,14 @@ EXAMPLES_PATH = os.path.join(SPARK_HOME, "examples/src/main/python") SCRIPTS_PATH = os.path.join(SPARK_HOME, "bin") +DATA_PATH = os.path.join(SPARK_HOME, "data") +LICENSES_PATH = os.path.join(SPARK_HOME, "licenses") + SCRIPTS_TARGET = os.path.join(TEMP_PATH, "bin") JARS_TARGET = os.path.join(TEMP_PATH, "jars") EXAMPLES_TARGET = os.path.join(TEMP_PATH, "examples") - +DATA_TARGET = os.path.join(TEMP_PATH, "data") +LICENSES_TARGET = os.path.join(TEMP_PATH, "licenses") # Check and see if we are under the spark path in which case we need to build the symlink farm. # This is important because we only want to build the symlink farm while under Spark otherwise we @@ -114,11 +118,15 @@ def _supports_symlinks(): os.symlink(JARS_PATH, JARS_TARGET) os.symlink(SCRIPTS_PATH, SCRIPTS_TARGET) os.symlink(EXAMPLES_PATH, EXAMPLES_TARGET) + os.symlink(DATA_PATH, DATA_TARGET) + os.symlink(LICENSES_PATH, LICENSES_TARGET) else: # For windows fall back to the slower copytree copytree(JARS_PATH, JARS_TARGET) copytree(SCRIPTS_PATH, SCRIPTS_TARGET) copytree(EXAMPLES_PATH, EXAMPLES_TARGET) + copytree(DATA_PATH, DATA_TARGET) + copytree(LICENSES_PATH, LICENSES_TARGET) else: # If we are not inside of SPARK_HOME verify we have the required symlink farm if not os.path.exists(JARS_TARGET): @@ -161,18 +169,24 @@ def _supports_symlinks(): 'pyspark.jars', 'pyspark.python.pyspark', 'pyspark.python.lib', + 'pyspark.data', + 'pyspark.licenses', 'pyspark.examples.src.main.python'], include_package_data=True, package_dir={ 'pyspark.jars': 'deps/jars', 'pyspark.bin': 'deps/bin', 'pyspark.python.lib': 'lib', + 'pyspark.data': 'deps/data', + 'pyspark.licenses': 'deps/licenses', 'pyspark.examples.src.main.python': 'deps/examples', }, package_data={ 'pyspark.jars': ['*.jar'], 'pyspark.bin': ['*'], 'pyspark.python.lib': ['*.zip'], + 'pyspark.data': ['*.txt', '*.data'], + 'pyspark.licenses': ['*.txt'], 'pyspark.examples.src.main.python': ['*.py', '*/*.py']}, scripts=scripts, license='http://www.apache.org/licenses/LICENSE-2.0', @@ -202,8 +216,12 @@ def _supports_symlinks(): os.remove(os.path.join(TEMP_PATH, "jars")) os.remove(os.path.join(TEMP_PATH, "bin")) os.remove(os.path.join(TEMP_PATH, "examples")) + os.remove(os.path.join(TEMP_PATH, "data")) + os.remove(os.path.join(TEMP_PATH, "licenses")) else: rmtree(os.path.join(TEMP_PATH, "jars")) rmtree(os.path.join(TEMP_PATH, "bin")) rmtree(os.path.join(TEMP_PATH, "examples")) + rmtree(os.path.join(TEMP_PATH, "data")) + rmtree(os.path.join(TEMP_PATH, "licenses")) os.rmdir(TEMP_PATH) diff --git a/repl/pom.xml b/repl/pom.xml index 73493e600e..a256ae3b84 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../pom.xml @@ -92,6 +92,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + org.apache.xbean xbean-asm5-shaded diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index 3d622d42f4..6d274bddb7 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -24,7 +24,6 @@ import java.nio.charset.StandardCharsets import java.nio.file.{Paths, StandardOpenOption} import java.util -import scala.concurrent.duration._ import scala.io.Source import scala.language.implicitConversions @@ -34,8 +33,6 @@ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterAll -import org.scalatest.concurrent.Interruptor -import org.scalatest.concurrent.Timeouts._ import org.scalatest.mock.MockitoSugar import org.apache.spark._ @@ -61,7 +58,7 @@ class ExecutorClassLoaderSuite super.beforeAll() tempDir1 = Utils.createTempDir() tempDir2 = Utils.createTempDir() - url1 = "file://" + tempDir1 + url1 = tempDir1.toURI.toURL.toString urls2 = List(tempDir2.toURI.toURL).toArray childClassNames.foreach(TestUtils.createCompiledClass(_, tempDir1, "1")) parentResourceNames.foreach { x => @@ -118,8 +115,14 @@ class ExecutorClassLoaderSuite val resourceName: String = parentResourceNames.head val is = classLoader.getResourceAsStream(resourceName) assert(is != null, s"Resource $resourceName not found") - val content = Source.fromInputStream(is, "UTF-8").getLines().next() - assert(content.contains("resource"), "File doesn't contain 'resource'") + + val bufferedSource = Source.fromInputStream(is, "UTF-8") + Utils.tryWithSafeFinally { + val content = bufferedSource.getLines().next() + assert(content.contains("resource"), "File doesn't contain 'resource'") + } { + bufferedSource.close() + } } test("resources from parent") { @@ -128,8 +131,14 @@ class ExecutorClassLoaderSuite val resourceName: String = parentResourceNames.head val resources: util.Enumeration[URL] = classLoader.getResources(resourceName) assert(resources.hasMoreElements, s"Resource $resourceName not found") - val fileReader = Source.fromInputStream(resources.nextElement().openStream()).bufferedReader() - assert(fileReader.readLine().contains("resource"), "File doesn't contain 'resource'") + + val bufferedSource = Source.fromInputStream(resources.nextElement().openStream()) + Utils.tryWithSafeFinally { + val fileReader = bufferedSource.bufferedReader() + assert(fileReader.readLine().contains("resource"), "File doesn't contain 'resource'") + } { + bufferedSource.close() + } } test("fetch classes using Spark's RpcEnv") { diff --git a/mesos/pom.xml b/resource-managers/mesos/pom.xml similarity index 97% rename from mesos/pom.xml rename to resource-managers/mesos/pom.xml index 57cc26a4cc..c0a8f9a344 100644 --- a/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,8 +20,8 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT - ../pom.xml + 2.2.0-SNAPSHOT + ../../pom.xml spark-mesos_2.11 diff --git a/mesos/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager b/resource-managers/mesos/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager similarity index 100% rename from mesos/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager rename to resource-managers/mesos/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala diff --git a/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala diff --git a/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala similarity index 99% rename from mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 9cb6023704..1d742fefbb 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -80,6 +80,8 @@ trait MesosSchedulerUtils extends Logging { frameworkId.foreach { id => fwInfoBuilder.setId(FrameworkID.newBuilder().setValue(id).build()) } + fwInfoBuilder.setHostname(Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse( + conf.get(DRIVER_HOST_ADDRESS))) conf.getOption("spark.mesos.principal").foreach { principal => fwInfoBuilder.setPrincipal(principal) credBuilder.setPrincipal(principal) diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala similarity index 100% rename from mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala rename to resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala diff --git a/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala similarity index 100% rename from mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala rename to resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala diff --git a/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherSuite.scala similarity index 100% rename from mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherSuite.scala rename to resource-managers/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherSuite.scala diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManagerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManagerSuite.scala similarity index 100% rename from mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManagerSuite.scala rename to resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManagerSuite.scala diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala similarity index 100% rename from mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala rename to resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala similarity index 100% rename from mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala rename to resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala similarity index 100% rename from mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala rename to resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala similarity index 100% rename from mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala rename to resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala similarity index 100% rename from mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala rename to resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchDataSuite.scala diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala similarity index 100% rename from mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala rename to resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala diff --git a/yarn/pom.xml b/resource-managers/yarn/pom.xml similarity index 98% rename from yarn/pom.xml rename to resource-managers/yarn/pom.xml index 64ff845b5a..f090d2427d 100644 --- a/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,8 +20,8 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT - ../pom.xml + 2.2.0-SNAPSHOT + ../../pom.xml spark-yarn_2.11 @@ -54,6 +54,8 @@ org.apache.spark spark-tags_${scala.binary.version} + test-jar + test org.apache.hadoop diff --git a/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider b/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider similarity index 100% rename from yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider rename to resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider diff --git a/yarn/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager b/resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager similarity index 100% rename from yarn/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager rename to resource-managers/yarn/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala similarity index 99% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 0378ef4fac..f79c66b9ff 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -692,11 +692,11 @@ private[spark] class ApplicationMaster( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount) => + case r: RequestExecutors => Option(allocator) match { case Some(a) => - if (a.requestTotalExecutorsWithPreferredLocalities(requestedTotal, - localityAwareTasks, hostToLocalTaskCount)) { + if (a.requestTotalExecutorsWithPreferredLocalities(r.requestedTotal, + r.localityAwareTasks, r.hostToLocalTaskCount, r.nodeBlacklist)) { resetAllocatorInterval() } context.reply(true) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala similarity index 96% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 0b66d1cf08..e498932e51 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -114,6 +114,8 @@ private[yarn] class YarnAllocator( @volatile private var targetNumExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf) + private var currentNodeBlacklist = Set.empty[String] + // Executor loss reason requests that are pending - maps from executor ID for inquiry to a // list of requesters that should be responded to once we find out why the given executor // was lost. @@ -217,18 +219,35 @@ private[yarn] class YarnAllocator( * @param localityAwareTasks number of locality aware tasks to be used as container placement hint * @param hostToLocalTaskCount a map of preferred hostname to possible task counts to be used as * container placement hint. + * @param nodeBlacklist a set of blacklisted nodes, which is passed in to avoid allocating new + * containers on them. It will be used to update the application master's + * blacklist. * @return Whether the new requested total is different than the old value. */ def requestTotalExecutorsWithPreferredLocalities( requestedTotal: Int, localityAwareTasks: Int, - hostToLocalTaskCount: Map[String, Int]): Boolean = synchronized { + hostToLocalTaskCount: Map[String, Int], + nodeBlacklist: Set[String]): Boolean = synchronized { this.numLocalityAwareTasks = localityAwareTasks this.hostToLocalTaskCounts = hostToLocalTaskCount if (requestedTotal != targetNumExecutors) { logInfo(s"Driver requested a total number of $requestedTotal executor(s).") targetNumExecutors = requestedTotal + + // Update blacklist infomation to YARN ResouceManager for this application, + // in order to avoid allocating new Containers on the problematic nodes. + val blacklistAdditions = nodeBlacklist -- currentNodeBlacklist + val blacklistRemovals = currentNodeBlacklist -- nodeBlacklist + if (blacklistAdditions.nonEmpty) { + logInfo(s"adding nodes to YARN application master's blacklist: $blacklistAdditions") + } + if (blacklistRemovals.nonEmpty) { + logInfo(s"removing nodes from YARN application master's blacklist: $blacklistRemovals") + } + amClient.updateBlacklist(blacklistAdditions.toList.asJava, blacklistRemovals.toList.asJava) + currentNodeBlacklist = nodeBlacklist true } else { false diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/AMCredentialRenewer.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManager.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HBaseCredentialProvider.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProvider.scala similarity index 88% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProvider.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProvider.scala index 8d06d735ba..ebb176bc95 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProvider.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProvider.scala @@ -72,21 +72,22 @@ private[security] class HDFSCredentialProvider extends ServiceCredentialProvider // We cannot use the tokens generated with renewer yarn. Trying to renew // those will fail with an access control issue. So create new tokens with the logged in // user as renewer. - sparkConf.get(PRINCIPAL).map { renewer => + sparkConf.get(PRINCIPAL).flatMap { renewer => val creds = new Credentials() nnsToAccess(hadoopConf, sparkConf).foreach { dst => val dstFs = dst.getFileSystem(hadoopConf) dstFs.addDelegationTokens(renewer, creds) } - val t = creds.getAllTokens.asScala - .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) - .head - val newExpiration = t.renew(hadoopConf) - val identifier = new DelegationTokenIdentifier() - identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier))) - val interval = newExpiration - identifier.getIssueDate - logInfo(s"Renewal Interval is $interval") - interval + val hdfsToken = creds.getAllTokens.asScala + .find(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) + hdfsToken.map { t => + val newExpiration = t.renew(hadoopConf) + val identifier = new DelegationTokenIdentifier() + identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier))) + val interval = newExpiration - identifier.getIssueDate + logInfo(s"Renewal Interval is $interval") + interval + } } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/HiveCredentialProvider.scala diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala diff --git a/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerExtensionService.scala diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala similarity index 100% rename from yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala similarity index 95% rename from yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala rename to resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 2f9ea1911f..cbc6e60e83 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -121,13 +121,21 @@ private[spark] abstract class YarnSchedulerBackend( } } + private[cluster] def prepareRequestExecutors(requestedTotal: Int): RequestExecutors = { + val nodeBlacklist: Set[String] = scheduler.nodeBlacklist() + // For locality preferences, ignore preferences for nodes that are blacklisted + val filteredHostToLocalTaskCount = + hostToLocalTaskCount.filter { case (k, v) => !nodeBlacklist.contains(k) } + RequestExecutors(requestedTotal, localityAwareTasks, filteredHostToLocalTaskCount, + nodeBlacklist) + } + /** * Request executors from the ApplicationMaster by specifying the total number desired. * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { - yarnSchedulerEndpointRef.ask[Boolean]( - RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) + yarnSchedulerEndpointRef.ask[Boolean](prepareRequestExecutors(requestedTotal)) } /** diff --git a/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider b/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider similarity index 100% rename from yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider rename to resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider diff --git a/yarn/src/test/resources/log4j.properties b/resource-managers/yarn/src/test/resources/log4j.properties similarity index 100% rename from yarn/src/test/resources/log4j.properties rename to resource-managers/yarn/src/test/resources/log4j.properties diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala similarity index 91% rename from yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 994dc75d34..fcc0594cf6 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.deploy.yarn import java.util.{Arrays, List => JList} +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.CommonConfigurationKeysPublic import org.apache.hadoop.net.DNSToSwitchMapping import org.apache.hadoop.yarn.api.records._ @@ -90,7 +92,9 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter override def equals(other: Any): Boolean = false } - def createAllocator(maxExecutors: Int = 5): YarnAllocator = { + def createAllocator( + maxExecutors: Int = 5, + rmClient: AMRMClient[ContainerRequest] = rmClient): YarnAllocator = { val args = Array( "--jar", "somejar.jar", "--class", "SomeClass") @@ -202,7 +206,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumExecutorsRunning should be (0) handler.getPendingAllocate.size should be (4) - handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty) + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty, Set.empty) handler.updateResourceRequests() handler.getPendingAllocate.size should be (3) @@ -213,7 +217,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId) - handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map.empty) + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map.empty, Set.empty) handler.updateResourceRequests() handler.getPendingAllocate.size should be (1) } @@ -224,7 +228,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumExecutorsRunning should be (0) handler.getPendingAllocate.size should be (4) - handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty) + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty, Set.empty) handler.updateResourceRequests() handler.getPendingAllocate.size should be (3) @@ -234,7 +238,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumExecutorsRunning should be (2) - handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty) + handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty, Set.empty) handler.updateResourceRequests() handler.getPendingAllocate.size should be (0) handler.getNumExecutorsRunning should be (2) @@ -250,7 +254,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter val container2 = createContainer("host2") handler.handleAllocatedContainers(Array(container1, container2)) - handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty) + handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty, Set.empty) handler.executorIdToContainer.keys.foreach { id => handler.killExecutor(id ) } val statuses = Seq(container1, container2).map { c => @@ -272,7 +276,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter val container2 = createContainer("host2") handler.handleAllocatedContainers(Array(container1, container2)) - handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map()) + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map(), Set.empty) val statuses = Seq(container1, container2).map { c => ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Failed", -1) @@ -286,6 +290,21 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumUnexpectedContainerRelease should be (2) } + test("blacklisted nodes reflected in amClient requests") { + // Internally we track the set of blacklisted nodes, but yarn wants us to send *changes* + // to the blacklist. This makes sure we are sending the right updates. + val mockAmClient = mock(classOf[AMRMClient[ContainerRequest]]) + val handler = createAllocator(4, mockAmClient) + handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map(), Set("hostA")) + verify(mockAmClient).updateBlacklist(Seq("hostA").asJava, Seq[String]().asJava) + + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map(), Set("hostA", "hostB")) + verify(mockAmClient).updateBlacklist(Seq("hostB").asJava, Seq[String]().asJava) + + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map(), Set()) + verify(mockAmClient).updateBlacklist(Seq[String]().asJava, Seq("hostA", "hostB").asJava) + } + test("memory exceeded diagnostic regexes") { val diagnostics = "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " + diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/ConfigurableCredentialManagerSuite.scala diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProviderSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProviderSuite.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProviderSuite.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/HDFSCredentialProviderSuite.scala diff --git a/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala diff --git a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/ExtensionServiceIntegrationSuite.scala diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/SimpleExtensionService.scala diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationAttemptId.scala diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala similarity index 100% rename from yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala rename to resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/StubApplicationId.scala diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 48333851ef..1f48d71cc7 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -200,7 +200,6 @@ This file is divided into 3 sections: // scalastyle:off awaitresult Await.result(...) // scalastyle:on awaitresult - If your codes use ThreadLocal and may run in threads created by the user, use ThreadUtils.awaitResultInForkJoinSafely instead. ]]> diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 82a5a85317..765c92b8d3 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -56,6 +56,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + org.apache.spark spark-unsafe_${scala.binary.version} diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 075c73d7a3..a34087cb6c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -89,6 +89,8 @@ statement SET TBLPROPERTIES tablePropertyList #setTableProperties | ALTER (TABLE | VIEW) tableIdentifier UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties + | ALTER TABLE tableIdentifier partitionSpec? + CHANGE COLUMN? identifier colType colPosition? #changeColumn | ALTER TABLE tableIdentifier (partitionSpec)? SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe | ALTER TABLE tableIdentifier (partitionSpec)? @@ -120,8 +122,10 @@ statement (USING resource (',' resource)*)? #createFunction | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN)? statement #explain - | SHOW TABLES EXTENDED? ((FROM | IN) db=identifier)? - (LIKE? pattern=STRING)? partitionSpec? #showTables + | SHOW TABLES ((FROM | IN) db=identifier)? + (LIKE? pattern=STRING)? #showTables + | SHOW TABLE EXTENDED ((FROM | IN) db=identifier)? + LIKE pattern=STRING partitionSpec? #showTable | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases | SHOW TBLPROPERTIES table=tableIdentifier ('(' key=tablePropertyKey ')')? #showTblProperties @@ -192,7 +196,6 @@ unsupportedHiveNativeCommands | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=CONCATENATE | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=SET kw4=FILEFORMAT | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=ADD kw4=COLUMNS - | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=CHANGE kw4=COLUMN? | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=REPLACE kw4=COLUMNS | kw1=START kw2=TRANSACTION | kw1=COMMIT @@ -578,6 +581,10 @@ intervalValue | STRING ; +colPosition + : FIRST | AFTER identifier + ; + dataType : complex=ARRAY '<' dataType '>' #complexDataType | complex=MAP '<' dataType ',' dataType '>' #complexDataType @@ -669,7 +676,7 @@ number nonReserved : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS | DATABASES | ADD - | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST + | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST | AFTER | MAP | ARRAY | STRUCT | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED @@ -759,6 +766,7 @@ PRECEDING: 'PRECEDING'; FOLLOWING: 'FOLLOWING'; CURRENT: 'CURRENT'; FIRST: 'FIRST'; +AFTER: 'AFTER'; LAST: 'LAST'; ROW: 'ROW'; WITH: 'WITH'; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java index 49a18df2c7..cf0579fd36 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java @@ -19,7 +19,7 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.InternalOutputModes; +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes; /** * :: Experimental :: @@ -46,7 +46,7 @@ public static OutputMode Append() { /** * OutputMode in which all the rows in the streaming DataFrame/Dataset will be written - * to the sink every time these is some updates. This output mode can only be used in queries + * to the sink every time there are some updates. This output mode can only be used in queries * that contain aggregations. * * @since 2.0.0 @@ -54,4 +54,14 @@ public static OutputMode Append() { public static OutputMode Complete() { return InternalOutputModes.Complete$.MODULE$; } + + /** + * OutputMode in which only the rows that were updated in the streaming DataFrame/Dataset will + * be written to the sink every time there are some updates. + * + * @since 2.1.1 + */ + public static OutputMode Update() { + return InternalOutputModes.Update$.MODULE$; + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index f498e071b5..256f64e320 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{DataType, Decimal, StructType} /** - * An abstract class for row used internal in Spark SQL, which only contain the columns as + * An abstract class for row used internally in Spark SQL, which only contains the columns as * internal types. */ abstract class InternalRow extends SpecializedGetters with Serializable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 7e8e4dab72..8b53d988cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import java.beans.{Introspector, PropertyDescriptor} import java.lang.{Iterable => JIterable} +import java.lang.reflect.Type import java.util.{Iterator => JIterator, List => JList, Map => JMap} import scala.language.existentials @@ -54,12 +55,21 @@ object JavaTypeInference { inferDataType(TypeToken.of(beanClass)) } + /** + * Infers the corresponding SQL data type of a Java type. + * @param beanType Java type + * @return (SQL data type, nullable) + */ + private[sql] def inferDataType(beanType: Type): (DataType, Boolean) = { + inferDataType(TypeToken.of(beanType)) + } + /** * Infers the corresponding SQL data type of a Java type. * @param typeToken Java type * @return (SQL data type, nullable) */ - private[sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { + private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 6e20096901..ad218cf88d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -342,7 +342,7 @@ object ScalaReflection extends ScalaReflection { StaticInvoke( ArrayBasedMapData.getClass, - ObjectType(classOf[Map[_, _]]), + ObjectType(classOf[scala.collection.immutable.Map[_, _]]), "toScalaMap", keyData :: valueData :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8faf0eda54..73e92066b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -463,14 +463,15 @@ class Analyzer( .toAggregateExpression() , "__pivot_" + a.sql)() } - val secondAgg = Aggregate(groupByExprs, groupByExprs ++ pivotAggs, firstAgg) + val groupByExprsAttr = groupByExprs.map(_.toAttribute) + val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg) val pivotAggAttribute = pivotAggs.map(_.toAttribute) val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) => aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) => Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))() } } - Project(groupByExprs ++ pivotOutputs, secondAgg) + Project(groupByExprsAttr ++ pivotOutputs, secondAgg) } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => def ifExpr(expr: Expression) = { @@ -1011,24 +1012,24 @@ class Analyzer( private def pullOutCorrelatedPredicates(sub: LogicalPlan): (LogicalPlan, Seq[Expression]) = { val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, Seq[Expression]] - /** Make sure a plans' subtree does not contain a tagged predicate. */ - def failOnOuterReferenceInSubTree(p: LogicalPlan, msg: String): Unit = { - if (p.collect(predicateMap).nonEmpty) { - failAnalysis(s"Accessing outer query column is not allowed in $msg: $p") + // Make sure a plan's subtree does not contain outer references + def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { + if (p.collectFirst(predicateMap).nonEmpty) { + failAnalysis(s"Accessing outer query column is not allowed in:\n$p") } } - /** Helper function for locating outer references. */ + // Helper function for locating outer references. def containsOuter(e: Expression): Boolean = { e.find(_.isInstanceOf[OuterReference]).isDefined } - /** Make sure a plans' expressions do not contain a tagged predicate. */ + // Make sure a plan's expressions do not contain outer references def failOnOuterReference(p: LogicalPlan): Unit = { if (p.expressions.exists(containsOuter)) { failAnalysis( "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + - s"clauses: $p") + s"clauses:\n$p") } } @@ -1077,10 +1078,51 @@ class Analyzer( // Simplify the predicates before pulling them out. val transformed = BooleanSimplification(sub) transformUp { - // WARNING: - // Only Filter can host correlated expressions at this time - // Anyone adding a new "case" below needs to add the call to - // "failOnOuterReference" to disallow correlated expressions in it. + + // Whitelist operators allowed in a correlated subquery + // There are 4 categories: + // 1. Operators that are allowed anywhere in a correlated subquery, and, + // by definition of the operators, they either do not contain + // any columns or cannot host outer references. + // 2. Operators that are allowed anywhere in a correlated subquery + // so long as they do not host outer references. + // 3. Operators that need special handlings. These operators are + // Project, Filter, Join, Aggregate, and Generate. + // + // Any operators that are not in the above list are allowed + // in a correlated subquery only if they are not on a correlation path. + // In other word, these operators are allowed only under a correlation point. + // + // A correlation path is defined as the sub-tree of all the operators that + // are on the path from the operator hosting the correlated expressions + // up to the operator producing the correlated values. + + // Category 1: + // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias + case p: BroadcastHint => + p + case p: Distinct => + p + case p: LeafNode => + p + case p: Repartition => + p + case p: SubqueryAlias => + p + + // Category 2: + // These operators can be anywhere in a correlated subquery. + // so long as they do not host outer references in the operators. + case p: Sort => + failOnOuterReference(p) + p + case p: RepartitionByExpression => + failOnOuterReference(p) + p + + // Category 3: + // Filter is one of the two operators allowed to host correlated expressions. + // The other operator is Join. Filter can be anywhere in a correlated subquery. case f @ Filter(cond, child) => // Find all predicates with an outer reference. val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) @@ -1102,14 +1144,24 @@ class Analyzer( predicateMap += child -> xs child } + + // Project cannot host any correlated expressions + // but can be anywhere in a correlated subquery. case p @ Project(expressions, child) => failOnOuterReference(p) + val referencesToAdd = missingReferences(p) if (referencesToAdd.nonEmpty) { Project(expressions ++ referencesToAdd, child) } else { p } + + // Aggregate cannot host any correlated expressions + // It can be on a correlation path if the correlation contains + // only equality correlated predicates. + // It cannot be on a correlation path if the correlation has + // non-equality correlated predicates. case a @ Aggregate(grouping, expressions, child) => failOnOuterReference(a) failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) @@ -1120,48 +1172,55 @@ class Analyzer( } else { a } - case w : Window => - failOnOuterReference(w) - failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, w) - w - case j @ Join(left, _, RightOuter, _) => - failOnOuterReference(j) - failOnOuterReferenceInSubTree(left, "a RIGHT OUTER JOIN") - j - // SPARK-18578: Do not allow any correlated predicate - // in a Full (Outer) Join operator and its descendants - case j @ Join(_, _, FullOuter, _) => - failOnOuterReferenceInSubTree(j, "a FULL OUTER JOIN") - j - case j @ Join(_, right, jt, _) if !jt.isInstanceOf[InnerLike] => - failOnOuterReference(j) - failOnOuterReferenceInSubTree(right, "a LEFT (OUTER) JOIN") + + // Join can host correlated expressions. + case j @ Join(left, right, joinType, _) => + joinType match { + // Inner join, like Filter, can be anywhere. + case _: InnerLike => + failOnOuterReference(j) + + // Left outer join's right operand cannot be on a correlation path. + // LeftAnti and ExistenceJoin are special cases of LeftOuter. + // Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame + // so it should not show up here in Analysis phase. This is just a safety net. + // + // LeftSemi does not allow output from the right operand. + // Any correlated references in the subplan + // of the right operand cannot be pulled up. + case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => + failOnOuterReference(j) + failOnOuterReferenceInSubTree(right) + + // Likewise, Right outer join's left operand cannot be on a correlation path. + case RightOuter => + failOnOuterReference(j) + failOnOuterReferenceInSubTree(left) + + // Any other join types not explicitly listed above, + // including Full outer join, are treated as Category 4. + case _ => + failOnOuterReferenceInSubTree(j) + } j - case u: Union => - failOnOuterReferenceInSubTree(u, "a UNION") - u - case s: SetOperation => - failOnOuterReferenceInSubTree(s.right, "an INTERSECT/EXCEPT") - s - case e: Expand => - failOnOuterReferenceInSubTree(e, "an EXPAND") - e - case l : LocalLimit => - failOnOuterReferenceInSubTree(l, "a LIMIT") - l - // Since LIMIT is represented as GlobalLimit(, (LocalLimit (, child)) - // and we are walking bottom up, we will fail on LocalLimit before - // reaching GlobalLimit. - // The code below is just a safety net. - case g : GlobalLimit => - failOnOuterReferenceInSubTree(g, "a LIMIT") - g - case s : Sample => - failOnOuterReferenceInSubTree(s, "a TABLESAMPLE") - s - case p => + + // Generator with join=true, i.e., expressed with + // LATERAL VIEW [OUTER], similar to inner join, + // allows to have correlation under it + // but must not host any outer references. + // Note: + // Generator with join=false is treated as Category 4. + case p @ Generate(generator, true, _, _, _, _) => failOnOuterReference(p) p + + // Category 4: Any other operators not in the above 3 categories + // cannot be on a correlation path, that is they are allowed only + // under a correlation point but they and their descendant operators + // are not allowed to have any correlated expressions. + case p => + failOnOuterReferenceInSubTree(p) + p } (transformed, predicateMap.values.flatten.toSeq) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 235a79973d..aa77a6efef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -124,6 +124,10 @@ trait CheckAnalysis extends PredicateHelper { s"Scalar subquery must return only one column, but got ${query.output.size}") case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty => + + // Collect the columns from the subquery for further checking. + var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains) + def checkAggregate(agg: Aggregate): Unit = { // Make sure correlated scalar subqueries contain one row for every outer row by // enforcing that they are aggregates which contain exactly one aggregate expressions. @@ -136,24 +140,35 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis("The output of a correlated scalar subquery must be aggregated") } - // SPARK-18504: block cases where GROUP BY columns - // are not part of the correlated columns - val groupByCols = ExpressionSet.apply(agg.groupingExpressions.flatMap(_.references)) - val predicateCols = ExpressionSet.apply(conditions.flatMap(_.references)) - val invalidCols = groupByCols.diff(predicateCols) + // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns + // are not part of the correlated columns. + val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) + val correlatedCols = AttributeSet(subqueryColumns) + val invalidCols = groupByCols -- correlatedCols // GROUP BY columns must be a subset of columns in the predicates if (invalidCols.nonEmpty) { failAnalysis( - "a GROUP BY clause in a scalar correlated subquery " + + "A GROUP BY clause in a scalar correlated subquery " + "cannot contain non-correlated columns: " + invalidCols.mkString(",")) } } - // Skip projects and subquery aliases added by the Analyzer and the SQLBuilder. + // Skip subquery aliases added by the Analyzer and the SQLBuilder. + // For projects, do the necessary mapping and skip to its child. def cleanQuery(p: LogicalPlan): LogicalPlan = p match { case s: SubqueryAlias => cleanQuery(s.child) - case p: Project => cleanQuery(p.child) + case p: Project => + // SPARK-18814: Map any aliases to their AttributeReference children + // for the checking in the Aggregate operators below this Project. + subqueryColumns = subqueryColumns.map { + xs => p.projectList.collectFirst { + case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId => + child + }.getOrElse(xs) + } + + cleanQuery(p.child) case child => child } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e41f1cad93..2b214c3c9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -252,6 +252,7 @@ object FunctionRegistry { expression[Percentile]("percentile"), expression[Skewness]("skewness"), expression[ApproximatePercentile]("percentile_approx"), + expression[ApproximatePercentile]("approx_percentile"), expression[StddevSamp]("std"), expression[StddevSamp]("stddev"), expression[StddevPop]("stddev_pop"), @@ -371,6 +372,8 @@ object FunctionRegistry { expression[Sha2]("sha2"), expression[SparkPartitionID]("spark_partition_id"), expression[InputFileName]("input_file_name"), + expression[InputFileBlockStart]("input_file_block_start"), + expression[InputFileBlockLength]("input_file_block_length"), expression[MonotonicallyIncreasingID]("monotonically_increasing_id"), expression[CurrentDatabase]("current_database"), expression[CallMethodViaReflection]("reflect"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 6662a9e974..cd73f9c897 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -673,48 +673,69 @@ object TypeCoercion { * If the expression has an incompatible type that cannot be implicitly cast, return None. */ def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = { - val inType = e.dataType + implicitCast(e.dataType, expectedType).map { dt => + if (dt == e.dataType) e else Cast(e, dt) + } + } + private def implicitCast(inType: DataType, expectedType: AbstractDataType): Option[DataType] = { // Note that ret is nullable to avoid typing a lot of Some(...) in this local scope. // We wrap immediately an Option after this. - @Nullable val ret: Expression = (inType, expectedType) match { - + @Nullable val ret: DataType = (inType, expectedType) match { // If the expected type is already a parent of the input type, no need to cast. - case _ if expectedType.acceptsType(inType) => e + case _ if expectedType.acceptsType(inType) => inType // Cast null type (usually from null literals) into target types - case (NullType, target) => Cast(e, target.defaultConcreteType) + case (NullType, target) => target.defaultConcreteType // If the function accepts any numeric type and the input is a string, we follow the hive // convention and cast that input into a double - case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType) + case (StringType, NumericType) => NumericType.defaultConcreteType // Implicit cast among numeric types. When we reach here, input type is not acceptable. // If input is a numeric type but not decimal, and we expect a decimal type, // cast the input to decimal. - case (d: NumericType, DecimalType) => Cast(e, DecimalType.forType(d)) + case (d: NumericType, DecimalType) => DecimalType.forType(d) // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long - case (_: NumericType, target: NumericType) => Cast(e, target) + case (_: NumericType, target: NumericType) => target // Implicit cast between date time types - case (DateType, TimestampType) => Cast(e, TimestampType) - case (TimestampType, DateType) => Cast(e, DateType) + case (DateType, TimestampType) => TimestampType + case (TimestampType, DateType) => DateType // Implicit cast from/to string - case (StringType, DecimalType) => Cast(e, DecimalType.SYSTEM_DEFAULT) - case (StringType, target: NumericType) => Cast(e, target) - case (StringType, DateType) => Cast(e, DateType) - case (StringType, TimestampType) => Cast(e, TimestampType) - case (StringType, BinaryType) => Cast(e, BinaryType) + case (StringType, DecimalType) => DecimalType.SYSTEM_DEFAULT + case (StringType, target: NumericType) => target + case (StringType, DateType) => DateType + case (StringType, TimestampType) => TimestampType + case (StringType, BinaryType) => BinaryType // Cast any atomic type to string. - case (any: AtomicType, StringType) if any != StringType => Cast(e, StringType) + case (any: AtomicType, StringType) if any != StringType => StringType // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. - case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull + case (_, TypeCollection(types)) => + types.flatMap(implicitCast(inType, _)).headOption.orNull + + // Implicit cast between array types. + // + // Compare the nullabilities of the from type and the to type, check whether the cast of + // the nullability is resolvable by the following rules: + // 1. If the nullability of the to type is true, the cast is always allowed; + // 2. If the nullability of the to type is false, and the nullability of the from type is + // true, the cast is never allowed; + // 3. If the nullabilities of both the from type and the to type are false, the cast is + // allowed only when Cast.forceNullable(fromType, toType) is false. + case (ArrayType(fromType, fn), ArrayType(toType: DataType, true)) => + implicitCast(fromType, toType).map(ArrayType(_, true)).orNull + + case (ArrayType(fromType, true), ArrayType(toType: DataType, false)) => null + + case (ArrayType(fromType, false), ArrayType(toType: DataType, false)) + if !Cast.forceNullable(fromType, toType) => + implicitCast(fromType, toType).map(ArrayType(_, false)).orNull - // Else, just return the same input expression case _ => null } Option(ret) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index c054fcbef3..053c8eb617 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.{AnalysisException, InternalOutputModes} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.streaming.OutputMode /** @@ -95,6 +97,16 @@ object UnsupportedOperationChecker { // Operations that cannot exists anywhere in a streaming plan subPlan match { + case Aggregate(_, aggregateExpressions, child) => + val distinctAggExprs = aggregateExpressions.flatMap { expr => + expr.collect { case ae: AggregateExpression if ae.isDistinct => ae } + } + throwErrorIf( + child.isStreaming && distinctAggExprs.nonEmpty, + "Distinct aggregations are not supported on streaming DataFrames/Datasets, unless " + + "it is on aggregated DataFrame/Dataset in Complete output mode. Consider using " + + "approximate distinct aggregation (e.g. approx_count_distinct() instead of count()).") + case _: Command => throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " + "streaming DataFrames/Datasets") @@ -143,7 +155,7 @@ object UnsupportedOperationChecker { throwError("Union between streaming and batch DataFrames/Datasets is not supported") case Except(left, right) if right.isStreaming => - throwError("Except with a streaming DataFrame/Dataset on the right is not supported") + throwError("Except on a streaming DataFrame/Dataset on the right is not supported") case Intersect(left, right) if left.isStreaming && right.isStreaming => throwError("Intersect between two streaming DataFrames/Datasets is not supported") @@ -154,9 +166,9 @@ object UnsupportedOperationChecker { case GlobalLimit(_, _) | LocalLimit(_, _) if subPlan.children.forall(_.isStreaming) => throwError("Limits are not supported on streaming DataFrames/Datasets") - case Sort(_, _, _) | SortPartitions(_, _) if !containsCompleteData(subPlan) => + case Sort(_, _, _) if !containsCompleteData(subPlan) => throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on" + - "aggregated DataFrame/Dataset in Complete mode") + "aggregated DataFrame/Dataset in Complete output mode") case Sample(_, _, _, _, child) if child.isStreaming => throwError("Sampling is not supported on streaming DataFrames/Datasets") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 259008f183..5233699fac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression /** - * Interface for the system catalog (of columns, partitions, tables, and databases). + * Interface for the system catalog (of functions, partitions, tables, and databases). * * This is only used for non-temporary items, and implementations must be thread-safe as they * can be accessed in multiple threads. This is an external catalog because it is expected to @@ -114,13 +114,26 @@ abstract class ExternalCatalog { def listTables(db: String, pattern: String): Seq[String] + /** + * Loads data into a table. + * + * @param isSrcLocal Whether the source data is local, as defined by the "LOAD DATA LOCAL" + * HiveQL command. + */ def loadTable( db: String, table: String, loadPath: String, isOverwrite: Boolean, - holdDDLTime: Boolean): Unit + holdDDLTime: Boolean, + isSrcLocal: Boolean): Unit + /** + * Loads data into a partition. + * + * @param isSrcLocal Whether the source data is local, as defined by the "LOAD DATA LOCAL" + * HiveQL command. + */ def loadPartition( db: String, table: String, @@ -128,7 +141,8 @@ abstract class ExternalCatalog { partition: TablePartitionSpec, isOverwrite: Boolean, holdDDLTime: Boolean, - inheritTableSpecs: Boolean): Unit + inheritTableSpecs: Boolean, + isSrcLocal: Boolean): Unit def loadDynamicPartitions( db: String, @@ -189,15 +203,37 @@ abstract class ExternalCatalog { table: String, spec: TablePartitionSpec): Option[CatalogTablePartition] + /** + * List the names of all partitions that belong to the specified table, assuming it exists. + * + * For a table with partition columns p1, p2, p3, each partition name is formatted as + * `p1=v1/p2=v2/p3=v3`. Each partition column name and value is an escaped path name, and can be + * decoded with the `ExternalCatalogUtils.unescapePathName` method. + * + * The returned sequence is sorted as strings. + * + * A partial partition spec may optionally be provided to filter the partitions returned, as + * described in the `listPartitions` method. + * + * @param db database name + * @param table table name + * @param partialSpec partition spec + */ + def listPartitionNames( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[String] + /** * List the metadata of all partitions that belong to the specified table, assuming it exists. * * A partial partition spec may optionally be provided to filter the partitions returned. * For instance, if there exist partitions (a='1', b='2'), (a='1', b='3') and (a='2', b='4'), * then a partial spec of (a='1') will return the first two only. + * * @param db database name * @param table table name - * @param partialSpec partition spec + * @param partialSpec partition spec */ def listPartitions( db: String, @@ -210,7 +246,7 @@ abstract class ExternalCatalog { * * @param db database name * @param table table name - * @param predicates partition-pruning predicates + * @param predicates partition-pruning predicates */ def listPartitionsByFilter( db: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 817c1ab688..4331841fbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec object ExternalCatalogUtils { @@ -133,4 +135,39 @@ object CatalogUtils { case o => o } } + + def normalizePartCols( + tableName: String, + tableCols: Seq[String], + partCols: Seq[String], + resolver: Resolver): Seq[String] = { + partCols.map(normalizeColumnName(tableName, tableCols, _, "partition", resolver)) + } + + def normalizeBucketSpec( + tableName: String, + tableCols: Seq[String], + bucketSpec: BucketSpec, + resolver: Resolver): BucketSpec = { + val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec + val normalizedBucketCols = bucketColumnNames.map { colName => + normalizeColumnName(tableName, tableCols, colName, "bucket", resolver) + } + val normalizedSortCols = sortColumnNames.map { colName => + normalizeColumnName(tableName, tableCols, colName, "sort", resolver) + } + BucketSpec(numBuckets, normalizedBucketCols, normalizedSortCols) + } + + private def normalizeColumnName( + tableName: String, + tableCols: Seq[String], + colName: String, + colType: String, + resolver: Resolver): String = { + tableCols.find(resolver(_, colName)).getOrElse { + throw new AnalysisException(s"$colType column $colName is not defined in table $tableName, " + + s"defined table columns are: ${tableCols.mkString(", ")}") + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 880a7a0dc4..816e4af2df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.util.StringUtils @@ -311,7 +312,8 @@ class InMemoryCatalog( table: String, loadPath: String, isOverwrite: Boolean, - holdDDLTime: Boolean): Unit = { + holdDDLTime: Boolean, + isSrcLocal: Boolean): Unit = { throw new UnsupportedOperationException("loadTable is not implemented") } @@ -322,7 +324,8 @@ class InMemoryCatalog( partition: TablePartitionSpec, isOverwrite: Boolean, holdDDLTime: Boolean, - inheritTableSpecs: Boolean): Unit = { + inheritTableSpecs: Boolean, + isSrcLocal: Boolean): Unit = { throw new UnsupportedOperationException("loadPartition is not implemented.") } @@ -488,6 +491,19 @@ class InMemoryCatalog( } } + override def listPartitionNames( + db: String, + table: String, + partialSpec: Option[TablePartitionSpec] = None): Seq[String] = synchronized { + val partitionColumnNames = getTable(db, table).partitionColumnNames + + listPartitions(db, table, partialSpec).map { partition => + partitionColumnNames.map { name => + escapePathName(name) + "=" + escapePathName(partition.spec(name)) + }.mkString("/") + }.sorted + } + override def listPartitions( db: String, table: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index da3a2079f4..e996a836fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -311,12 +311,13 @@ class SessionCatalog( name: TableIdentifier, loadPath: String, isOverwrite: Boolean, - holdDDLTime: Boolean): Unit = { + holdDDLTime: Boolean, + isSrcLocal: Boolean): Unit = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val table = formatTableName(name.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Some(db))) - externalCatalog.loadTable(db, table, loadPath, isOverwrite, holdDDLTime) + externalCatalog.loadTable(db, table, loadPath, isOverwrite, holdDDLTime, isSrcLocal) } /** @@ -330,13 +331,14 @@ class SessionCatalog( partition: TablePartitionSpec, isOverwrite: Boolean, holdDDLTime: Boolean, - inheritTableSpecs: Boolean): Unit = { + inheritTableSpecs: Boolean, + isSrcLocal: Boolean): Unit = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val table = formatTableName(name.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Some(db))) externalCatalog.loadPartition( - db, table, loadPath, partition, isOverwrite, holdDDLTime, inheritTableSpecs) + db, table, loadPath, partition, isOverwrite, holdDDLTime, inheritTableSpecs, isSrcLocal) } def defaultTablePath(tableIdent: TableIdentifier): String = { @@ -748,6 +750,26 @@ class SessionCatalog( externalCatalog.getPartition(db, table, spec) } + /** + * List the names of all partitions that belong to the specified table, assuming it exists. + * + * A partial partition spec may optionally be provided to filter the partitions returned. + * For instance, if there exist partitions (a='1', b='2'), (a='1', b='3') and (a='2', b='4'), + * then a partial spec of (a='1') will return the first two only. + */ + def listPartitionNames( + tableName: TableIdentifier, + partialSpec: Option[TablePartitionSpec] = None): Seq[String] = { + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) + val table = formatTableName(tableName.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + partialSpec.foreach { spec => + requirePartialMatchedPartitionSpec(Seq(spec), getTableMetadata(tableName)) + } + externalCatalog.listPartitionNames(db, table, partialSpec) + } + /** * List the metadata of all partitions that belong to the specified table, assuming it exists. * @@ -762,6 +784,9 @@ class SessionCatalog( val table = formatTableName(tableName.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Option(db))) + partialSpec.foreach { spec => + requirePartialMatchedPartitionSpec(Seq(spec), getTableMetadata(tableName)) + } externalCatalog.listPartitions(db, table, partialSpec) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index d2a1af0800..5b5378c09e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -133,6 +133,16 @@ case class BucketSpec( if (numBuckets <= 0) { throw new AnalysisException(s"Expected positive number of buckets, but got `$numBuckets`.") } + + override def toString: String = { + val bucketString = s"bucket columns: [${bucketColumnNames.mkString(", ")}]" + val sortString = if (sortColumnNames.nonEmpty) { + s", sort columns: [${sortColumnNames.mkString(", ")}]" + } else { + "" + } + s"$numBuckets buckets, $bucketString$sortString" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index e901683be6..66e52ca68a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -367,7 +367,7 @@ package object dsl { def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( analysis.UnresolvedRelation(TableIdentifier(tableName)), - Map.empty, logicalPlan, OverwriteOptions(overwrite), false) + Map.empty, logicalPlan, overwrite, false) def as(alias: String): LogicalPlan = logicalPlan match { case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 4db1ae6faa..741730e3e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -89,9 +89,7 @@ object Cast { case _ => false } - private def resolvableNullability(from: Boolean, to: Boolean) = !from || to - - private def forceNullable(from: DataType, to: DataType) = (from, to) match { + def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match { case (NullType, _) => true case (_, _) if from == to => false @@ -110,6 +108,8 @@ object Cast { case (_: FractionalType, _: IntegralType) => true // NaN, infinity case _ => false } + + private def resolvableNullability(from: Boolean, to: Boolean) = !from || to } /** Cast the child expression to the target data type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index b8e2b67b2f..6c246a5663 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable /** * This class is used to compute equality of (sub)expression trees. Expressions can be added @@ -72,7 +73,10 @@ class EquivalentExpressions { root: Expression, ignoreLeaf: Boolean = true, skipReferenceToExpressions: Boolean = true): Unit = { - val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf + val skip = (root.isInstanceOf[LeafExpression] && ignoreLeaf) || + // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the + // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. + root.find(_.isInstanceOf[LambdaVariable]).isDefined // There are some special expressions that we should not recurse into children. // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) // 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 221f830aa8..b93a5d0b7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -70,9 +70,9 @@ abstract class Expression extends TreeNode[Expression] { * children. * * Note that this means that an expression should be considered as non-deterministic if: - * - if it relies on some mutable internal state, or - * - if it relies on some implicit input that is not part of the children expression list. - * - if it has non-deterministic child or children. + * - it relies on some mutable internal state, or + * - it relies on some implicit input that is not part of the children expression list. + * - it has non-deterministic child or children. * * An example would be `SparkPartitionID` that relies on the partition id returned by TaskContext. * By default leaf expressions are deterministic as Nil.forall(_.deterministic) returns true. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala deleted file mode 100644 index d412336699..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.rdd.InputFileNameHolder -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.types.{DataType, StringType} -import org.apache.spark.unsafe.types.UTF8String - -/** - * Expression that returns the name of the current file being read. - */ -@ExpressionDescription( - usage = "_FUNC_() - Returns the name of the current file being read if available.") -case class InputFileName() extends LeafExpression with Nondeterministic { - - override def nullable: Boolean = false - - override def dataType: DataType = StringType - - override def prettyName: String = "input_file_name" - - override protected def initializeInternal(partitionIndex: Int): Unit = {} - - override protected def evalInternal(input: InternalRow): UTF8String = { - InputFileNameHolder.getInputFileName() - } - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + - "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();", isNull = "false") - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index c2cd8951bc..0e71442f89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -61,9 +61,9 @@ import org.apache.spark.sql.types._ """, extended = """ Examples: - > SELECT percentile_approx(10.0, array(0.5, 0.4, 0.1), 100); + > SELECT _FUNC_(10.0, array(0.5, 0.4, 0.1), 100); [10.0,10.0,10.0] - > SELECT percentile_approx(10.0, 0.5, 100); + > SELECT _FUNC_(10.0, 0.5, 100); 10.0 """) case class ApproximatePercentile( @@ -86,23 +86,16 @@ case class ApproximatePercentile( private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int] override def inputTypes: Seq[AbstractDataType] = { - Seq(DoubleType, TypeCollection(DoubleType, ArrayType), IntegerType) + Seq(DoubleType, TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType) } // Mark as lazy so that percentageExpression is not evaluated during tree transformation. - private lazy val (returnPercentileArray: Boolean, percentages: Array[Double]) = { - (percentageExpression.dataType, percentageExpression.eval()) match { + private lazy val (returnPercentileArray: Boolean, percentages: Array[Double]) = + percentageExpression.eval() match { // Rule ImplicitTypeCasts can cast other numeric types to double - case (_, num: Double) => (false, Array(num)) - case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => - val numericArray = arrayData.toObjectArray(baseType) - (true, numericArray.map { x => - baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType]) - }) - case other => - throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage") + case num: Double => (false, Array(num)) + case arrayData: ArrayData => (true, arrayData.toDoubleArray()) } - } override def checkInputDataTypes(): TypeCheckResult = { val defaultCheck = super.checkInputDataTypes() @@ -162,7 +155,7 @@ case class ApproximatePercentile( override def nullable: Boolean = true override def dataType: DataType = { - if (returnPercentileArray) ArrayType(DoubleType) else DoubleType + if (returnPercentileArray) ArrayType(DoubleType, false) else DoubleType } override def prettyName: String = "percentile_approx" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index b51b55313e..2f68195dd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -77,15 +77,9 @@ case class Percentile( private lazy val returnPercentileArray = percentageExpression.dataType.isInstanceOf[ArrayType] @transient - private lazy val percentages = - (percentageExpression.dataType, percentageExpression.eval()) match { - case (_, num: Double) => Seq(num) - case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => - val numericArray = arrayData.toObjectArray(baseType) - numericArray.map { x => - baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])}.toSeq - case other => - throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentages") + private lazy val percentages = percentageExpression.eval() match { + case num: Double => Seq(num) + case arrayData: ArrayData => arrayData.toDoubleArray().toSeq } override def children: Seq[Expression] = child :: percentageExpression :: Nil @@ -99,7 +93,7 @@ case class Percentile( } override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match { - case _: ArrayType => Seq(NumericType, ArrayType) + case _: ArrayType => Seq(NumericType, ArrayType(DoubleType)) case _ => Seq(NumericType, DoubleType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index afc190e697..bacedec1ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -64,19 +64,75 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val trueEval = trueValue.genCode(ctx) val falseEval = falseValue.genCode(ctx) - ev.copy(code = s""" - ${condEval.code} - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${condEval.isNull} && ${condEval.value}) { - ${trueEval.code} - ${ev.isNull} = ${trueEval.isNull}; - ${ev.value} = ${trueEval.value}; - } else { - ${falseEval.code} - ${ev.isNull} = ${falseEval.isNull}; - ${ev.value} = ${falseEval.value}; - }""") + // place generated code of condition, true value and false value in separate methods if + // their code combined is large + val combinedLength = condEval.code.length + trueEval.code.length + falseEval.code.length + val generatedCode = if (combinedLength > 1024 && + // Split these expressions only if they are created from a row object + (ctx.INPUT_ROW != null && ctx.currentVars == null)) { + + val (condFuncName, condGlobalIsNull, condGlobalValue) = + createAndAddFunction(ctx, condEval, predicate.dataType, "evalIfCondExpr") + val (trueFuncName, trueGlobalIsNull, trueGlobalValue) = + createAndAddFunction(ctx, trueEval, trueValue.dataType, "evalIfTrueExpr") + val (falseFuncName, falseGlobalIsNull, falseGlobalValue) = + createAndAddFunction(ctx, falseEval, falseValue.dataType, "evalIfFalseExpr") + s""" + $condFuncName(${ctx.INPUT_ROW}); + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!$condGlobalIsNull && $condGlobalValue) { + $trueFuncName(${ctx.INPUT_ROW}); + ${ev.isNull} = $trueGlobalIsNull; + ${ev.value} = $trueGlobalValue; + } else { + $falseFuncName(${ctx.INPUT_ROW}); + ${ev.isNull} = $falseGlobalIsNull; + ${ev.value} = $falseGlobalValue; + } + """ + } + else { + s""" + ${condEval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${condEval.isNull} && ${condEval.value}) { + ${trueEval.code} + ${ev.isNull} = ${trueEval.isNull}; + ${ev.value} = ${trueEval.value}; + } else { + ${falseEval.code} + ${ev.isNull} = ${falseEval.isNull}; + ${ev.value} = ${falseEval.value}; + } + """ + } + + ev.copy(code = generatedCode) + } + + private def createAndAddFunction( + ctx: CodegenContext, + ev: ExprCode, + dataType: DataType, + baseFuncName: String): (String, String, String) = { + val globalIsNull = ctx.freshName("isNull") + ctx.addMutableState("boolean", globalIsNull, s"$globalIsNull = false;") + val globalValue = ctx.freshName("value") + ctx.addMutableState(ctx.javaType(dataType), globalValue, + s"$globalValue = ${ctx.defaultValue(dataType)};") + val funcName = ctx.freshName(baseFuncName) + val funcBody = + s""" + |private void $funcName(InternalRow ${ctx.INPUT_ROW}) { + | ${ev.code.trim} + | $globalIsNull = ${ev.isNull}; + | $globalValue = ${ev.value}; + |} + """.stripMargin + ctx.addNewFunction(funcName, funcBody) + (funcName, globalIsNull, globalValue) } override def toString: String = s"if ($predicate) $trueValue else $falseValue" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala new file mode 100644 index 0000000000..7a8edabed1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.rdd.InputFileBlockHolder +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.{DataType, LongType, StringType} +import org.apache.spark.unsafe.types.UTF8String + + +@ExpressionDescription( + usage = "_FUNC_() - Returns the name of the file being read, or empty string if not available.") +case class InputFileName() extends LeafExpression with Nondeterministic { + + override def nullable: Boolean = false + + override def dataType: DataType = StringType + + override def prettyName: String = "input_file_name" + + override protected def initializeInternal(partitionIndex: Int): Unit = {} + + override protected def evalInternal(input: InternalRow): UTF8String = { + InputFileBlockHolder.getInputFilePath + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") + ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + s"$className.getInputFilePath();", isNull = "false") + } +} + + +@ExpressionDescription( + usage = "_FUNC_() - Returns the start offset of the block being read, or -1 if not available.") +case class InputFileBlockStart() extends LeafExpression with Nondeterministic { + override def nullable: Boolean = false + + override def dataType: DataType = LongType + + override def prettyName: String = "input_file_block_start" + + override protected def initializeInternal(partitionIndex: Int): Unit = {} + + override protected def evalInternal(input: InternalRow): Long = { + InputFileBlockHolder.getStartOffset + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") + ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + s"$className.getStartOffset();", isNull = "false") + } +} + + +@ExpressionDescription( + usage = "_FUNC_() - Returns the length of the block being read, or -1 if not available.") +case class InputFileBlockLength() extends LeafExpression with Nondeterministic { + override def nullable: Boolean = false + + override def dataType: DataType = LongType + + override def prettyName: String = "input_file_block_length" + + override protected def initializeInternal(partitionIndex: Int): Unit = {} + + override protected def evalInternal(input: InternalRow): Long = { + InputFileBlockHolder.getLength + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") + ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + s"$className.getLength();", isNull = "false") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index a8aa1e7255..fc323693a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -930,7 +930,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp /** * Asserts that input values of a non-nullable child expression are not null. * - * Note that there are cases where `child.nullable == true`, while we still needs to add this + * Note that there are cases where `child.nullable == true`, while we still need to add this * assertion. Consider a nullable column `s` whose data type is a struct containing a non-nullable * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all * non-null `s`, `s.i` can't be null. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index e476cb11a3..03e27ba934 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -119,47 +119,33 @@ class JacksonParser( * to a value according to a desired schema. This is a wrapper for the method * `makeConverter()` to handle a row wrapped with an array. */ - def makeRootConverter(dataType: DataType): ValueConverter = dataType match { - case st: StructType => - val elementConverter = makeConverter(st) - val fieldConverters = st.map(_.dataType).map(makeConverter) - (parser: JsonParser) => parseJsonToken(parser, dataType) { - case START_OBJECT => convertObject(parser, st, fieldConverters) - // SPARK-3308: support reading top level JSON arrays and take every element - // in such an array as a row - // - // For example, we support, the JSON data as below: - // - // [{"a":"str_a_1"}] - // [{"a":"str_a_2"}, {"b":"str_b_3"}] - // - // resulting in: - // - // List([str_a_1,null]) - // List([str_a_2,null], [null,str_b_3]) - // - case START_ARRAY => convertArray(parser, elementConverter) - } - - case ArrayType(st: StructType, _) => - val elementConverter = makeConverter(st) - val fieldConverters = st.map(_.dataType).map(makeConverter) - (parser: JsonParser) => parseJsonToken(parser, dataType) { - // the business end of SPARK-3308: - // when an object is found but an array is requested just wrap it in a list. - // This is being wrapped in `JacksonParser.parse`. - case START_OBJECT => convertObject(parser, st, fieldConverters) - case START_ARRAY => convertArray(parser, elementConverter) - } - - case _ => makeConverter(dataType) + private def makeRootConverter(st: StructType): ValueConverter = { + val elementConverter = makeConverter(st) + val fieldConverters = st.map(_.dataType).map(makeConverter) + (parser: JsonParser) => parseJsonToken(parser, st) { + case START_OBJECT => convertObject(parser, st, fieldConverters) + // SPARK-3308: support reading top level JSON arrays and take every element + // in such an array as a row + // + // For example, we support, the JSON data as below: + // + // [{"a":"str_a_1"}] + // [{"a":"str_a_2"}, {"b":"str_b_3"}] + // + // resulting in: + // + // List([str_a_1,null]) + // List([str_a_2,null], [null,str_b_3]) + // + case START_ARRAY => convertArray(parser, elementConverter) + } } /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. */ - private def makeConverter(dataType: DataType): ValueConverter = dataType match { + private[sql] def makeConverter(dataType: DataType): ValueConverter = dataType match { case BooleanType => (parser: JsonParser) => parseJsonToken(parser, dataType) { case VALUE_TRUE => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 37f0c8ed19..dfd66aac2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -796,7 +796,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { case _: Distinct => true case _: Generate => true case _: Pivot => true - case _: RedistributeData => true + case _: RepartitionByExpression => true case _: Repartition => true case _: ScriptTransformation => true case _: Sort => true @@ -932,7 +932,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) joinType match { - case _: InnerLike | LeftSemi => + case _: InnerLike | LeftSemi => // push down the single side only join filter for both sides sub queries val newLeft = leftJoinConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 6958398e03..949ccdcb45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -489,7 +489,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { case _: AppendColumns => true case _: AppendColumnsWithObject => true case _: BroadcastHint => true - case _: RedistributeData => true + case _: RepartitionByExpression => true case _: Repartition => true case _: Sort => true case _: TypedFilter => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7b8badcf8c..3969fdb0ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -177,15 +177,12 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " + "partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx) } - val overwrite = ctx.OVERWRITE != null - val staticPartitionKeys: Map[String, String] = - partitionKeys.filter(_._2.nonEmpty).map(t => (t._1, t._2.get)) InsertIntoTable( UnresolvedRelation(tableIdent, None), partitionKeys, query, - OverwriteOptions(overwrite, if (overwrite) staticPartitionKeys else Map.empty), + ctx.OVERWRITE != null, ctx.EXISTS != null) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 890865d177..91633f5124 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -75,7 +75,8 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) } override lazy val statistics = - Statistics(sizeInBytes = output.map(_.dataType.defaultSize).sum * data.length) + Statistics(sizeInBytes = + (output.map(n => BigInt(n.dataType.defaultSize))).sum * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 79865609cb..465fbab571 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -194,11 +194,12 @@ object ColumnStat extends Logging { val numNonNulls = if (col.nullable) Count(col) else Count(one) val ndv = Least(Seq(HyperLogLogPlusPlus(col, relativeSD), numNonNulls)) val numNulls = Subtract(Count(one), numNonNulls) + val defaultSize = Literal(col.dataType.defaultSize, LongType) def fixedLenTypeStruct(castType: DataType) = { // For fixed width types, avg size should be the same as max size. - val avgSize = Literal(col.dataType.defaultSize, LongType) - struct(ndv, Cast(Min(col), castType), Cast(Max(col), castType), numNulls, avgSize, avgSize) + struct(ndv, Cast(Min(col), castType), Cast(Max(col), castType), numNulls, defaultSize, + defaultSize) } col.dataType match { @@ -213,7 +214,9 @@ object ColumnStat extends Logging { val nullLit = Literal(null, col.dataType) struct( ndv, nullLit, nullLit, numNulls, - Ceil(Average(Length(col))), Cast(Max(Length(col)), LongType)) + // Set avg/max size to default size if all the values are null or there is no value. + Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)), + Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize))) case _ => throw new AnalysisException("Analyzing column statistics is not supported for column " + s"${col.name} of data type: ${col.dataType}.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 7aaefc8529..d583fa31b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.catalog.CatalogTypes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -93,13 +91,13 @@ case class Generate( override def producedAttributes: AttributeSet = AttributeSet(generatorOutput) - def output: Seq[Attribute] = { - val qualified = qualifier.map(q => - // prepend the new qualifier to the existed one - generatorOutput.map(a => a.withQualifier(Some(q))) - ).getOrElse(generatorOutput) + def qualifiedGeneratorOutput: Seq[Attribute] = qualifier.map { q => + // prepend the new qualifier to the existed one + generatorOutput.map(a => a.withQualifier(Some(q))) + }.getOrElse(generatorOutput) - if (join) child.output ++ qualified else qualified + def output: Seq[Attribute] = { + if (join) child.output ++ qualifiedGeneratorOutput else qualifiedGeneratorOutput } } @@ -346,22 +344,6 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override lazy val statistics: Statistics = super.statistics.copy(isBroadcastable = true) } -/** - * Options for writing new data into a table. - * - * @param enabled whether to overwrite existing data in the table. - * @param staticPartitionKeys if non-empty, specifies that we only want to overwrite partitions - * that match this partial partition spec. If empty, all partitions - * will be overwritten. - */ -case class OverwriteOptions( - enabled: Boolean, - staticPartitionKeys: CatalogTypes.TablePartitionSpec = Map.empty) { - if (staticPartitionKeys.nonEmpty) { - assert(enabled, "Overwrite must be enabled when specifying specific partitions.") - } -} - /** * Insert some data into a table. * @@ -382,14 +364,14 @@ case class InsertIntoTable( table: LogicalPlan, partition: Map[String, Option[String]], child: LogicalPlan, - overwrite: OverwriteOptions, + overwrite: Boolean, ifNotExists: Boolean) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = Seq.empty - assert(overwrite.enabled || !ifNotExists) + assert(overwrite || !ifNotExists) assert(partition.values.forall(_.nonEmpty) || !ifNotExists) override lazy val resolved: Boolean = childrenResolved && table.resolved @@ -411,7 +393,7 @@ case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) s"CTE $cteAliases" } - override def innerChildren: Seq[QueryPlan[_]] = cteRelations.map(_._2) + override def innerChildren: Seq[LogicalPlan] = cteRelations.map(_._2) } case class WithWindowDefinition( @@ -782,6 +764,28 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) override def output: Seq[Attribute] = child.output } +/** + * This method repartitions data using [[Expression]]s into `numPartitions`, and receives + * information about the number of partitions during execution. Used when a specific ordering or + * distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like + * `coalesce` and `repartition`. + * If `numPartitions` is not specified, the number of partitions will be the number set by + * `spark.sql.shuffle.partitions`. + */ +case class RepartitionByExpression( + partitionExpressions: Seq[Expression], + child: LogicalPlan, + numPartitions: Option[Int] = None) extends UnaryNode { + + numPartitions match { + case Some(n) => require(n > 0, s"Number of partitions ($n) must be positive.") + case None => // Ok + } + + override def maxRows: Option[Long] = child.maxRows + override def output: Seq[Attribute] = child.output +} + /** * A relation with one row. This is used in "SELECT ..." without a from clause. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala deleted file mode 100644 index 28cbce8748..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans.logical - -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder} - -/** - * Performs a physical redistribution of the data. Used when the consumer of the query - * result have expectations about the distribution and ordering of partitioned input data. - */ -abstract class RedistributeData extends UnaryNode { - override def output: Seq[Attribute] = child.output -} - -case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) - extends RedistributeData - -/** - * This method repartitions data using [[Expression]]s into `numPartitions`, and receives - * information about the number of partitions during execution. Used when a specific ordering or - * distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like - * `coalesce` and `repartition`. - * If `numPartitions` is not specified, the number of partitions will be the number set by - * `spark.sql.shuffle.partitions`. - */ -case class RepartitionByExpression( - partitionExpressions: Seq[Expression], - child: LogicalPlan, - numPartitions: Option[Int] = None) extends RedistributeData { - numPartitions match { - case Some(n) => require(n > 0, s"Number of partitions ($n) must be positive.") - case None => // Ok - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/InternalOutputModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala similarity index 97% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/InternalOutputModes.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala index 594c41c2c7..915f4a9e25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/InternalOutputModes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalOutputModes.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.catalyst.streaming import org.apache.spark.sql.streaming.OutputMode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index ea8d8fef7b..8cc16d662b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.trees import java.util.UUID import scala.collection.Map -import scala.collection.mutable.Stack import scala.reflect.ClassTag import org.apache.commons.lang3.ClassUtils @@ -28,12 +27,9 @@ import org.json4s.JsonAST._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.SparkContext -import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.ScalaReflection._ -import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ @@ -493,25 +489,43 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** * Returns a string representation of the nodes in this tree, where each operator is numbered. - * The numbers can be used with [[trees.TreeNode.apply apply]] to easily access specific subtrees. + * The numbers can be used with [[TreeNode.apply]] to easily access specific subtrees. + * + * The numbers are based on depth-first traversal of the tree (with innerChildren traversed first + * before children). */ def numberedTreeString: String = treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n") /** - * Returns the tree node at the specified number. + * Returns the tree node at the specified number, used primarily for interactive debugging. + * Numbers for each node can be found in the [[numberedTreeString]]. + * + * Note that this cannot return BaseType because logical plan's plan node might return + * physical plan for innerChildren, e.g. in-memory relation logical plan node has a reference + * to the physical plan node it is referencing. + */ + def apply(number: Int): TreeNode[_] = getNodeNumbered(new MutableInt(number)).orNull + + /** + * Returns the tree node at the specified number, used primarily for interactive debugging. * Numbers for each node can be found in the [[numberedTreeString]]. + * + * This is a variant of [[apply]] that returns the node as BaseType (if the type matches). */ - def apply(number: Int): BaseType = getNodeNumbered(new MutableInt(number)) + def p(number: Int): BaseType = apply(number).asInstanceOf[BaseType] - protected def getNodeNumbered(number: MutableInt): BaseType = { + private def getNodeNumbered(number: MutableInt): Option[TreeNode[_]] = { if (number.i < 0) { - null.asInstanceOf[BaseType] + None } else if (number.i == 0) { - this + Some(this) } else { number.i -= 1 - children.map(_.getNodeNumbered(number)).find(_ != null).getOrElse(null.asInstanceOf[BaseType]) + // Note that this traversal order must be the same as numberedTreeString. + innerChildren.map(_.getNodeNumbered(number)).find(_ != None).getOrElse { + children.map(_.getNodeNumbered(number)).find(_ != None).flatten + } } } @@ -527,6 +541,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at * depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and * `lastChildren` for the root node should be empty. + * + * Note that this traversal (numbering) order must be the same as [[getNodeNumbered]]. */ def generateTreeString( depth: Int, @@ -534,19 +550,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { builder: StringBuilder, verbose: Boolean, prefix: String = ""): StringBuilder = { + if (depth > 0) { lastChildren.init.foreach { isLast => - val prefixFragment = if (isLast) " " else ": " - builder.append(prefixFragment) + builder.append(if (isLast) " " else ": ") } - - val branch = if (lastChildren.last) "+- " else ":- " - builder.append(branch) + builder.append(if (lastChildren.last) "+- " else ":- ") } builder.append(prefix) - val headline = if (verbose) verboseString else simpleString - builder.append(headline) + builder.append(if (verbose) verboseString else simpleString) builder.append("\n") if (innerChildren.nonEmpty) { @@ -557,9 +570,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } if (children.nonEmpty) { - children.init.foreach( - _.generateTreeString(depth + 1, lastChildren :+ false, builder, verbose, prefix)) - children.last.generateTreeString(depth + 1, lastChildren :+ true, builder, verbose, prefix) + children.init.foreach(_.generateTreeString( + depth + 1, lastChildren :+ false, builder, verbose, prefix)) + children.last.generateTreeString( + depth + 1, lastChildren :+ true, builder, verbose, prefix) } builder diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index d409271fbc..98efba199a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -78,10 +78,10 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT ("containsNull" -> containsNull) /** - * The default size of a value of the ArrayType is 100 * the default size of the element type. - * (We assume that there are 100 elements). + * The default size of a value of the ArrayType is the default size of the element type. + * We assume that there is only 1 element on average in an array. See SPARK-18853. */ - override def defaultSize: Int = 100 * elementType.defaultSize + override def defaultSize: Int = 1 * elementType.defaultSize override def simpleString: String = s"array<${elementType.simpleString}>" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index fbf3a61786..6691b81dce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -56,10 +56,10 @@ case class MapType( /** * The default size of a value of the MapType is - * 100 * (the default size of the key type + the default size of the value type). - * (We assume that there are 100 elements). + * (the default size of the key type + the default size of the value type). + * We assume that there is only 1 element on average in a map. See SPARK-18853. */ - override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize) + override def defaultSize: Int = 1 * (keyType.defaultSize + valueType.defaultSize) override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 8c1faea239..96aff37a4b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -542,7 +542,7 @@ class AnalysisErrorSuite extends AnalysisTest { Filter(EqualTo(OuterReference(a), b), LocalRelation(b))) ), LocalRelation(a)) - assertAnalysisError(plan4, "Accessing outer query column is not allowed in a LIMIT" :: Nil) + assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil) val plan5 = Filter( Exists( @@ -551,6 +551,6 @@ class AnalysisErrorSuite extends AnalysisTest { ), LocalRelation(a)) assertAnalysisError(plan5, - "Accessing outer query column is not allowed in a TABLESAMPLE" :: Nil) + "Accessing outer query column is not allowed in" :: Nil) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 590c9d5e84..dbb1e3e70f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -57,14 +57,43 @@ class TypeCoercionSuite extends PlanTest { // scalastyle:on line.size.limit private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { - val got = TypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) - assert(got.map(_.dataType) == Option(expected), + // Check default value + val castDefault = TypeCoercion.ImplicitTypeCasts.implicitCast(default(from), to) + assert(DataType.equalsIgnoreCompatibleNullability( + castDefault.map(_.dataType).getOrElse(null), expected), + s"Failed to cast $from to $to") + + // Check null value + val castNull = TypeCoercion.ImplicitTypeCasts.implicitCast(createNull(from), to) + assert(DataType.equalsIgnoreCaseAndNullability( + castNull.map(_.dataType).getOrElse(null), expected), s"Failed to cast $from to $to") } private def shouldNotCast(from: DataType, to: AbstractDataType): Unit = { - val got = TypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) - assert(got.isEmpty, s"Should not be able to cast $from to $to, but got $got") + // Check default value + val castDefault = TypeCoercion.ImplicitTypeCasts.implicitCast(default(from), to) + assert(castDefault.isEmpty, s"Should not be able to cast $from to $to, but got $castDefault") + + // Check null value + val castNull = TypeCoercion.ImplicitTypeCasts.implicitCast(createNull(from), to) + assert(castNull.isEmpty, s"Should not be able to cast $from to $to, but got $castNull") + } + + private def default(dataType: DataType): Expression = dataType match { + case ArrayType(internalType: DataType, _) => + CreateArray(Seq(Literal.default(internalType))) + case MapType(keyDataType: DataType, valueDataType: DataType, _) => + CreateMap(Seq(Literal.default(keyDataType), Literal.default(valueDataType))) + case _ => Literal.default(dataType) + } + + private def createNull(dataType: DataType): Expression = dataType match { + case ArrayType(internalType: DataType, _) => + CreateArray(Seq(Literal.create(null, internalType))) + case MapType(keyDataType: DataType, valueDataType: DataType, _) => + CreateMap(Seq(Literal.create(null, keyDataType), Literal.create(null, valueDataType))) + case _ => Literal.create(null, dataType) } val integralTypes: Seq[DataType] = @@ -196,7 +225,13 @@ class TypeCoercionSuite extends PlanTest { test("implicit type cast - ArrayType(StringType)") { val checkedType = ArrayType(StringType) - checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) + val nonCastableTypes = + complexTypes ++ Seq(BooleanType, NullType, CalendarIntervalType) + checkTypeCasting(checkedType, + castableTypes = allTypes.filterNot(nonCastableTypes.contains).map(ArrayType(_))) + nonCastableTypes.map(ArrayType(_)).foreach(shouldNotCast(checkedType, _)) + shouldNotCast(ArrayType(DoubleType, containsNull = false), + ArrayType(LongType, containsNull = false)) shouldNotCast(checkedType, DecimalType) shouldNotCast(checkedType, NumericType) shouldNotCast(checkedType, IntegralType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index ff1bb126f4..d2c0f8cc9f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.InternalOutputModes._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -27,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.IntegerType @@ -98,6 +98,19 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Update, expectedMsgs = Seq("multiple streaming aggregations")) + // Aggregation: Distinct aggregates not supported on streaming relation + val distinctAggExprs = Seq(Count("*").toAggregateExpression(isDistinct = true).as("c")) + assertSupportedInStreamingPlan( + "distinct aggregate - aggregate on batch relation", + Aggregate(Nil, distinctAggExprs, batchRelation), + outputMode = Append) + + assertNotSupportedInStreamingPlan( + "distinct aggregate - aggregate on streaming relation", + Aggregate(Nil, distinctAggExprs, streamRelation), + outputMode = Complete, + expectedMsgs = Seq("distinct aggregation")) + // Inner joins: Stream-stream not supported testBinaryOperationInStreamingPlan( "inner join", @@ -200,7 +213,6 @@ class UnsupportedOperationsSuite extends SparkFunSuite { // Other unary operations - testUnaryOperatorInStreamingPlan("sort partitions", SortPartitions(Nil, _), expectedMsg = "sort") testUnaryOperatorInStreamingPlan( "sample", Sample(0.1, 1, true, 1L, _)(), expectedMsg = "sampling") testUnaryOperatorInStreamingPlan( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index 3b39f420af..00e663c324 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -346,6 +346,31 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac assert(new Path(partitionLocation) == defaultPartitionLocation) } + test("list partition names") { + val catalog = newBasicCatalog() + val newPart = CatalogTablePartition(Map("a" -> "1", "b" -> "%="), storageFormat) + catalog.createPartitions("db2", "tbl2", Seq(newPart), ignoreIfExists = false) + + val partitionNames = catalog.listPartitionNames("db2", "tbl2") + assert(partitionNames == Seq("a=1/b=%25%3D", "a=1/b=2", "a=3/b=4")) + } + + test("list partition names with partial partition spec") { + val catalog = newBasicCatalog() + val newPart = CatalogTablePartition(Map("a" -> "1", "b" -> "%="), storageFormat) + catalog.createPartitions("db2", "tbl2", Seq(newPart), ignoreIfExists = false) + + val partitionNames1 = catalog.listPartitionNames("db2", "tbl2", Some(Map("a" -> "1"))) + assert(partitionNames1 == Seq("a=1/b=%25%3D", "a=1/b=2")) + + // Partial partition specs including "weird" partition values should use the unescaped values + val partitionNames2 = catalog.listPartitionNames("db2", "tbl2", Some(Map("b" -> "%="))) + assert(partitionNames2 == Seq("a=1/b=%25%3D")) + + val partitionNames3 = catalog.listPartitionNames("db2", "tbl2", Some(Map("b" -> "%25%3D"))) + assert(partitionNames3.isEmpty) + } + test("list partitions with partial partition spec") { val catalog = newBasicCatalog() val parts = catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "1"))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index f9c4b2687b..5cc772d8e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -878,6 +878,31 @@ class SessionCatalogSuite extends SparkFunSuite { "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) } + test("list partition names") { + val catalog = new SessionCatalog(newBasicCatalog()) + val expectedPartitionNames = Seq("a=1/b=2", "a=3/b=4") + assert(catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2"))) == + expectedPartitionNames) + // List partition names without explicitly specifying database + catalog.setCurrentDatabase("db2") + assert(catalog.listPartitionNames(TableIdentifier("tbl2")) == expectedPartitionNames) + } + + test("list partition names with partial partition spec") { + val catalog = new SessionCatalog(newBasicCatalog()) + assert( + catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), Some(Map("a" -> "1"))) == + Seq("a=1/b=2")) + } + + test("list partition names with invalid partial partition spec") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.listPartitionNames(TableIdentifier("tbl2", Some("db2")), + Some(Map("unknown" -> "unknown"))) + } + } + test("list partitions") { val catalog = new SessionCatalog(newBasicCatalog()) assert(catalogPartitionsEqual( @@ -887,6 +912,20 @@ class SessionCatalogSuite extends SparkFunSuite { assert(catalogPartitionsEqual(catalog.listPartitions(TableIdentifier("tbl2")), part1, part2)) } + test("list partitions with partial partition spec") { + val catalog = new SessionCatalog(newBasicCatalog()) + assert(catalogPartitionsEqual( + catalog.listPartitions(TableIdentifier("tbl2", Some("db2")), Some(Map("a" -> "1"))), part1)) + } + + test("list partitions with invalid partial partition spec") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[AnalysisException] { + catalog.listPartitions( + TableIdentifier("tbl2", Some("db2")), Some(Map("unknown" -> "unknown"))) + } + } + test("list partitions when database/table does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) intercept[NoSuchDatabaseException] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 0cb201e4da..ee5d1f6373 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -97,6 +97,22 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actual(0) == cases) } + test("SPARK-18091: split large if expressions into blocks due to JVM code size limit") { + var strExpr: Expression = Literal("abc") + for (_ <- 1 to 150) { + strExpr = Decode(Encode(strExpr, "utf-8"), "utf-8") + } + + val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr)) + val plan = GenerateMutableProjection.generate(expressions) + val actual = plan(null).toSeq(expressions.map(_.dataType)) + val expected = Seq(UTF8String.fromString("abc")) + + if (!checkResult(actual, expected)) { + fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") + } + } + test("SPARK-14793: split wide array creation into blocks due to JVM code size limit") { val length = 5000 val expressions = Seq(CreateArray(List.fill(length)(EqualTo(Literal(1), Literal(1))))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 8e4327c788..f408ba99d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -180,16 +180,7 @@ class PlanParserSuite extends PlanTest { partition: Map[String, Option[String]], overwrite: Boolean = false, ifNotExists: Boolean = false): LogicalPlan = - InsertIntoTable( - table("s"), partition, plan, - OverwriteOptions( - overwrite, - if (overwrite && partition.nonEmpty) { - partition.map(kv => (kv._1, kv._2.get)) - } else { - Map.empty - }), - ifNotExists) + InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) // Single inserts assertEqual(s"insert overwrite table s $sql", @@ -205,9 +196,9 @@ class PlanParserSuite extends PlanTest { val plan2 = table("t").where('x > 5).select(star()) assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", InsertIntoTable( - table("s"), Map.empty, plan.limit(1), OverwriteOptions(false), ifNotExists = false).union( + table("s"), Map.empty, plan.limit(1), false, ifNotExists = false).union( InsertIntoTable( - table("u"), Map.empty, plan2, OverwriteOptions(false), ifNotExists = false))) + table("u"), Map.empty, plan2, false, ifNotExists = false))) } test ("insert with if not exists") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index b8ab9a9963..12d2c00dc9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -253,7 +253,7 @@ class DataTypeSuite extends SparkFunSuite { checkDataTypeJsonRepr(structType) def checkDefaultSize(dataType: DataType, expectedDefaultSize: Int): Unit = { - test(s"Check the default size of ${dataType}") { + test(s"Check the default size of $dataType") { assert(dataType.defaultSize === expectedDefaultSize) } } @@ -272,18 +272,18 @@ class DataTypeSuite extends SparkFunSuite { checkDefaultSize(TimestampType, 8) checkDefaultSize(StringType, 20) checkDefaultSize(BinaryType, 100) - checkDefaultSize(ArrayType(DoubleType, true), 800) - checkDefaultSize(ArrayType(StringType, false), 2000) - checkDefaultSize(MapType(IntegerType, StringType, true), 2400) - checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 80400) - checkDefaultSize(structType, 812) + checkDefaultSize(ArrayType(DoubleType, true), 8) + checkDefaultSize(ArrayType(StringType, false), 20) + checkDefaultSize(MapType(IntegerType, StringType, true), 24) + checkDefaultSize(MapType(IntegerType, ArrayType(DoubleType), false), 12) + checkDefaultSize(structType, 20) def checkEqualsIgnoreCompatibleNullability( from: DataType, to: DataType, expected: Boolean): Unit = { val testName = - s"equalsIgnoreCompatibleNullability: (from: ${from}, to: ${to})" + s"equalsIgnoreCompatibleNullability: (from: $from, to: $to)" test(testName) { assert(DataType.equalsIgnoreCompatibleNullability(from, to) === expected) } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 7da77158ff..b8aa698090 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -74,6 +74,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + org.apache.parquet parquet-column diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 086547c793..730a4ae8d5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -69,6 +69,16 @@ protected void append(InternalRow row) { currentRows.add(row); } + /** + * Returns whether this iterator should stop fetching next row from [[CodegenSupport#inputRDDs]]. + * + * If it returns true, the caller should exit the loop that [[InputAdapter]] generates. + * This interface is mainly used to limit the number of input rows. + */ + protected boolean stopEarly() { + return false; + } + /** * Returns whether `processNext()` should stop processing next row from `input` or not. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 184c5a1129..28820681cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -128,6 +128,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * + * @since 2.2.0 + */ + def fill(value: Long): DataFrame = fill(value, df.columns) + + /** + * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * @since 1.3.1 */ def fill(value: Double): DataFrame = fill(value, df.columns) @@ -139,6 +145,14 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { */ def fill(value: String): DataFrame = fill(value, df.columns) + /** + * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. + * If a specified column is not a numeric column, it is ignored. + * + * @since 2.2.0 + */ + def fill(value: Long, cols: Array[String]): DataFrame = fill(value, cols.toSeq) + /** * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. * If a specified column is not a numeric column, it is ignored. @@ -147,24 +161,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { */ def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq) + /** + * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified + * numeric columns. If a specified column is not a numeric column, it is ignored. + * + * @since 2.2.0 + */ + def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, cols) + /** * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified * numeric columns. If a specified column is not a numeric column, it is ignored. * * @since 1.3.1 */ - def fill(value: Double, cols: Seq[String]): DataFrame = { - val columnEquals = df.sparkSession.sessionState.analyzer.resolver - val projections = df.schema.fields.map { f => - // Only fill if the column is part of the cols list. - if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => columnEquals(f.name, col))) { - fillCol[Double](f, value) - } else { - df.col(f.name) - } - } - df.select(projections : _*) - } + def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, cols) + /** * Returns a new `DataFrame` that replaces null values in specified string columns. @@ -180,18 +192,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 1.3.1 */ - def fill(value: String, cols: Seq[String]): DataFrame = { - val columnEquals = df.sparkSession.sessionState.analyzer.resolver - val projections = df.schema.fields.map { f => - // Only fill if the column is part of the cols list. - if (f.dataType.isInstanceOf[StringType] && cols.exists(col => columnEquals(f.name, col))) { - fillCol[String](f, value) - } else { - df.col(f.name) - } - } - df.select(projections : _*) - } + def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols) /** * Returns a new `DataFrame` that replaces null values. @@ -210,7 +211,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 1.3.1 */ - def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.asScala.toSeq) + def fill(valueMap: java.util.Map[String, Any]): DataFrame = fillMap(valueMap.asScala.toSeq) /** * (Scala-specific) Returns a new `DataFrame` that replaces null values. @@ -230,7 +231,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 1.3.1 */ - def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + def fill(valueMap: Map[String, Any]): DataFrame = fillMap(valueMap.toSeq) /** * Replaces values matching keys in `replacement` map with the corresponding values. @@ -368,7 +369,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { df.select(projections : _*) } - private def fill0(values: Seq[(String, Any)]): DataFrame = { + private def fillMap(values: Seq[(String, Any)]): DataFrame = { // Error handling values.foreach { case (colName, replaceValue) => // Check column name exists @@ -435,4 +436,38 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { case v => throw new IllegalArgumentException( s"Unsupported value type ${v.getClass.getName} ($v).") } + + /** + * Returns a new `DataFrame` that replaces null or NaN values in specified + * numeric, string columns. If a specified column is not a numeric, string column, + * it is ignored. + */ + private def fillValue[T](value: T, cols: Seq[String]): DataFrame = { + // the fill[T] which T is Long/Double, + // should apply on all the NumericType Column, for example: + // val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b") + // input.na.fill(3.1) + // the result is (3,164.3), not (null, 164.3) + val targetType = value match { + case _: Double | _: Long => NumericType + case _: String => StringType + case _ => throw new IllegalArgumentException( + s"Unsupported value type ${value.getClass.getName} ($value).") + } + + val columnEquals = df.sparkSession.sessionState.analyzer.resolver + val projections = df.schema.fields.map { f => + val typeMatches = (targetType, f.dataType) match { + case (NumericType, dt) => dt.isInstanceOf[NumericType] + case (StringType, dt) => dt == StringType + } + // Only fill if the column is part of the cols list. + if (typeMatches && cols.exists(col => columnEquals(f.name, col))) { + fillCol[T](f, value) + } else { + df.col(f.name) + } + } + df.select(projections : _*) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index fa8e8cb985..9c5660a378 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,9 +25,9 @@ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, OverwriteOptions} -import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, DDLUtils} -import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, HadoopFsRelation} +import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource} import org.apache.spark.sql.types.StructType /** @@ -150,7 +150,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * predicates on the partitioned columns. In order for partitioning to work well, the number * of distinct values in each column should typically be less than tens of thousands. * - * This was initially applicable for Parquet but in 1.5+ covers JSON, text, ORC and avro as well. + * This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0. * * @since 1.4.0 */ @@ -164,7 +164,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * Buckets the output by the given columns. If specified, the output is laid out on the file * system similar to Hive's bucketing scheme. * - * This is applicable for Parquet, JSON and ORC. + * This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0. * * @since 2.0 */ @@ -178,7 +178,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { /** * Sorts the output in each bucket by the given columns. * - * This is applicable for Parquet, JSON and ORC. + * This is applicable for all file-based data sources (e.g. Parquet, JSON) staring Spark 2.1.0. * * @since 2.0 */ @@ -214,6 +214,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { dataSource.write(mode, df) } + /** * Inserts the content of the `DataFrame` to the specified table. It requires that * the schema of the `DataFrame` is the same as the schema of the table. @@ -259,7 +260,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { table = UnresolvedRelation(tableIdent), partition = Map.empty[String, Option[String]], child = df.logicalPlan, - overwrite = OverwriteOptions(mode == SaveMode.Overwrite), + overwrite = mode == SaveMode.Overwrite, ifNotExists = false)).toRdd } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 133f633212..29397b1340 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1871,21 +1871,7 @@ class Dataset[T] private[sql]( * Returns a new Dataset by adding a column with metadata. */ private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { - val resolver = sparkSession.sessionState.analyzer.resolver - val output = queryExecution.analyzed.output - val shouldReplace = output.exists(f => resolver(f.name, colName)) - if (shouldReplace) { - val columns = output.map { field => - if (resolver(field.name, colName)) { - col.as(colName, metadata) - } else { - Column(field) - } - } - select(columns : _*) - } else { - select(Column("*"), col.as(colName, metadata)) - } + withColumn(colName, col.as(colName, metadata)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 6554359806..1a7fd689a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -747,7 +747,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ def tableNames(): Array[String] = { - sparkSession.catalog.listTables().collect().map(_.name) + tableNames(sparkSession.catalog.currentDatabase) } /** @@ -757,7 +757,7 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @since 1.3.0 */ def tableNames(databaseName: String): Array[String] = { - sparkSession.catalog.listTables(databaseName).collect().map(_.name) + sessionState.catalog.listTables(databaseName).map(_.table).toArray } //////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 73d16d8a10..872a78b578 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -74,6 +74,19 @@ abstract class SQLImplicits { /** @since 1.6.0 */ implicit def newStringEncoder: Encoder[String] = Encoders.STRING + /** @since 2.2.0 */ + implicit def newJavaDecimalEncoder: Encoder[java.math.BigDecimal] = Encoders.DECIMAL + + /** @since 2.2.0 */ + implicit def newScalaDecimalEncoder: Encoder[scala.math.BigDecimal] = ExpressionEncoder() + + /** @since 2.2.0 */ + implicit def newDateEncoder: Encoder[java.sql.Date] = Encoders.DATE + + /** @since 2.2.0 */ + implicit def newTimeStampEncoder: Encoder[java.sql.Timestamp] = Encoders.TIMESTAMP + + // Boxed primitives /** @since 2.0.0 */ @@ -141,7 +154,7 @@ abstract class SQLImplicits { implicit def newFloatArrayEncoder: Encoder[Array[Float]] = ExpressionEncoder() /** @since 1.6.1 */ - implicit def newByteArrayEncoder: Encoder[Array[Byte]] = ExpressionEncoder() + implicit def newByteArrayEncoder: Encoder[Array[Byte]] = Encoders.BINARY /** @since 1.6.1 */ implicit def newShortArrayEncoder: Encoder[Array[Short]] = ExpressionEncoder() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 08d74ac018..f3dde480ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -618,6 +618,22 @@ class SparkSession private( @InterfaceStability.Evolving def readStream: DataStreamReader = new DataStreamReader(self) + /** + * Executes some code block and prints to stdout the time taken to execute the block. This is + * available in Scala only and is used primarily for interactive testing and debugging. + * + * @since 2.1.0 + */ + @InterfaceStability.Stable + def time[T](f: => T): T = { + val start = System.nanoTime() + val ret = f + val end = System.nanoTime() + // scalastyle:off println + println(s"Time taken: ${(end - start) / 1000 / 1000} ms") + // scalastyle:on println + ret + } // scalastyle:off // Disable style checker so "implicits" object can start with lowercase i diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index c8be89c646..d94185b390 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -23,8 +23,6 @@ import java.lang.reflect.{ParameterizedType, Type} import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import com.google.common.reflect.TypeToken - import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ @@ -446,7 +444,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val udfReturnType = udfInterfaces(0).getActualTypeArguments.last var returnType = returnDataType if (returnType == null) { - returnType = JavaTypeInference.inferDataType(TypeToken.of(udfReturnType))._1 + returnType = JavaTypeInference.inferDataType(udfReturnType)._1 } udfInterfaces(0).getActualTypeArguments.length match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 9de6510c63..e56c33e4b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types._ private[sql] object SQLUtils extends Logging { - SerDe.registerSqlSerDe((readSqlObject, writeSqlObject)) + SerDe.setSQLReadObject(readSqlObject).setSQLWriteObject(writeSqlObject) private[this] def withHiveExternalCatalog(sc: SparkContext): SparkContext = { sc.conf.set(CATALOG_IMPLEMENTATION.key, "hive") @@ -158,7 +158,7 @@ private[sql] object SQLUtils extends Logging { val dis = new DataInputStream(bis) val num = SerDe.readInt(dis) Row.fromSeq((0 until num).map { i => - doConversion(SerDe.readObject(dis), schema.fields(i).dataType) + doConversion(SerDe.readObject(dis, jvmObjectTracker = null), schema.fields(i).dataType) }) } @@ -167,7 +167,7 @@ private[sql] object SQLUtils extends Logging { val dos = new DataOutputStream(bos) val cols = (0 until row.length).map(row(_).asInstanceOf[Object]).toArray - SerDe.writeObject(dos, cols) + SerDe.writeObject(dos, cols, jvmObjectTracker = null) bos.toByteArray() } @@ -247,7 +247,7 @@ private[sql] object SQLUtils extends Logging { dataType match { case 's' => // Read StructType for DataFrame - val fields = SerDe.readList(dis).asInstanceOf[Array[Object]] + val fields = SerDe.readList(dis, jvmObjectTracker = null).asInstanceOf[Array[Object]] Row.fromSeq(fields) case _ => null } @@ -258,8 +258,8 @@ private[sql] object SQLUtils extends Logging { // Handle struct type in DataFrame case v: GenericRowWithSchema => dos.writeByte('s') - SerDe.writeObject(dos, v.schema.fieldNames) - SerDe.writeObject(dos, v.values) + SerDe.writeObject(dos, v.schema.fieldNames, jvmObjectTracker = null) + SerDe.writeObject(dos, v.values, jvmObjectTracker = null) true case _ => false @@ -276,11 +276,12 @@ private[sql] object SQLUtils extends Logging { } def getTableNames(sparkSession: SparkSession, databaseName: String): Array[String] = { - databaseName match { - case n: String if n != null && n.trim.nonEmpty => - sparkSession.catalog.listTables(n).collect().map(_.name) + val db = databaseName match { + case _ if databaseName != null && databaseName.trim.nonEmpty => + databaseName case _ => - sparkSession.catalog.listTables().collect().map(_.name) + sparkSession.catalog.currentDatabase } + sparkSession.sessionState.catalog.listTables(db).map(_.table).toArray } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index aecdda1c36..6b061f8ab2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -300,6 +300,13 @@ abstract class Catalog { */ def dropGlobalTempView(viewName: String): Boolean + /** + * Recover all the partitions in the directory of a table and update the catalog. + * + * @since 2.1.1 + */ + def recoverPartitions(tableName: String): Unit + /** * Returns true if the table is currently cached in-memory. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index e485b52b43..7616164397 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -136,7 +136,7 @@ case class RowDataSourceScanExec( * @param outputSchema Output schema of the scan. * @param partitionFilters Predicates to use for partition pruning. * @param dataFilters Data source filters to use for filtering data within partitions. - * @param metastoreTableIdentifier + * @param metastoreTableIdentifier identifier for the table in the metastore. */ case class FileSourceScanExec( @transient relation: HadoopFsRelation, @@ -147,10 +147,10 @@ case class FileSourceScanExec( override val metastoreTableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec { - val supportsBatch = relation.fileFormat.supportBatch( + val supportsBatch: Boolean = relation.fileFormat.supportBatch( relation.sparkSession, StructType.fromAttributes(output)) - val needsUnsafeRowConversion = if (relation.fileFormat.isInstanceOf[ParquetSource]) { + val needsUnsafeRowConversion: Boolean = if (relation.fileFormat.isInstanceOf[ParquetSource]) { SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled } else { false @@ -516,7 +516,6 @@ case class FileSourceScanExec( } // Assign files to partitions using "First Fit Decreasing" (FFD) - // TODO: consider adding a slop factor here? splitFiles.foreach { file => if (currentSize + file.length > maxSplitBytes) { closePartition() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index f80214af43..04b16af4ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -51,17 +51,26 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In * it. * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. `outer` has no effect when `join` is false. - * @param output the output attributes of this node, which constructed in analysis phase, - * and we can not change it, as the parent node bound with it already. + * @param generatorOutput the qualified output attributes of the generator of this node, which + * constructed in analysis phase, and we can not change it, as the + * parent node bound with it already. */ case class GenerateExec( generator: Generator, join: Boolean, outer: Boolean, - output: Seq[Attribute], + generatorOutput: Seq[Attribute], child: SparkPlan) extends UnaryExecNode with CodegenSupport { + override def output: Seq[Attribute] = { + if (join) { + child.output ++ generatorOutput + } else { + generatorOutput + } + } + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 5f0c264416..862ee05392 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -33,10 +33,6 @@ private final class ShuffledRowRDDPartition( val startPreShufflePartitionIndex: Int, val endPreShufflePartitionIndex: Int) extends Partition { override val index: Int = postShufflePartitionIndex - - override def hashCode(): Int = postShufflePartitionIndex - - override def equals(other: Any): Boolean = super.equals(other) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 4400174e92..14a983e43b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTable, _} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} /** * Concrete parser for Spark SQL statements. @@ -126,23 +126,33 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Create a [[ShowTablesCommand]] logical plan. * Example SQL : * {{{ - * SHOW TABLES [EXTENDED] [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards'] - * [PARTITION(partition_spec)]; + * SHOW TABLES [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards']; * }}} */ override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = withOrigin(ctx) { + ShowTablesCommand( + Option(ctx.db).map(_.getText), + Option(ctx.pattern).map(string), + isExtended = false) + } + + /** + * Create a [[ShowTablesCommand]] logical plan. + * Example SQL : + * {{{ + * SHOW TABLE EXTENDED [(IN|FROM) database_name] LIKE 'identifier_with_wildcards' + * [PARTITION(partition_spec)]; + * }}} + */ + override def visitShowTable(ctx: ShowTableContext): LogicalPlan = withOrigin(ctx) { if (ctx.partitionSpec != null) { - operationNotAllowed("SHOW TABLES [EXTENDED] ... PARTITION", ctx) - } - if (ctx.EXTENDED != null && ctx.pattern == null) { - throw new AnalysisException( - s"SHOW TABLES EXTENDED must have identifier_with_wildcards specified.") + operationNotAllowed("SHOW TABLE EXTENDED ... PARTITION", ctx) } ShowTablesCommand( Option(ctx.db).map(_.getText), Option(ctx.pattern).map(string), - ctx.EXTENDED != null) + isExtended = true) } /** @@ -876,6 +886,33 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { visitLocationSpec(ctx.locationSpec)) } + /** + * Create a [[AlterTableChangeColumnCommand]] command. + * + * For example: + * {{{ + * ALTER TABLE table [PARTITION partition_spec] + * CHANGE [COLUMN] column_old_name column_new_name column_dataType [COMMENT column_comment] + * [FIRST | AFTER column_name]; + * }}} + */ + override def visitChangeColumn(ctx: ChangeColumnContext): LogicalPlan = withOrigin(ctx) { + if (ctx.partitionSpec != null) { + operationNotAllowed("ALTER TABLE table PARTITION partition_spec CHANGE COLUMN", ctx) + } + + if (ctx.colPosition != null) { + operationNotAllowed( + "ALTER TABLE table [PARTITION partition_spec] CHANGE COLUMN ... FIRST | AFTER otherCol", + ctx) + } + + AlterTableChangeColumnCommand( + tableName = visitTableIdentifier(ctx.tableIdentifier), + columnName = ctx.identifier.getText, + newColumn = visitColType(ctx.colType)) + } + /** * Create location string. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2308ae8a6c..ba82ec156e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -115,7 +115,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ private def canBroadcast(plan: LogicalPlan): Boolean = { plan.statistics.isBroadcastable || - plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold + (plan.statistics.sizeInBytes >= 0 && + plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) } /** @@ -375,10 +376,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } else { execution.CoalesceExec(numPartitions, planLater(child)) :: Nil } - case logical.SortPartitions(sortExprs, child) => - // This sort only sorts tuples within a partition. Its requiredDistribution will be - // an UnspecifiedDistribution. - execution.SortExec(sortExprs, global = false, child = planLater(child)) :: Nil case logical.Sort(sortExprs, global, child) => execution.SortExec(sortExprs, global, planLater(child)) :: Nil case logical.Project(projectList, child) => @@ -403,7 +400,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.UnionExec(unionChildren.map(planLater)) :: Nil case g @ logical.Generate(generator, join, outer, _, _, child) => execution.GenerateExec( - generator, join = join, outer = outer, g.output, planLater(child)) :: Nil + generator, join = join, outer = outer, g.qualifiedGeneratorOutput, + planLater(child)) :: Nil case logical.OneRowRelation => execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil case r: logical.Range => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 516b9d5444..2ead8f6baa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -241,7 +241,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") val row = ctx.freshName("row") s""" - | while ($input.hasNext()) { + | while ($input.hasNext() && !stopEarly()) { | InternalRow $row = (InternalRow) $input.next(); | ${consume(ctx, null, row).trim} | if (shouldStop()) return; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index e6f1de5cb0..fb90799534 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -578,7 +578,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { } override def executeCollect(): Array[InternalRow] = { - ThreadUtils.awaitResultInForkJoinSafely(relationFuture, Duration.Inf) + ThreadUtils.awaitResult(relationFuture, Duration.Inf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 56bd5c1891..03cc04659b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -24,7 +24,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.SparkPlan @@ -64,7 +63,7 @@ case class InMemoryRelation( val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) extends logical.LeafNode with MultiInstanceRelation { - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) + override protected def innerChildren: Seq[SparkPlan] = Seq(child) override def producedAttributes: AttributeSet = outputSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 422700c891..81c20475a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -18,13 +18,11 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.sources.BaseRelation /** * A command used to create a data source table. @@ -58,13 +56,21 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo // Create the relation to validate the arguments before writing the metadata to the metastore, // and infer the table schema and partition if users didn't specify schema in CREATE TABLE. val pathOption = table.storage.locationUri.map("path" -> _) + // Fill in some default table options from the session conf + val tableWithDefaultOptions = table.copy( + identifier = table.identifier.copy( + database = Some( + table.identifier.database.getOrElse(sessionState.catalog.getCurrentDatabase))), + tracksPartitionsInCatalog = sparkSession.sessionState.conf.manageFilesourcePartitions) val dataSource: BaseRelation = DataSource( sparkSession = sparkSession, userSpecifiedSchema = if (table.schema.isEmpty) None else Some(table.schema), + partitionColumns = table.partitionColumnNames, className = table.provider.get, bucketSpec = table.bucketSpec, - options = table.storage.properties ++ pathOption).resolveRelation() + options = table.storage.properties ++ pathOption, + catalogTable = Some(tableWithDefaultOptions)).resolveRelation() dataSource match { case fs: HadoopFsRelation => @@ -135,8 +141,9 @@ case class CreateDataSourceTableAsSelectCommand( val tableName = tableIdentWithDB.unquotedString var createMetastoreTable = false - var existingSchema = Option.empty[StructType] - if (sparkSession.sessionState.catalog.tableExists(tableIdentWithDB)) { + // We may need to reorder the columns of the query to match the existing table. + var reorderedColumns = Option.empty[Seq[NamedExpression]] + if (sessionState.catalog.tableExists(tableIdentWithDB)) { // Check if we need to throw an exception or just return. mode match { case SaveMode.ErrorIfExists => @@ -149,39 +156,76 @@ case class CreateDataSourceTableAsSelectCommand( // Since the table already exists and the save mode is Ignore, we will just return. return Seq.empty[Row] case SaveMode.Append => + val existingTable = sessionState.catalog.getTableMetadata(tableIdentWithDB) + + if (existingTable.provider.get == DDLUtils.HIVE_PROVIDER) { + throw new AnalysisException(s"Saving data in the Hive serde table $tableName is " + + "not supported yet. Please use the insertInto() API as an alternative.") + } + // Check if the specified data source match the data source of the existing table. - val existingProvider = DataSource.lookupDataSource(provider) + val existingProvider = DataSource.lookupDataSource(existingTable.provider.get) + val specifiedProvider = DataSource.lookupDataSource(table.provider.get) // TODO: Check that options from the resolved relation match the relation that we are // inserting into (i.e. using the same compression). + if (existingProvider != specifiedProvider) { + throw new AnalysisException(s"The format of the existing table $tableName is " + + s"`${existingProvider.getSimpleName}`. It doesn't match the specified format " + + s"`${specifiedProvider.getSimpleName}`.") + } - // Pass a table identifier with database part, so that `lookupRelation` won't get temp - // views unexpectedly. - EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) match { - case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) => - // check if the file formats match - l.relation match { - case r: HadoopFsRelation if r.fileFormat.getClass != existingProvider => - throw new AnalysisException( - s"The file format of the existing table $tableName is " + - s"`${r.fileFormat.getClass.getName}`. It doesn't match the specified " + - s"format `$provider`") - case _ => - } - if (query.schema.size != l.schema.size) { - throw new AnalysisException( - s"The column number of the existing schema[${l.schema}] " + - s"doesn't match the data schema[${query.schema}]'s") - } - existingSchema = Some(l.schema) - case s: SimpleCatalogRelation if DDLUtils.isDatasourceTable(s.metadata) => - existingSchema = Some(s.metadata.schema) - case c: CatalogRelation if c.catalogTable.provider == Some(DDLUtils.HIVE_PROVIDER) => - throw new AnalysisException("Saving data in the Hive serde table " + - s"${c.catalogTable.identifier} is not supported yet. Please use the " + - "insertInto() API as an alternative..") - case o => - throw new AnalysisException(s"Saving data in ${o.toString} is not supported.") + if (query.schema.length != existingTable.schema.length) { + throw new AnalysisException( + s"The column number of the existing table $tableName" + + s"(${existingTable.schema.catalogString}) doesn't match the data schema" + + s"(${query.schema.catalogString})") } + + val resolver = sessionState.conf.resolver + val tableCols = existingTable.schema.map(_.name) + + reorderedColumns = Some(existingTable.schema.map { f => + query.resolve(Seq(f.name), resolver).getOrElse { + val inputColumns = query.schema.map(_.name).mkString(", ") + throw new AnalysisException( + s"cannot resolve '${f.name}' given input columns: [$inputColumns]") + } + }) + + // In `AnalyzeCreateTable`, we verified the consistency between the user-specified table + // definition(partition columns, bucketing) and the SELECT query, here we also need to + // verify the the consistency between the user-specified table definition and the existing + // table definition. + + // Check if the specified partition columns match the existing table. + val specifiedPartCols = CatalogUtils.normalizePartCols( + tableName, tableCols, table.partitionColumnNames, resolver) + if (specifiedPartCols != existingTable.partitionColumnNames) { + throw new AnalysisException( + s""" + |Specified partitioning does not match that of the existing table $tableName. + |Specified partition columns: [${specifiedPartCols.mkString(", ")}] + |Existing partition columns: [${existingTable.partitionColumnNames.mkString(", ")}] + """.stripMargin) + } + + // Check if the specified bucketing match the existing table. + val specifiedBucketSpec = table.bucketSpec.map { bucketSpec => + CatalogUtils.normalizeBucketSpec(tableName, tableCols, bucketSpec, resolver) + } + if (specifiedBucketSpec != existingTable.bucketSpec) { + val specifiedBucketString = + specifiedBucketSpec.map(_.toString).getOrElse("not bucketed") + val existingBucketString = + existingTable.bucketSpec.map(_.toString).getOrElse("not bucketed") + throw new AnalysisException( + s""" + |Specified bucketing does not match that of the existing table $tableName. + |Specified bucketing: $specifiedBucketString + |Existing bucketing: $existingBucketString + """.stripMargin) + } + case SaveMode.Overwrite => sessionState.catalog.dropTable(tableIdentWithDB, ignoreIfNotExists = true, purge = false) // Need to create the table again. @@ -193,9 +237,9 @@ case class CreateDataSourceTableAsSelectCommand( } val data = Dataset.ofRows(sparkSession, query) - val df = existingSchema match { - // If we are inserting into an existing table, just use the existing schema. - case Some(s) => data.selectExpr(s.fieldNames: _*) + val df = reorderedColumns match { + // Reorder the columns of the query to match the existing table. + case Some(cols) => data.select(cols.map(Column(_)): _*) case None => data } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index c62c14200c..522158b641 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -274,6 +274,77 @@ case class AlterTableUnsetPropertiesCommand( } + +/** + * A command to change the column for a table, only support changing the comment of a non-partition + * column for now. + * + * The syntax of using this command in SQL is: + * {{{ + * ALTER TABLE table_identifier + * CHANGE [COLUMN] column_old_name column_new_name column_dataType [COMMENT column_comment] + * [FIRST | AFTER column_name]; + * }}} + */ +case class AlterTableChangeColumnCommand( + tableName: TableIdentifier, + columnName: String, + newColumn: StructField) extends RunnableCommand { + + // TODO: support change column name/dataType/metadata/position. + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog + val table = catalog.getTableMetadata(tableName) + val resolver = sparkSession.sessionState.conf.resolver + DDLUtils.verifyAlterTableType(catalog, table, isView = false) + + // Find the origin column from schema by column name. + val originColumn = findColumnByName(table.schema, columnName, resolver) + // Throw an AnalysisException if the column name/dataType is changed. + if (!columnEqual(originColumn, newColumn, resolver)) { + throw new AnalysisException( + "ALTER TABLE CHANGE COLUMN is not supported for changing column " + + s"'${originColumn.name}' with type '${originColumn.dataType}' to " + + s"'${newColumn.name}' with type '${newColumn.dataType}'") + } + + val newSchema = table.schema.fields.map { field => + if (field.name == originColumn.name) { + // Create a new column from the origin column with the new comment. + addComment(field, newColumn.getComment) + } else { + field + } + } + val newTable = table.copy(schema = StructType(newSchema)) + catalog.alterTable(newTable) + + Seq.empty[Row] + } + + // Find the origin column from schema by column name, throw an AnalysisException if the column + // reference is invalid. + private def findColumnByName( + schema: StructType, name: String, resolver: Resolver): StructField = { + schema.fields.collectFirst { + case field if resolver(field.name, name) => field + }.getOrElse(throw new AnalysisException( + s"Invalid column reference '$name', table schema is '${schema}'")) + } + + // Add the comment to a column, if comment is empty, return the original column. + private def addComment(column: StructField, comment: Option[String]): StructField = { + comment.map(column.withComment(_)).getOrElse(column) + } + + // Compare a [[StructField]] to another, return true if they have the same column + // name(by resolver) and dataType. + private def columnEqual( + field: StructField, other: StructField, resolver: Resolver): Boolean = { + resolver(field.name, other.name) && field.dataType == other.dataType + } +} + /** * A command that sets the serde class and/or serde properties of a table/view. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index dc0720d78d..012b6ea4c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -203,7 +203,7 @@ case class LoadDataCommand( throw new AnalysisException(s"LOAD DATA target table $tableIdentwithDB is partitioned, " + s"but number of columns in provided partition spec (${partition.get.size}) " + s"do not match number of partitioned columns in table " + - s"(s${targetTable.partitionColumnNames.size})") + s"(${targetTable.partitionColumnNames.size})") } partition.get.keys.foreach { colName => if (!targetTable.partitionColumnNames.contains(colName)) { @@ -297,13 +297,15 @@ case class LoadDataCommand( partition.get, isOverwrite, holdDDLTime = false, - inheritTableSpecs = true) + inheritTableSpecs = true, + isSrcLocal = isLocal) } else { catalog.loadTable( targetTable.identifier, loadPath.toString, isOverwrite, - holdDDLTime = false) + holdDDLTime = false, + isSrcLocal = isLocal) } Seq.empty[Row] } @@ -590,7 +592,8 @@ case class DescribeTableCommand( * If a databaseName is not given, the current database will be used. * The syntax of using this command in SQL is: * {{{ - * SHOW TABLES [EXTENDED] [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards']; + * SHOW TABLES [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards']; + * SHOW TABLE EXTENDED [(IN|FROM) database_name] LIKE 'identifier_with_wildcards'; * }}} */ case class ShowTablesCommand( @@ -598,8 +601,8 @@ case class ShowTablesCommand( tableIdentifierPattern: Option[String], isExtended: Boolean = false) extends RunnableCommand { - // The result of SHOW TABLES has three basic columns: database, tableName and isTemporary. - // If `isExtended` is true, append column `information` to the output columns. + // The result of SHOW TABLES/SHOW TABLE has three basic columns: database, tableName and + // isTemporary. If `isExtended` is true, append column `information` to the output columns. override val output: Seq[Attribute] = { val tableExtendedInfo = if (isExtended) { AttributeReference("information", StringType, nullable = false)() :: Nil @@ -729,13 +732,6 @@ case class ShowPartitionsCommand( AttributeReference("partition", StringType, nullable = false)() :: Nil } - private def getPartName(spec: TablePartitionSpec, partColNames: Seq[String]): String = { - partColNames.map { name => - ExternalCatalogUtils.escapePathName(name) + "=" + - ExternalCatalogUtils.escapePathName(spec(name)) - }.mkString(File.separator) - } - override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) @@ -772,10 +768,7 @@ case class ShowPartitionsCommand( } } - val partNames = catalog.listPartitions(tableName, spec).map { p => - getPartName(p.spec, table.partitionColumnNames) - } - + val partNames = catalog.listPartitionNames(tableName, spec) partNames.map(Row(_)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 4ad91dcceb..1235a4b12f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession @@ -37,14 +38,15 @@ class CatalogFileIndex( val table: CatalogTable, override val sizeInBytes: Long) extends FileIndex { - protected val hadoopConf = sparkSession.sessionState.newHadoopConf + protected val hadoopConf: Configuration = sparkSession.sessionState.newHadoopConf() - private val fileStatusCache = FileStatusCache.newCache(sparkSession) + /** Globally shared (not exclusive to this table) cache for file statuses to speed up listing. */ + private val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) assert(table.identifier.database.isDefined, "The table identifier must be qualified in CatalogFileIndex") - private val baseLocation = table.storage.locationUri + private val baseLocation: Option[String] = table.storage.locationUri override def partitionSchema: StructType = table.partitionSchema @@ -76,7 +78,8 @@ class CatalogFileIndex( new PrunedInMemoryFileIndex( sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec) } else { - new InMemoryFileIndex(sparkSession, rootPaths, table.storage.properties, None) + new InMemoryFileIndex( + sparkSession, rootPaths, table.storage.properties, partitionSchema = None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index ccfc759c8f..ac3f0688bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -84,7 +84,7 @@ case class DataSource( case class SourceInfo(name: String, schema: StructType, partitionColumns: Seq[String]) lazy val providingClass: Class[_] = DataSource.lookupDataSource(className) - lazy val sourceInfo = sourceSchema() + lazy val sourceInfo: SourceInfo = sourceSchema() private val caseInsensitiveOptions = new CaseInsensitiveMap(options) /** @@ -132,7 +132,7 @@ case class DataSource( }.toArray new InMemoryFileIndex(sparkSession, globbedPaths, options, None) } - val partitionSchema = if (partitionColumns.isEmpty && catalogTable.isEmpty) { + val partitionSchema = if (partitionColumns.isEmpty) { // Try to infer partitioning, because no DataSource in the read path provides the partitioning // columns properly unless it is a Hive DataSource val resolved = tempFileIndex.partitionSchema.map { partitionField => @@ -278,7 +278,7 @@ case class DataSource( throw new IllegalArgumentException("'path' is not specified") }) if (outputMode != OutputMode.Append) { - throw new IllegalArgumentException( + throw new AnalysisException( s"Data source $className does not support $outputMode output mode") } new FileStreamSink(sparkSession, path, fileFormat, partitionColumns, caseInsensitiveOptions) @@ -388,10 +388,11 @@ case class DataSource( val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) { + val defaultTableSize = sparkSession.sessionState.conf.defaultSizeInBytes new CatalogFileIndex( sparkSession, catalogTable.get, - catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(0L)) + catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(defaultTableSize)) } else { new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(partitionSchema)) } @@ -478,7 +479,7 @@ case class DataSource( val plan = InsertIntoHadoopFsRelationCommand( outputPath = outputPath, - staticPartitionKeys = Map.empty, + staticPartitions = Map.empty, customPartitionLocations = Map.empty, partitionColumns = columns, bucketSpec = bucketSpec, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 4468dc58e4..61f0d43f24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTablePartition, SimpleCatalogRelation} @@ -32,8 +32,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, Project} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} @@ -100,7 +99,7 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { None } else if (potentialSpecs.size == 1) { val partValue = potentialSpecs.head._2 - Some(Alias(Cast(Literal(partValue), field.dataType), "_staticPart")()) + Some(Alias(Cast(Literal(partValue), field.dataType), field.name)()) } else { throw new AnalysisException( s"Partition column ${field.name} have multiple values specified, " + @@ -128,61 +127,75 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { projectList } + /** + * Returns true if the [[InsertIntoTable]] plan has already been preprocessed by analyzer rule + * [[PreprocessTableInsertion]]. It is important that this rule([[DataSourceAnalysis]]) has to + * be run after [[PreprocessTableInsertion]], to normalize the column names in partition spec and + * fix the schema mismatch by adding Cast. + */ + private def hasBeenPreprocessed( + tableOutput: Seq[Attribute], + partSchema: StructType, + partSpec: Map[String, Option[String]], + query: LogicalPlan): Boolean = { + val partColNames = partSchema.map(_.name).toSet + query.resolved && partSpec.keys.forall(partColNames.contains) && { + val staticPartCols = partSpec.filter(_._2.isDefined).keySet + val expectedColumns = tableOutput.filterNot(a => staticPartCols.contains(a.name)) + expectedColumns.toStructType.sameType(query.schema) + } + } + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // If the InsertIntoTable command is for a partitioned HadoopFsRelation and - // the user has specified static partitions, we add a Project operator on top of the query - // to include those constant column values in the query result. - // - // Example: - // Let's say that we have a table "t", which is created by - // CREATE TABLE t (a INT, b INT, c INT) USING parquet PARTITIONED BY (b, c) - // The statement of "INSERT INTO TABLE t PARTITION (b=2, c) SELECT 1, 3" - // will be converted to "INSERT INTO TABLE t PARTITION (b, c) SELECT 1, 2, 3". - // - // Basically, we will put those partition columns having a assigned value back - // to the SELECT clause. The output of the SELECT clause is organized as - // normal_columns static_partitioning_columns dynamic_partitioning_columns. - // static_partitioning_columns are partitioning columns having assigned - // values in the PARTITION clause (e.g. b in the above example). - // dynamic_partitioning_columns are partitioning columns that do not assigned - // values in the PARTITION clause (e.g. c in the above example). - case insert @ logical.InsertIntoTable( - relation @ LogicalRelation(t: HadoopFsRelation, _, _), parts, query, overwrite, false) - if query.resolved && parts.exists(_._2.isDefined) => - - val projectList = convertStaticPartitions( - sourceAttributes = query.output, - providedPartitions = parts, - targetAttributes = relation.output, - targetPartitionSchema = t.partitionSchema) - - // We will remove all assigned values to static partitions because they have been - // moved to the projectList. - insert.copy(partition = parts.map(p => (p._1, None)), child = Project(projectList, query)) - - - case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: HadoopFsRelation, _, table), part, query, overwrite, false) - if query.resolved && t.schema.sameType(query.schema) => - - // Sanity checks + case InsertIntoTable( + l @ LogicalRelation(t: HadoopFsRelation, _, table), parts, query, overwrite, false) + if hasBeenPreprocessed(l.output, t.partitionSchema, parts, query) => + + // If the InsertIntoTable command is for a partitioned HadoopFsRelation and + // the user has specified static partitions, we add a Project operator on top of the query + // to include those constant column values in the query result. + // + // Example: + // Let's say that we have a table "t", which is created by + // CREATE TABLE t (a INT, b INT, c INT) USING parquet PARTITIONED BY (b, c) + // The statement of "INSERT INTO TABLE t PARTITION (b=2, c) SELECT 1, 3" + // will be converted to "INSERT INTO TABLE t PARTITION (b, c) SELECT 1, 2, 3". + // + // Basically, we will put those partition columns having a assigned value back + // to the SELECT clause. The output of the SELECT clause is organized as + // normal_columns static_partitioning_columns dynamic_partitioning_columns. + // static_partitioning_columns are partitioning columns having assigned + // values in the PARTITION clause (e.g. b in the above example). + // dynamic_partitioning_columns are partitioning columns that do not assigned + // values in the PARTITION clause (e.g. c in the above example). + val actualQuery = if (parts.exists(_._2.isDefined)) { + val projectList = convertStaticPartitions( + sourceAttributes = query.output, + providedPartitions = parts, + targetAttributes = l.output, + targetPartitionSchema = t.partitionSchema) + Project(projectList, query) + } else { + query + } + + // Sanity check if (t.location.rootPaths.size != 1) { - throw new AnalysisException( - "Can only write data to relations with a single path.") + throw new AnalysisException("Can only write data to relations with a single path.") } val outputPath = t.location.rootPaths.head - val inputPaths = query.collect { + val inputPaths = actualQuery.collect { case LogicalRelation(r: HadoopFsRelation, _, _) => r.location.rootPaths }.flatten - val mode = if (overwrite.enabled) SaveMode.Overwrite else SaveMode.Append - if (overwrite.enabled && inputPaths.contains(outputPath)) { + val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append + if (overwrite && inputPaths.contains(outputPath)) { throw new AnalysisException( "Cannot overwrite a path that is also being read from.") } - val partitionSchema = query.resolve( + val partitionSchema = actualQuery.resolve( t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) val partitionsTrackedByCatalog = t.sparkSession.sessionState.conf.manageFilesourcePartitions && @@ -192,11 +205,13 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty + val staticPartitions = parts.filter(_._2.nonEmpty).map { case (k, v) => k -> v.get } + // When partitions are tracked by the catalog, compute all custom partition locations that // may be relevant to the insertion job. if (partitionsTrackedByCatalog) { val matchingPartitions = t.sparkSession.sessionState.catalog.listPartitions( - l.catalogTable.get.identifier, Some(overwrite.staticPartitionKeys)) + l.catalogTable.get.identifier, Some(staticPartitions)) initialMatchingPartitions = matchingPartitions.map(_.spec) customPartitionLocations = getCustomPartitionLocations( t.sparkSession, l.catalogTable.get, outputPath, matchingPartitions) @@ -212,7 +227,7 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { l.catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), ifNotExists = true).run(t.sparkSession) } - if (overwrite.enabled) { + if (overwrite) { val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions if (deletedPartitions.nonEmpty) { AlterTableDropPartitionCommand( @@ -225,24 +240,16 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { t.location.refresh() } - val staticPartitionKeys: TablePartitionSpec = if (overwrite.enabled) { - overwrite.staticPartitionKeys.map { case (k, v) => - (partitionSchema.map(_.name).find(_.equalsIgnoreCase(k)).get, v) - } - } else { - Map.empty - } - val insertCmd = InsertIntoHadoopFsRelationCommand( outputPath, - staticPartitionKeys, + staticPartitions, customPartitionLocations, partitionSchema, t.bucketSpec, t.fileFormat, refreshPartitionsCallback, t.options, - query, + actualQuery, mode, table) @@ -305,7 +312,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] } override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ logical.InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) + case i @ InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) if DDLUtils.isDatasourceTable(s.metadata) => i.copy(table = readDataSourceTable(sparkSession, s)) @@ -351,7 +358,7 @@ object DataSourceStrategy extends Strategy with Logging { Map.empty, None) :: Nil - case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _), + case InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _), part, query, overwrite, false) if part.isEmpty => ExecutedCommandExec(InsertIntoDataSourceCommand(l, query, overwrite)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 4f4aaaa502..6784ee243c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -148,7 +148,8 @@ trait FileFormat { * The base class file format that is based on text file. */ abstract class TextBasedFileFormat extends FileFormat { - private var codecFactory: CompressionCodecFactory = null + private var codecFactory: CompressionCodecFactory = _ + override def isSplitable( sparkSession: SparkSession, options: Map[String, String], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index d560ad5709..1eb4541e2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -31,13 +31,12 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage -import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -47,6 +46,13 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** A helper object for writing FileFormat data out to a location. */ object FileFormatWriter extends Logging { + /** + * Max number of files a single task writes out due to file size. In most cases the number of + * files written should be very small. This is just a safe guard to protect some really bad + * settings, e.g. maxRecordsPerFile = 1. + */ + private val MAX_FILE_COUNTER = 1000 * 1000 + /** Describes how output files should be placed in the filesystem. */ case class OutputSpec( outputPath: String, customPartitionLocations: Map[TablePartitionSpec, String]) @@ -61,7 +67,8 @@ object FileFormatWriter extends Logging { val nonPartitionColumns: Seq[Attribute], val bucketSpec: Option[BucketSpec], val path: String, - val customPartitionLocations: Map[TablePartitionSpec, String]) + val customPartitionLocations: Map[TablePartitionSpec, String], + val maxRecordsPerFile: Long) extends Serializable { assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns), @@ -116,7 +123,10 @@ object FileFormatWriter extends Logging { nonPartitionColumns = dataColumns, bucketSpec = bucketSpec, path = outputSpec.outputPath, - customPartitionLocations = outputSpec.customPartitionLocations) + customPartitionLocations = outputSpec.customPartitionLocations, + maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong) + .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile) + ) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { // This call shouldn't be put into the `try` block below because it only initializes and @@ -225,32 +235,49 @@ object FileFormatWriter extends Logging { taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol) extends ExecuteWriteTask { - private[this] var outputWriter: OutputWriter = { + private[this] var currentWriter: OutputWriter = _ + + private def newOutputWriter(fileCounter: Int): Unit = { + val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) val tmpFilePath = committer.newTaskTempFile( taskAttemptContext, None, - description.outputWriterFactory.getFileExtension(taskAttemptContext)) + f"-c$fileCounter%03d" + ext) - val outputWriter = description.outputWriterFactory.newInstance( + currentWriter = description.outputWriterFactory.newInstance( path = tmpFilePath, dataSchema = description.nonPartitionColumns.toStructType, context = taskAttemptContext) - outputWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType) - outputWriter + currentWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType) } override def execute(iter: Iterator[InternalRow]): Set[String] = { + var fileCounter = 0 + var recordsInFile: Long = 0L + newOutputWriter(fileCounter) while (iter.hasNext) { + if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + recordsInFile = 0 + releaseResources() + newOutputWriter(fileCounter) + } + val internalRow = iter.next() - outputWriter.writeInternal(internalRow) + currentWriter.writeInternal(internalRow) + recordsInFile += 1 } + releaseResources() Set.empty } override def releaseResources(): Unit = { - if (outputWriter != null) { - outputWriter.close() - outputWriter = null + if (currentWriter != null) { + currentWriter.close() + currentWriter = null } } } @@ -300,8 +327,15 @@ object FileFormatWriter extends Logging { * Open and returns a new OutputWriter given a partition key and optional bucket id. * If bucket id is specified, we will append it to the end of the file name, but before the * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet + * + * @param key vaues for fields consisting of partition keys for the current row + * @param partString a function that projects the partition values into a string + * @param fileCounter the number of files that have been written in the past for this specific + * partition. This is used to limit the max number of records written for a + * single file. The value should start from 0. */ - private def newOutputWriter(key: InternalRow, partString: UnsafeProjection): OutputWriter = { + private def newOutputWriter( + key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = { val partDir = if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0)) @@ -311,7 +345,10 @@ object FileFormatWriter extends Logging { } else { "" } - val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext) + + // This must be in a form that matches our bucketing format. See BucketingUtils. + val ext = f"$bucketId.c$fileCounter%03d" + + description.outputWriterFactory.getFileExtension(taskAttemptContext) val customPath = partDir match { case Some(dir) => @@ -324,12 +361,12 @@ object FileFormatWriter extends Logging { } else { committer.newTaskTempFile(taskAttemptContext, partDir, ext) } - val newWriter = description.outputWriterFactory.newInstance( + + currentWriter = description.outputWriterFactory.newInstance( path = path, dataSchema = description.nonPartitionColumns.toStructType, context = taskAttemptContext) - newWriter.initConverter(description.nonPartitionColumns.toStructType) - newWriter + currentWriter.initConverter(description.nonPartitionColumns.toStructType) } override def execute(iter: Iterator[InternalRow]): Set[String] = { @@ -349,7 +386,7 @@ object FileFormatWriter extends Logging { description.nonPartitionColumns, description.allColumns) // Returns the partition path given a partition key. - val getPartitionString = UnsafeProjection.create( + val getPartitionStringFunc = UnsafeProjection.create( Seq(Concat(partitionStringExpression)), description.partitionColumns) // Sorts the data before write, so that we only need one writer at the same time. @@ -366,7 +403,6 @@ object FileFormatWriter extends Logging { val currentRow = iter.next() sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) } - logInfo(s"Sorting complete. Writing out partition files one at a time.") val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { identity @@ -379,30 +415,43 @@ object FileFormatWriter extends Logging { val sortedIterator = sorter.sortedIterator() // If anything below fails, we should abort the task. + var recordsInFile: Long = 0L + var fileCounter = 0 var currentKey: UnsafeRow = null val updatedPartitions = mutable.Set[String]() while (sortedIterator.next()) { val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] if (currentKey != nextKey) { - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } + // See a new key - write to a new partition (new file). currentKey = nextKey.copy() logDebug(s"Writing partition: $currentKey") - currentWriter = newOutputWriter(currentKey, getPartitionString) - val partitionPath = getPartitionString(currentKey).getString(0) + recordsInFile = 0 + fileCounter = 0 + + releaseResources() + newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) + val partitionPath = getPartitionStringFunc(currentKey).getString(0) if (partitionPath.nonEmpty) { updatedPartitions.add(partitionPath) } + } else if (description.maxRecordsPerFile > 0 && + recordsInFile >= description.maxRecordsPerFile) { + // Exceeded the threshold in terms of the number of records per file. + // Create a new file by increasing the file counter. + recordsInFile = 0 + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + releaseResources() + newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) } + currentWriter.writeInternal(sortedIterator.getValue) + recordsInFile += 1 } - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } + releaseResources() updatedPartitions.toSet } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 89944570df..dced536136 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -21,9 +21,9 @@ import java.io.IOException import scala.collection.mutable -import org.apache.spark.{Partition => RDDPartition, TaskContext} +import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.rdd.{InputFileNameHolder, RDD} +import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.vectorized.ColumnarBatch @@ -99,7 +99,15 @@ class FileScanRDD( private[this] var currentFile: PartitionedFile = null private[this] var currentIterator: Iterator[Object] = null - def hasNext: Boolean = (currentIterator != null && currentIterator.hasNext) || nextIterator() + def hasNext: Boolean = { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. + if (context.isInterrupted()) { + throw new TaskKilledException + } + (currentIterator != null && currentIterator.hasNext) || nextIterator() + } def next(): Object = { val nextElement = currentIterator.next() // TODO: we should have a better separation of row based and batch based scan, so that we @@ -121,7 +129,8 @@ class FileScanRDD( if (files.hasNext) { currentFile = files.next() logInfo(s"Reading File $currentFile") - InputFileNameHolder.setInputFileName(currentFile.filePath) + // Sets InputFileBlockHolder for the file block's information + InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length) try { if (ignoreCorruptFiles) { @@ -138,6 +147,7 @@ class FileScanRDD( } } catch { case e: IOException => + logWarning(s"Skipped the rest content in the corrupted file: $currentFile", e) finished = true null } @@ -149,6 +159,9 @@ class FileScanRDD( currentIterator = readFunction(currentFile) } } catch { + case e: IOException if ignoreCorruptFiles => + logWarning(s"Skipped the rest content in the corrupted file: $currentFile", e) + currentIterator = Iterator.empty case e: java.io.FileNotFoundException => throw new java.io.FileNotFoundException( e.getMessage + "\n" + @@ -162,7 +175,7 @@ class FileScanRDD( hasNext } else { currentFile = null - InputFileNameHolder.unsetInputFileName() + InputFileBlockHolder.unset() false } } @@ -170,7 +183,7 @@ class FileScanRDD( override def close(): Unit = { updateBytesRead() updateBytesReadWithFileSize() - InputFileNameHolder.unsetInputFileName() + InputFileBlockHolder.unset() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 55ca4f1106..ead3233202 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -86,7 +86,7 @@ object FileSourceStrategy extends Strategy with Logging { val dataFilters = normalizedFilters.filter(_.references.intersect(partitionSet).isEmpty) // Predicates with both partition keys and attributes need to be evaluated after the scan. - val afterScanFilters = filterSet -- partitionKeyFilters + val afterScanFilters = filterSet -- partitionKeyFilters.filter(_.references.nonEmpty) logInfo(s"Post-Scan Filters: ${afterScanFilters.mkString(",")}") val filterAttributes = AttributeSet(afterScanFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala index 7c2e6fd04d..5d97558633 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileStatusCache.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.datasources -import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ @@ -26,9 +25,38 @@ import com.google.common.cache._ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.internal.Logging -import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.SparkSession -import org.apache.spark.util.{SerializableConfiguration, SizeEstimator} +import org.apache.spark.util.SizeEstimator + + +/** + * Use [[FileStatusCache.getOrCreate()]] to construct a globally shared file status cache. + */ +object FileStatusCache { + private var sharedCache: SharedInMemoryCache = _ + + /** + * @return a new FileStatusCache based on session configuration. Cache memory quota is + * shared across all clients. + */ + def getOrCreate(session: SparkSession): FileStatusCache = synchronized { + if (session.sqlContext.conf.manageFilesourcePartitions && + session.sqlContext.conf.filesourcePartitionFileCacheSize > 0) { + if (sharedCache == null) { + sharedCache = new SharedInMemoryCache( + session.sqlContext.conf.filesourcePartitionFileCacheSize) + } + sharedCache.createForNewClient() + } else { + NoopCache + } + } + + def resetForTesting(): Unit = synchronized { + sharedCache = null + } +} + /** * A cache of the leaf files of partition directories. We cache these files in order to speed @@ -55,32 +83,6 @@ abstract class FileStatusCache { def invalidateAll(): Unit } -object FileStatusCache { - private var sharedCache: SharedInMemoryCache = null - - /** - * @return a new FileStatusCache based on session configuration. Cache memory quota is - * shared across all clients. - */ - def newCache(session: SparkSession): FileStatusCache = { - synchronized { - if (session.sqlContext.conf.manageFilesourcePartitions && - session.sqlContext.conf.filesourcePartitionFileCacheSize > 0) { - if (sharedCache == null) { - sharedCache = new SharedInMemoryCache( - session.sqlContext.conf.filesourcePartitionFileCacheSize) - } - sharedCache.getForNewClient() - } else { - NoopCache - } - } - } - - def resetForTesting(): Unit = synchronized { - sharedCache = null - } -} /** * An implementation that caches partition file statuses in memory. @@ -88,7 +90,6 @@ object FileStatusCache { * @param maxSizeInBytes max allowable cache size before entries start getting evicted */ private class SharedInMemoryCache(maxSizeInBytes: Long) extends Logging { - import FileStatusCache._ // Opaque object that uniquely identifies a shared cache user private type ClientId = Object @@ -102,8 +103,9 @@ private class SharedInMemoryCache(maxSizeInBytes: Long) extends Logging { (SizeEstimator.estimate(key) + SizeEstimator.estimate(value)).toInt }}) .removalListener(new RemovalListener[(ClientId, Path), Array[FileStatus]]() { - override def onRemoval(removed: RemovalNotification[(ClientId, Path), Array[FileStatus]]) = { - if (removed.getCause() == RemovalCause.SIZE && + override def onRemoval(removed: RemovalNotification[(ClientId, Path), Array[FileStatus]]) + : Unit = { + if (removed.getCause == RemovalCause.SIZE && warnedAboutEviction.compareAndSet(false, true)) { logWarning( "Evicting cached table partition metadata from memory due to size constraints " + @@ -112,13 +114,13 @@ private class SharedInMemoryCache(maxSizeInBytes: Long) extends Logging { } }}) .maximumWeight(maxSizeInBytes) - .build() + .build[(ClientId, Path), Array[FileStatus]]() /** * @return a FileStatusCache that does not share any entries with any other client, but does * share memory resources for the purpose of cache eviction. */ - def getForNewClient(): FileStatusCache = new FileStatusCache { + def createForNewClient(): FileStatusCache = new FileStatusCache { val clientId = new Object() override def getLeafFiles(path: Path): Option[Array[FileStatus]] = { @@ -126,7 +128,7 @@ private class SharedInMemoryCache(maxSizeInBytes: Long) extends Logging { } override def putLeafFiles(path: Path, leafFiles: Array[FileStatus]): Unit = { - cache.put((clientId, path), leafFiles.toArray) + cache.put((clientId, path), leafFiles) } override def invalidateAll(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala index 014abd454f..9a08524476 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelation.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.execution.datasources +import scala.collection.mutable + import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.execution.FileRelation import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} /** @@ -49,10 +51,16 @@ case class HadoopFsRelation( override def sqlContext: SQLContext = sparkSession.sqlContext val schema: StructType = { - val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet - StructType(dataSchema ++ partitionSchema.filterNot { column => - dataSchemaColumnNames.contains(column.name.toLowerCase) - }) + val getColName: (StructField => String) = + if (sparkSession.sessionState.conf.caseSensitiveAnalysis) _.name else _.name.toLowerCase + val overlappedPartCols = mutable.Map.empty[String, StructField] + partitionSchema.foreach { partitionField => + if (dataSchema.exists(getColName(_) == getColName(partitionField))) { + overlappedPartCols += getColName(partitionField) -> partitionField + } + } + StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++ + partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f)))) } def partitionSchemaOption: Option[StructType] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index 2eba1e9986..b2ff68a833 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OverwriteOptions} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.sources.InsertableRelation @@ -30,7 +30,7 @@ import org.apache.spark.sql.sources.InsertableRelation case class InsertIntoDataSourceCommand( logicalRelation: LogicalRelation, query: LogicalPlan, - overwrite: OverwriteOptions) + overwrite: Boolean) extends RunnableCommand { override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) @@ -40,7 +40,7 @@ case class InsertIntoDataSourceCommand( val data = Dataset.ofRows(sparkSession, query) // Apply the schema of the existing table to the new data. val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) - relation.insert(df, overwrite.enabled) + relation.insert(df, overwrite) // Invalidate the cache. sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index a9bde903b3..53c884c22b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -33,10 +33,10 @@ import org.apache.spark.sql.execution.command.RunnableCommand * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. * Writing to dynamic partitions is also supported. * - * @param staticPartitionKeys partial partitioning spec for write. This defines the scope of - * partition overwrites: when the spec is empty, all partitions are - * overwritten. When it covers a prefix of the partition keys, only - * partitions matching the prefix are overwritten. + * @param staticPartitions partial partitioning spec for write. This defines the scope of partition + * overwrites: when the spec is empty, all partitions are overwritten. + * When it covers a prefix of the partition keys, only partitions matching + * the prefix are overwritten. * @param customPartitionLocations mapping of partition specs to their custom locations. The * caller should guarantee that exactly those table partitions * falling under the specified static partition keys are contained @@ -44,7 +44,7 @@ import org.apache.spark.sql.execution.command.RunnableCommand */ case class InsertIntoHadoopFsRelationCommand( outputPath: Path, - staticPartitionKeys: TablePartitionSpec, + staticPartitions: TablePartitionSpec, customPartitionLocations: Map[TablePartitionSpec, String], partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], @@ -122,9 +122,9 @@ case class InsertIntoHadoopFsRelationCommand( * locations are also cleared based on the custom locations map given to this class. */ private def deleteMatchingPartitions(fs: FileSystem, qualifiedOutputPath: Path): Unit = { - val staticPartitionPrefix = if (staticPartitionKeys.nonEmpty) { + val staticPartitionPrefix = if (staticPartitions.nonEmpty) { "/" + partitionColumns.flatMap { p => - staticPartitionKeys.get(p.name) match { + staticPartitions.get(p.name) match { case Some(value) => Some(escapePathName(p.name) + "=" + escapePathName(value)) case None => @@ -143,7 +143,7 @@ case class InsertIntoHadoopFsRelationCommand( // now clear all custom partition locations (e.g. /custom/dir/where/foo=2/bar=4) for ((spec, customLoc) <- customPartitionLocations) { assert( - (staticPartitionKeys.toSet -- spec).isEmpty, + (staticPartitions.toSet -- spec).isEmpty, "Custom partition location did not match static partitioning keys") val path = new Path(customLoc) if (fs.exists(path) && !fs.delete(path, true)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index bf9f318780..bc290702dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -244,13 +244,22 @@ object PartitioningUtils { /** * Given a partition path fragment, e.g. `fieldOne=1/fieldTwo=2`, returns a parsed spec - * for that fragment, e.g. `Map(("fieldOne", "1"), ("fieldTwo", "2"))`. + * for that fragment as a `TablePartitionSpec`, e.g. `Map(("fieldOne", "1"), ("fieldTwo", "2"))`. */ def parsePathFragment(pathFragment: String): TablePartitionSpec = { + parsePathFragmentAsSeq(pathFragment).toMap + } + + /** + * Given a partition path fragment, e.g. `fieldOne=1/fieldTwo=2`, returns a parsed spec + * for that fragment as a `Seq[(String, String)]`, e.g. + * `Seq(("fieldOne", "1"), ("fieldTwo", "2"))`. + */ + def parsePathFragmentAsSeq(pathFragment: String): Seq[(String, String)] = { pathFragment.split("/").map { kv => val pair = kv.split("=", 2) (unescapePathName(pair(0)), unescapePathName(pair(1))) - }.toMap + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index a3691158ee..b0feaeb84e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -27,10 +27,12 @@ import org.apache.hadoop.mapreduce._ import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.functions.{length, trim} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -52,17 +54,20 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val csvOptions = new CSVOptions(options) + require(files.nonEmpty, "Cannot infer schema from an empty set of files") - // TODO: Move filtering. - val paths = files.filterNot(_.getPath.getName startsWith "_").map(_.getPath.toString) - val rdd = baseRdd(sparkSession, csvOptions, paths) - val firstLine = findFirstLine(csvOptions, rdd) + val csvOptions = new CSVOptions(options) + val paths = files.map(_.getPath.toString) + val lines: Dataset[String] = readText(sparkSession, csvOptions, paths) + val firstLine: String = findFirstLine(csvOptions, lines) val firstRow = new CsvReader(csvOptions).parseLine(firstLine) val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis val header = makeSafeHeader(firstRow, csvOptions, caseSensitive) - val parsedRdd = tokenRdd(sparkSession, csvOptions, header, paths) + val parsedRdd: RDD[Array[String]] = CSVRelation.univocityTokenizer( + lines, + firstLine = if (csvOptions.headerFlag) firstLine else null, + params = csvOptions) val schema = if (csvOptions.inferSchemaFlag) { CSVInferSchema.infer(parsedRdd, header, csvOptions) } else { @@ -173,51 +178,37 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { } } - private def baseRdd( - sparkSession: SparkSession, - options: CSVOptions, - inputPaths: Seq[String]): RDD[String] = { - readText(sparkSession, options, inputPaths.mkString(",")) - } - - private def tokenRdd( - sparkSession: SparkSession, - options: CSVOptions, - header: Array[String], - inputPaths: Seq[String]): RDD[Array[String]] = { - val rdd = baseRdd(sparkSession, options, inputPaths) - // Make sure firstLine is materialized before sending to executors - val firstLine = if (options.headerFlag) findFirstLine(options, rdd) else null - CSVRelation.univocityTokenizer(rdd, firstLine, options) - } - /** * Returns the first line of the first non-empty file in path */ - private def findFirstLine(options: CSVOptions, rdd: RDD[String]): String = { + private def findFirstLine(options: CSVOptions, lines: Dataset[String]): String = { + import lines.sqlContext.implicits._ + val nonEmptyLines = lines.filter(length(trim($"value")) > 0) if (options.isCommentSet) { - val comment = options.comment.toString - rdd.filter { line => - line.trim.nonEmpty && !line.startsWith(comment) - }.first() + nonEmptyLines.filter(!$"value".startsWith(options.comment.toString)).first() } else { - rdd.filter { line => - line.trim.nonEmpty - }.first() + nonEmptyLines.first() } } private def readText( sparkSession: SparkSession, options: CSVOptions, - location: String): RDD[String] = { + inputPaths: Seq[String]): Dataset[String] = { if (Charset.forName(options.charset) == StandardCharsets.UTF_8) { - sparkSession.sparkContext.textFile(location) + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = inputPaths, + className = classOf[TextFileFormat].getName + ).resolveRelation(checkFilesExist = false)) + .select("value").as[String](Encoders.STRING) } else { val charset = options.charset - sparkSession.sparkContext - .hadoopFile[LongWritable, Text, TextInputFormat](location) + val rdd = sparkSession.sparkContext + .hadoopFile[LongWritable, Text, TextInputFormat](inputPaths.mkString(",")) .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) + sparkSession.createDataset(rdd)(Encoders.STRING) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 52de11d403..e4ce7a94be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -34,12 +34,12 @@ import org.apache.spark.sql.types._ object CSVRelation extends Logging { def univocityTokenizer( - file: RDD[String], + file: Dataset[String], firstLine: String, params: CSVOptions): RDD[Array[String]] = { // If header is set, make sure firstLine is materialized before sending to executors. val commentPrefix = params.comment.toString - file.mapPartitions { iter => + file.rdd.mapPartitions { iter => val parser = new CsvReader(params) val filteredIter = iter.filter { line => line.trim.nonEmpty && !line.startsWith(commentPrefix) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index d5b11e7bec..2bdc432541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -23,7 +23,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils -import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -301,6 +301,7 @@ private[jdbc] class JDBCRDD( rs = stmt.executeQuery() val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) - CompletionIterator[InternalRow, Iterator[InternalRow]](rowsIterator, close()) + CompletionIterator[InternalRow, Iterator[InternalRow]]( + new InterruptibleIterator(context, rowsIterator), close()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index c957914c5a..a9d8ddfe9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -51,13 +51,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) - val jsonFiles = files.filterNot { status => - val name = status.getPath.getName - (name.startsWith("_") && !name.contains("=")) || name.startsWith(".") - }.toArray - val jsonSchema = InferSchema.infer( - createBaseRdd(sparkSession, jsonFiles), + createBaseRdd(sparkSession, files), columnNameOfCorruptRecord, parsedOptions) checkConstraints(jsonSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 031a0fe578..0efe6dae7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -61,7 +61,7 @@ class ParquetFileFormat override def shortName(): String = "parquet" - override def toString: String = "ParquetFormat" + override def toString: String = "Parquet" override def hashCode(): Int = getClass.hashCode() @@ -241,12 +241,7 @@ class ParquetFileFormat commonMetadata: Seq[FileStatus]) private def splitFiles(allFiles: Seq[FileStatus]): FileTypes = { - // Lists `FileStatus`es of all leaf nodes (files) under all base directories. - val leaves = allFiles.filter { f => - isSummaryFile(f.getPath) || - !((f.getPath.getName.startsWith("_") && !f.getPath.getName.contains("=")) || - f.getPath.getName.startsWith(".")) - }.toArray.sortBy(_.getPath.toString) + val leaves = allFiles.toArray.sortBy(_.getPath.toString) FileTypes( data = leaves.filterNot(f => isSummaryFile(f.getPath)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 7154e3e41c..2b2fbddd12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -17,14 +17,11 @@ package org.apache.spark.sql.execution.datasources -import java.util.regex.Pattern - import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable, CatalogUtils, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ @@ -122,9 +119,12 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl } private def checkPartitionColumns(schema: StructType, tableDesc: CatalogTable): CatalogTable = { - val normalizedPartitionCols = tableDesc.partitionColumnNames.map { colName => - normalizeColumnName(tableDesc.identifier, schema, colName, "partition") - } + val normalizedPartitionCols = CatalogUtils.normalizePartCols( + tableName = tableDesc.identifier.unquotedString, + tableCols = schema.map(_.name), + partCols = tableDesc.partitionColumnNames, + resolver = sparkSession.sessionState.conf.resolver) + checkDuplication(normalizedPartitionCols, "partition") if (schema.nonEmpty && normalizedPartitionCols.length == schema.length) { @@ -149,25 +149,21 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl private def checkBucketColumns(schema: StructType, tableDesc: CatalogTable): CatalogTable = { tableDesc.bucketSpec match { - case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) => - val normalizedBucketCols = bucketColumnNames.map { colName => - normalizeColumnName(tableDesc.identifier, schema, colName, "bucket") - } - checkDuplication(normalizedBucketCols, "bucket") - - val normalizedSortCols = sortColumnNames.map { colName => - normalizeColumnName(tableDesc.identifier, schema, colName, "sort") - } - checkDuplication(normalizedSortCols, "sort") - - schema.filter(f => normalizedSortCols.contains(f.name)).map(_.dataType).foreach { + case Some(bucketSpec) => + val normalizedBucketing = CatalogUtils.normalizeBucketSpec( + tableName = tableDesc.identifier.unquotedString, + tableCols = schema.map(_.name), + bucketSpec = bucketSpec, + resolver = sparkSession.sessionState.conf.resolver) + checkDuplication(normalizedBucketing.bucketColumnNames, "bucket") + checkDuplication(normalizedBucketing.sortColumnNames, "sort") + + normalizedBucketing.sortColumnNames.map(schema(_)).map(_.dataType).foreach { case dt if RowOrdering.isOrderable(dt) => // OK case other => failAnalysis(s"Cannot use ${other.simpleString} for sorting column") } - tableDesc.copy( - bucketSpec = Some(BucketSpec(numBuckets, normalizedBucketCols, normalizedSortCols)) - ) + tableDesc.copy(bucketSpec = Some(normalizedBucketing)) case None => tableDesc } @@ -182,19 +178,6 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl } } - private def normalizeColumnName( - tableIdent: TableIdentifier, - schema: StructType, - colName: String, - colType: String): String = { - val tableCols = schema.map(_.name) - val resolver = sparkSession.sessionState.conf.resolver - tableCols.find(resolver(_, colName)).getOrElse { - failAnalysis(s"$colType column $colName is not defined in table $tableIdent, " + - s"defined table columns are: ${tableCols.mkString(", ")}") - } - } - private def failAnalysis(msg: String) = throw new AnalysisException(msg) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 178160cd71..897e535953 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -39,6 +39,8 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { override def shortName(): String = "text" + override def toString: String = "Text" + private def verifySchema(schema: StructType): Unit = { if (schema.size != 1) { throw new AnalysisException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index ce5013daeb..7be5d31d4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -128,8 +128,7 @@ case class BroadcastExchangeExec( } override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - ThreadUtils.awaitResultInForkJoinSafely(relationFuture, timeout) - .asInstanceOf[broadcast.Broadcast[T]] + ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 8821c0dea9..b9f6601ea8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -670,9 +670,9 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap var offset: Long = Platform.LONG_ARRAY_OFFSET val end = len * 8L + Platform.LONG_ARRAY_OFFSET while (offset < end) { - val size = Math.min(buffer.length, (end - offset).toInt) + val size = Math.min(buffer.length, end - offset) Platform.copyMemory(arr, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) - writeBuffer(buffer, 0, size) + writeBuffer(buffer, 0, size.toInt) offset += size } } @@ -710,8 +710,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap var offset: Long = Platform.LONG_ARRAY_OFFSET val end = length * 8L + Platform.LONG_ARRAY_OFFSET while (offset < end) { - val size = Math.min(buffer.length, (end - offset).toInt) - readBuffer(buffer, 0, size) + val size = Math.min(buffer.length, end - offset) + readBuffer(buffer, 0, size.toInt) Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size) offset += size } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 9918ac327f..757fe2185d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -70,10 +70,10 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { val stopEarly = ctx.freshName("stopEarly") ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") - ctx.addNewFunction("shouldStop", s""" + ctx.addNewFunction("stopEarly", s""" @Override - protected boolean shouldStop() { - return !currentRows.isEmpty() || $stopEarly; + protected boolean stopEarly() { + return $stopEarly; } """) val countTerm = ctx.freshName("count") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index dcaf2c76d4..7a5ac48f1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -119,26 +119,23 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val pickle = new Pickler(needConversion) // Input iterator to Python: input rows are grouped so we send them in batches to Python. // For each row, add it to the queue. - val inputIterator = iter.grouped(100).map { inputRows => - val toBePickled = inputRows.map { inputRow => - queue.add(inputRow.asInstanceOf[UnsafeRow]) - val row = projection(inputRow) - if (needConversion) { - EvaluatePython.toJava(row, schema) - } else { - // fast path for these types that does not need conversion in Python - val fields = new Array[Any](row.numFields) - var i = 0 - while (i < row.numFields) { - val dt = dataTypes(i) - fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) - i += 1 - } - fields + val inputIterator = iter.map { inputRow => + queue.add(inputRow.asInstanceOf[UnsafeRow]) + val row = projection(inputRow) + if (needConversion) { + EvaluatePython.toJava(row, schema) + } else { + // fast path for these types that does not need conversion in Python + val fields = new Array[Any](row.numFields) + var i = 0 + while (i < row.numFields) { + val dt = dataTypes(i) + fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) + i += 1 } - }.toArray - pickle.dumps(toBePickled) - } + fields + } + }.grouped(100).map(x => pickle.dumps(x.toArray)) val context = TaskContext.get() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 16e44845d5..69b4b7bb07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{FilterExec, SparkPlan} /** @@ -90,7 +90,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -object ExtractPythonUDFs extends Rule[SparkPlan] { +object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { private def hasPythonUDF(e: Expression): Boolean = { e.find(_.isInstanceOf[PythonUDF]).isDefined @@ -126,10 +126,11 @@ object ExtractPythonUDFs extends Rule[SparkPlan] { plan } else { val attributeMap = mutable.HashMap[PythonUDF, Expression]() + val splitFilter = trySplitFilter(plan) // Rewrite the child that has the input required for the UDF - val newChildren = plan.children.map { child => + val newChildren = splitFilter.children.map { child => // Pick the UDF we are going to evaluate - val validUdfs = udfs.filter { case udf => + val validUdfs = udfs.filter { udf => // Check to make sure that the UDF can be evaluated with only the input of this child. udf.references.subsetOf(child.outputSet) }.toArray // Turn it into an array since iterators cannot be serialized in Scala 2.10 @@ -150,7 +151,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] { sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") } - val rewritten = plan.withNewChildren(newChildren).transformExpressions { + val rewritten = splitFilter.withNewChildren(newChildren).transformExpressions { case p: PythonUDF if attributeMap.contains(p) => attributeMap(p) } @@ -165,4 +166,22 @@ object ExtractPythonUDFs extends Rule[SparkPlan] { } } } + + // Split the original FilterExec to two FilterExecs. Only push down the first few predicates + // that are all deterministic. + private def trySplitFilter(plan: SparkPlan): SparkPlan = { + plan match { + case filter: FilterExec => + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(filter.condition).span(_.deterministic) + val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) + if (pushDown.nonEmpty) { + val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) + FilterExec((rest ++ containingNonDeterministic).reduceLeft(And), newChild) + } else { + filter + } + case o => o + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index 8529ceac30..5a6f9e87f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -52,6 +52,8 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( /** Needed to serialize type T into JSON when using Jackson */ private implicit val manifest = Manifest.classType[T](implicitly[ClassTag[T]].runtimeClass) + protected val minBatchesToRetain = sparkSession.sessionState.conf.minBatchesToRetain + /** * If we delete the old files after compaction at once, there is a race condition in S3: other * processes may see the old files are deleted but still cannot see the compaction file using @@ -152,11 +154,16 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( } override def add(batchId: Long, logs: Array[T]): Boolean = { - if (isCompactionBatch(batchId, compactInterval)) { - compact(batchId, logs) - } else { - super.add(batchId, logs) + val batchAdded = + if (isCompactionBatch(batchId, compactInterval)) { + compact(batchId, logs) + } else { + super.add(batchId, logs) + } + if (batchAdded && isDeletingExpiredLog) { + deleteExpiredLog(batchId) } + batchAdded } /** @@ -167,9 +174,6 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( val validBatches = getValidBatchesBeforeCompactionBatch(batchId, compactInterval) val allLogs = validBatches.flatMap(batchId => super.get(batchId)).flatten ++ logs if (super.add(batchId, compactLogs(allLogs).toArray)) { - if (isDeletingExpiredLog) { - deleteExpiredLog(batchId) - } true } else { // Return false as there is another writer. @@ -210,26 +214,41 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( } /** - * Since all logs before `compactionBatchId` are compacted and written into the - * `compactionBatchId` log file, they can be removed. However, due to the eventual consistency of - * S3, the compaction file may not be seen by other processes at once. So we only delete files - * created `fileCleanupDelayMs` milliseconds ago. + * Delete expired log entries that proceed the currentBatchId and retain + * sufficient minimum number of batches (given by minBatchsToRetain). This + * equates to retaining the earliest compaction log that proceeds + * batch id position currentBatchId + 1 - minBatchesToRetain. All log entries + * prior to the earliest compaction log proceeding that position will be removed. + * However, due to the eventual consistency of S3, the compaction file may not + * be seen by other processes at once. So we only delete files created + * `fileCleanupDelayMs` milliseconds ago. */ - private def deleteExpiredLog(compactionBatchId: Long): Unit = { - val expiredTime = System.currentTimeMillis() - fileCleanupDelayMs - fileManager.list(metadataPath, new PathFilter { - override def accept(path: Path): Boolean = { - try { - val batchId = getBatchIdFromFileName(path.getName) - batchId < compactionBatchId - } catch { - case _: NumberFormatException => - false + private def deleteExpiredLog(currentBatchId: Long): Unit = { + if (compactInterval <= currentBatchId + 1 - minBatchesToRetain) { + // Find the first compaction batch id that maintains minBatchesToRetain + val minBatchId = currentBatchId + 1 - minBatchesToRetain + val minCompactionBatchId = minBatchId - (minBatchId % compactInterval) - 1 + assert(isCompactionBatch(minCompactionBatchId, compactInterval), + s"$minCompactionBatchId is not a compaction batch") + + logInfo(s"Current compact batch id = $currentBatchId " + + s"min compaction batch id to delete = $minCompactionBatchId") + + val expiredTime = System.currentTimeMillis() - fileCleanupDelayMs + fileManager.list(metadataPath, new PathFilter { + override def accept(path: Path): Boolean = { + try { + val batchId = getBatchIdFromFileName(path.getName) + batchId < minCompactionBatchId + } catch { + case _: NumberFormatException => + false + } + } + }).foreach { f => + if (f.getModificationTime <= expiredTime) { + fileManager.delete(f.getPath) } - } - }).foreach { f => - if (f.getModificationTime <= expiredTime) { - fileManager.delete(f.getPath) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala index 4c8cb069d2..5a9a99e111 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/EventTimeWatermarkExec.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.streaming -import scala.math.max - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} @@ -28,24 +26,48 @@ import org.apache.spark.sql.types.MetadataBuilder import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.AccumulatorV2 -/** Tracks the maximum positive long seen. */ -class MaxLong(protected var currentValue: Long = 0) - extends AccumulatorV2[Long, Long] { +/** Class for collecting event time stats with an accumulator */ +case class EventTimeStats(var max: Long, var min: Long, var sum: Long, var count: Long) { + def add(eventTime: Long): Unit = { + this.max = math.max(this.max, eventTime) + this.min = math.min(this.min, eventTime) + this.sum += eventTime + this.count += 1 + } + + def merge(that: EventTimeStats): Unit = { + this.max = math.max(this.max, that.max) + this.min = math.min(this.min, that.min) + this.sum += that.sum + this.count += that.count + } - override def isZero: Boolean = value == 0 - override def value: Long = currentValue - override def copy(): AccumulatorV2[Long, Long] = new MaxLong(currentValue) + def avg: Long = sum / count +} + +object EventTimeStats { + def zero: EventTimeStats = EventTimeStats( + max = Long.MinValue, min = Long.MaxValue, sum = 0L, count = 0L) +} + +/** Accumulator that collects stats on event time in a batch. */ +class EventTimeStatsAccum(protected var currentStats: EventTimeStats = EventTimeStats.zero) + extends AccumulatorV2[Long, EventTimeStats] { + + override def isZero: Boolean = value == EventTimeStats.zero + override def value: EventTimeStats = currentStats + override def copy(): AccumulatorV2[Long, EventTimeStats] = new EventTimeStatsAccum(currentStats) override def reset(): Unit = { - currentValue = 0 + currentStats = EventTimeStats.zero } override def add(v: Long): Unit = { - currentValue = max(v, value) + currentStats.add(v) } - override def merge(other: AccumulatorV2[Long, Long]): Unit = { - currentValue = max(value, other.value) + override def merge(other: AccumulatorV2[Long, EventTimeStats]): Unit = { + currentStats.merge(other.value) } } @@ -54,22 +76,26 @@ class MaxLong(protected var currentValue: Long = 0) * adding appropriate metadata to this column, this operator also tracks the maximum observed event * time. Based on the maximum observed time and a user specified delay, we can calculate the * `watermark` after which we assume we will no longer see late records for a particular time - * period. + * period. Note that event time is measured in milliseconds. */ case class EventTimeWatermarkExec( eventTime: Attribute, delay: CalendarInterval, child: SparkPlan) extends SparkPlan { - // TODO: Use Spark SQL Metrics? - val maxEventTime = new MaxLong - sparkContext.register(maxEventTime) + val eventTimeStats = new EventTimeStatsAccum() + val delayMs = { + val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31 + delay.milliseconds + delay.months * millisPerMonth + } + + sparkContext.register(eventTimeStats) override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitions { iter => val getEventTime = UnsafeProjection.create(eventTime :: Nil, child.output) iter.map { row => - maxEventTime.add(getEventTime(row).getLong(0)) + eventTimeStats.add(getEventTime(row).getLong(0) / 1000) row } } @@ -80,7 +106,7 @@ case class EventTimeWatermarkExec( if (a semanticEquals eventTime) { val updatedMetadata = new MetadataBuilder() .withMetadata(a.metadata) - .putLong(EventTimeWatermark.delayKey, delay.milliseconds) + .putLong(EventTimeWatermark.delayKey, delayMs) .build() a.withMetadata(updatedMetadata) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala index fdea65cb10..25ebe1797b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala @@ -53,4 +53,18 @@ class FileStreamOptions(parameters: CaseInsensitiveMap) extends Logging { /** Options as specified by the user, in a case-insensitive map, without "path" set. */ val optionMapWithoutPath: Map[String, String] = parameters.filterKeys(_ != "path") + + /** + * Whether to scan latest files first. If it's true, when the source finds unprocessed files in a + * trigger, it will first process the latest files. + */ + val latestFirst: Boolean = parameters.get("latestFirst").map { str => + try { + str.toBoolean + } catch { + case _: IllegalArgumentException => + throw new IllegalArgumentException( + s"Invalid value '$str' for option 'latestFirst', must be 'true' or 'false'") + } + }.getOrElse(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 8494aef004..39c0b49796 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -57,11 +57,20 @@ class FileStreamSource( private val metadataLog = new FileStreamSourceLog(FileStreamSourceLog.VERSION, sparkSession, metadataPath) - private var maxBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) + private var metadataLogCurrentOffset = metadataLog.getLatest().map(_._1).getOrElse(-1L) /** Maximum number of new files to be considered in each batch */ private val maxFilesPerBatch = sourceOptions.maxFilesPerTrigger + private val fileSortOrder = if (sourceOptions.latestFirst) { + logWarning( + """'latestFirst' is true. New files will be processed first. + |It may affect the watermark value""".stripMargin) + implicitly[Ordering[Long]].reverse + } else { + implicitly[Ordering[Long]] + } + /** A mapping from a file that we have processed to some timestamp it was last modified. */ // Visible for testing and debugging in production. val seenFiles = new SeenFilesMap(sourceOptions.maxFileAgeMs) @@ -79,7 +88,7 @@ class FileStreamSource( * `synchronized` on this method is for solving race conditions in tests. In the normal usage, * there is no race here, so the cost of `synchronized` should be rare. */ - private def fetchMaxOffset(): LongOffset = synchronized { + private def fetchMaxOffset(): FileStreamSourceOffset = synchronized { // All the new files found - ignore aged files and files that we have seen. val newFiles = fetchAllFiles().filter { case (path, timestamp) => seenFiles.isNewFile(path, timestamp) @@ -104,14 +113,14 @@ class FileStreamSource( """.stripMargin) if (batchFiles.nonEmpty) { - maxBatchId += 1 - metadataLog.add(maxBatchId, batchFiles.map { case (path, timestamp) => - FileEntry(path = path, timestamp = timestamp, batchId = maxBatchId) + metadataLogCurrentOffset += 1 + metadataLog.add(metadataLogCurrentOffset, batchFiles.map { case (p, timestamp) => + FileEntry(path = p, timestamp = timestamp, batchId = metadataLogCurrentOffset) }.toArray) - logInfo(s"Max batch id increased to $maxBatchId with ${batchFiles.size} new files") + logInfo(s"Log offset set to $metadataLogCurrentOffset with ${batchFiles.size} new files") } - new LongOffset(maxBatchId) + FileStreamSourceOffset(metadataLogCurrentOffset) } /** @@ -122,21 +131,19 @@ class FileStreamSource( func } - /** Return the latest offset in the source */ - def currentOffset: LongOffset = synchronized { - new LongOffset(maxBatchId) - } + /** Return the latest offset in the [[FileStreamSourceLog]] */ + def currentLogOffset: Long = synchronized { metadataLogCurrentOffset } /** * Returns the data that is between the offsets (`start`, `end`]. */ override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - val startId = start.flatMap(LongOffset.convert(_)).getOrElse(LongOffset(-1L)).offset - val endId = LongOffset.convert(end).getOrElse(LongOffset(0)).offset + val startOffset = start.map(FileStreamSourceOffset(_).logOffset).getOrElse(-1L) + val endOffset = FileStreamSourceOffset(end).logOffset - assert(startId <= endId) - val files = metadataLog.get(Some(startId + 1), Some(endId)).flatMap(_._2) - logInfo(s"Processing ${files.length} files from ${startId + 1}:$endId") + assert(startOffset <= endOffset) + val files = metadataLog.get(Some(startOffset + 1), Some(endOffset)).flatMap(_._2) + logInfo(s"Processing ${files.length} files from ${startOffset + 1}:$endOffset") logTrace(s"Files are:\n\t" + files.mkString("\n\t")) val newDataSource = DataSource( @@ -157,7 +164,7 @@ class FileStreamSource( val startTime = System.nanoTime val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath) val catalog = new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(new StructType)) - val files = catalog.allFiles().sortBy(_.getModificationTime).map { status => + val files = catalog.allFiles().sortBy(_.getModificationTime)(fileSortOrder).map { status => (status.getPath.toUri.toString, status.getModificationTime) } val endTime = System.nanoTime @@ -172,7 +179,7 @@ class FileStreamSource( files } - override def getOffset: Option[Offset] = Some(fetchMaxOffset()).filterNot(_.offset == -1) + override def getOffset: Option[Offset] = Some(fetchMaxOffset()).filterNot(_.logOffset == -1) override def toString: String = s"FileStreamSource[$qualifiedBasePath]" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala index 327b3ac267..81908c0cef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala @@ -78,7 +78,7 @@ class FileStreamSourceLog( override def get(startId: Option[Long], endId: Option[Long]): Array[(Long, Array[FileEntry])] = { val startBatchId = startId.getOrElse(0L) - val endBatchId = getLatest().map(_._1).getOrElse(0L) + val endBatchId = endId.orElse(getLatest().map(_._1)).getOrElse(0L) val (existedBatches, removedBatches) = (startBatchId to endBatchId).map { id => if (isCompactionBatch(id, compactInterval) && fileEntryCache.containsKey(id)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceOffset.scala new file mode 100644 index 0000000000..06d0fe6c18 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceOffset.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.util.control.Exception._ + +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + +/** + * Offset for the [[FileStreamSource]]. + * @param logOffset Position in the [[FileStreamSourceLog]] + */ +case class FileStreamSourceOffset(logOffset: Long) extends Offset { + override def json: String = { + Serialization.write(this)(FileStreamSourceOffset.format) + } +} + +object FileStreamSourceOffset { + implicit val format = Serialization.formats(NoTypeHints) + + def apply(offset: Offset): FileStreamSourceOffset = { + offset match { + case f: FileStreamSourceOffset => f + case SerializedOffset(str) => + catching(classOf[NumberFormatException]).opt { + FileStreamSourceOffset(str.toLong) + }.getOrElse { + Serialization.read[FileStreamSourceOffset](str) + } + case _ => + throw new IllegalArgumentException( + s"Invalid conversion from offset of ${offset.getClass} to FileStreamSourceOffset") + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala index c93fcfb77c..de09fb568d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -18,9 +18,8 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.TaskContext -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Encoder, ForeachWriter} -import org.apache.spark.sql.catalyst.plans.logical.CatalystSerde +import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter} +import org.apache.spark.sql.catalyst.encoders.encoderFor /** * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by @@ -32,46 +31,26 @@ import org.apache.spark.sql.catalyst.plans.logical.CatalystSerde class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable { override def addBatch(batchId: Long, data: DataFrame): Unit = { - // TODO: Refine this method when SPARK-16264 is resolved; see comments below. - // This logic should've been as simple as: // ``` // data.as[T].foreachPartition { iter => ... } // ``` // // Unfortunately, doing that would just break the incremental planing. The reason is, - // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` just - // does not support `IncrementalExecution`. + // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` will + // create a new plan. Because StreamExecution uses the existing plan to collect metrics and + // update watermark, we should never create a new plan. Otherwise, metrics and watermark are + // updated in the new plan, and StreamExecution cannot retrieval them. // - // So as a provisional fix, below we've made a special version of `Dataset` with its `rdd()` - // method supporting incremental planning. But in the long run, we should generally make newly - // created Datasets use `IncrementalExecution` where necessary (which is SPARK-16264 tries to - // resolve). - val incrementalExecution = data.queryExecution.asInstanceOf[IncrementalExecution] - val datasetWithIncrementalExecution = - new Dataset(data.sparkSession, incrementalExecution, implicitly[Encoder[T]]) { - override lazy val rdd: RDD[T] = { - val objectType = exprEnc.deserializer.dataType - val deserialized = CatalystSerde.deserialize[T](logicalPlan) - - // was originally: sparkSession.sessionState.executePlan(deserialized) ... - val newIncrementalExecution = new IncrementalExecution( - this.sparkSession, - deserialized, - incrementalExecution.outputMode, - incrementalExecution.checkpointLocation, - incrementalExecution.currentBatchId, - incrementalExecution.currentEventTimeWatermark) - newIncrementalExecution.toRdd.mapPartitions { rows => - rows.map(_.get(0, objectType)) - }.asInstanceOf[RDD[T]] - } - } - datasetWithIncrementalExecution.foreachPartition { iter => + // Hence, we need to manually convert internal rows to objects using encoder. + val encoder = encoderFor[T].resolveAndBind( + data.logicalPlan.output, + data.sparkSession.sessionState.analyzer) + data.queryExecution.toRdd.foreachPartition { iter => if (writer.open(TaskContext.getPartitionId(), batchId)) { try { while (iter.hasNext) { - writer.process(iter.next()) + writer.process(encoder.fromRow(iter.next())) } } catch { case e: Throwable => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 7469caeee3..e5a1997d6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -17,13 +17,16 @@ package org.apache.spark.sql.execution.streaming +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + /** * An ordered collection of offsets, used to track the progress of processing data from one or more * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance * vector clock that must progress linearly forward. */ -case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[String] = None) { +case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMetadata] = None) { /** * Unpacks an offset into [[StreamProgress]] by associating each offset with the order list of @@ -54,6 +57,26 @@ object OffsetSeq { * `nulls` in the sequence are converted to `None`s. */ def fill(metadata: Option[String], offsets: Offset*): OffsetSeq = { - OffsetSeq(offsets.map(Option(_)), metadata) + OffsetSeq(offsets.map(Option(_)), metadata.map(OffsetSeqMetadata.apply)) } } + + +/** + * Contains metadata associated with a [[OffsetSeq]]. This information is + * persisted to the offset log in the checkpoint location via the [[OffsetSeq]] metadata field. + * + * @param batchWatermarkMs: The current eventTime watermark, used to + * bound the lateness of data that will processed. Time unit: milliseconds + * @param batchTimestampMs: The current batch processing timestamp. + * Time unit: milliseconds + */ +case class OffsetSeqMetadata(var batchWatermarkMs: Long = 0, var batchTimestampMs: Long = 0) { + def json: String = Serialization.write(this)(OffsetSeqMetadata.format) +} + +object OffsetSeqMetadata { + private implicit val format = Serialization.formats(NoTypeHints) + def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index cc25b4474b..3210d8ad64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -74,7 +74,7 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) // write metadata out.write('\n') - out.write(offsetSeq.metadata.getOrElse("").getBytes(UTF_8)) + out.write(offsetSeq.metadata.map(_.json).getOrElse("").getBytes(UTF_8)) // write offsets, one per line offsetSeq.offsets.map(_.map(_.json)).foreach { offset => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index ba77e7c7bf..c5e9eae607 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.execution.streaming -import java.util.UUID +import java.text.SimpleDateFormat +import java.util.{Date, TimeZone, UUID} import scala.collection.mutable import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock /** @@ -39,10 +41,13 @@ import org.apache.spark.util.Clock trait ProgressReporter extends Logging { case class ExecutionStats( - inputRows: Map[Source, Long], stateOperators: Seq[StateOperatorProgress]) + inputRows: Map[Source, Long], + stateOperators: Seq[StateOperatorProgress], + eventTimeStats: Map[String, String]) // Internal state of the stream, required for computing metrics. protected def id: UUID + protected def runId: UUID protected def name: String protected def triggerClock: Clock protected def logicalPlan: LogicalPlan @@ -52,9 +57,10 @@ trait ProgressReporter extends Logging { protected def committedOffsets: StreamProgress protected def sources: Seq[Source] protected def sink: Sink - protected def streamExecutionMetadata: StreamExecutionMetadata + protected def offsetSeqMetadata: OffsetSeqMetadata protected def currentBatchId: Long protected def sparkSession: SparkSession + protected def postEvent(event: StreamingQueryListener.Event): Unit // Local timestamps and counters. private var currentTriggerStartTimestamp = -1L @@ -69,6 +75,15 @@ trait ProgressReporter extends Logging { /** Holds the most recent query progress updates. Accesses must lock on the queue itself. */ private val progressBuffer = new mutable.Queue[StreamingQueryProgress]() + private val noDataProgressEventInterval = + sparkSession.sessionState.conf.streamingNoDataProgressEventInterval + + // The timestamp we report an event that has no input data + private var lastNoDataProgressEventTime = Long.MinValue + + private val timestampFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601 + timestampFormat.setTimeZone(TimeZone.getTimeZone("UTC")) + @volatile protected var currentStatus: StreamingQueryStatus = { new StreamingQueryStatus( @@ -81,13 +96,13 @@ trait ProgressReporter extends Logging { def status: StreamingQueryStatus = currentStatus /** Returns an array containing the most recent query progress updates. */ - def recentProgresses: Array[StreamingQueryProgress] = progressBuffer.synchronized { + def recentProgress: Array[StreamingQueryProgress] = progressBuffer.synchronized { progressBuffer.toArray } - /** Returns the most recent query progress update. */ + /** Returns the most recent query progress update or null if there were no progress updates. */ def lastProgress: StreamingQueryProgress = progressBuffer.synchronized { - progressBuffer.last + progressBuffer.lastOption.orNull } /** Begins recording statistics about query progress for a given trigger. */ @@ -99,16 +114,22 @@ trait ProgressReporter extends Logging { currentDurationsMs.clear() } + private def updateProgress(newProgress: StreamingQueryProgress): Unit = { + progressBuffer.synchronized { + progressBuffer += newProgress + while (progressBuffer.length >= sparkSession.sqlContext.conf.streamingProgressRetention) { + progressBuffer.dequeue() + } + } + postEvent(new QueryProgressEvent(newProgress)) + logInfo(s"Streaming query made progress: $newProgress") + } + /** Finalizes the query progress and adds it to list of recent status updates. */ protected def finishTrigger(hasNewData: Boolean): Unit = { currentTriggerEndTimestamp = triggerClock.getTimeMillis() - val executionStats: ExecutionStats = if (!hasNewData) { - ExecutionStats(Map.empty, Seq.empty) - } else { - extractExecutionStats - } - + val executionStats = extractExecutionStats(hasNewData) val processingTimeSec = (currentTriggerEndTimestamp - currentTriggerStartTimestamp).toDouble / 1000 @@ -134,28 +155,42 @@ trait ProgressReporter extends Logging { val newProgress = new StreamingQueryProgress( id = id, + runId = runId, name = name, - timestamp = currentTriggerStartTimestamp, + timestamp = formatTimestamp(currentTriggerStartTimestamp), batchId = currentBatchId, - durationMs = currentDurationsMs.toMap.mapValues(long2Long).asJava, - currentWatermark = streamExecutionMetadata.batchWatermarkMs, + durationMs = new java.util.HashMap(currentDurationsMs.toMap.mapValues(long2Long).asJava), + eventTime = new java.util.HashMap(executionStats.eventTimeStats.asJava), stateOperators = executionStats.stateOperators.toArray, sources = sourceProgress.toArray, sink = sinkProgress) - progressBuffer.synchronized { - progressBuffer += newProgress - while (progressBuffer.length >= sparkSession.sqlContext.conf.streamingProgressRetention) { - progressBuffer.dequeue() + if (hasNewData) { + // Reset noDataEventTimestamp if we processed any data + lastNoDataProgressEventTime = Long.MinValue + updateProgress(newProgress) + } else { + val now = triggerClock.getTimeMillis() + if (now - noDataProgressEventInterval >= lastNoDataProgressEventTime) { + lastNoDataProgressEventTime = now + updateProgress(newProgress) } } - logInfo(s"Streaming query made progress: $newProgress") currentStatus = currentStatus.copy(isTriggerActive = false) } /** Extracts statistics from the most recent query execution. */ - private def extractExecutionStats: ExecutionStats = { + private def extractExecutionStats(hasNewData: Boolean): ExecutionStats = { + val hasEventTime = logicalPlan.collect { case e: EventTimeWatermark => e }.nonEmpty + val watermarkTimestamp = + if (hasEventTime) Map("watermark" -> formatTimestamp(offsetSeqMetadata.batchWatermarkMs)) + else Map.empty[String, String] + + if (!hasNewData) { + return ExecutionStats(Map.empty, Seq.empty, watermarkTimestamp) + } + // We want to associate execution plan leaves to sources that generate them, so that we match // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. // Consider the translation from the streaming logical plan to the final executed plan. @@ -212,7 +247,16 @@ trait ProgressReporter extends Logging { numRowsUpdated = node.metrics.get("numUpdatedStateRows").map(_.value).getOrElse(0L)) } - ExecutionStats(numInputRows, stateOperators) + val eventTimeStats = lastExecution.executedPlan.collect { + case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => + val stats = e.eventTimeStats.value + Map( + "max" -> stats.max, + "min" -> stats.min, + "avg" -> stats.avg).mapValues(formatTimestamp) + }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp + + ExecutionStats(numInputRows, stateOperators, eventTimeStats) } /** Records the duration of running `body` for the next query progress update. */ @@ -228,6 +272,10 @@ trait ProgressReporter extends Logging { result } + private def formatTimestamp(millis: Long): String = { + timestampFormat.format(new Date(millis)) + } + /** Updates the message returned in `status`. */ protected def updateStatusMessage(message: String): Unit = { currentStatus = currentStatus.copy(message = message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index 7af978a9c4..0551e4b4a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -21,11 +21,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratePredicate, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution -import org.apache.spark.sql.InternalOutputModes._ -import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.SparkPlan @@ -108,6 +108,30 @@ case class StateStoreSaveExec( "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) + /** Generate a predicate that matches data older than the watermark */ + private lazy val watermarkPredicate: Option[Predicate] = { + val optionalWatermarkAttribute = + keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)) + + optionalWatermarkAttribute.map { watermarkAttribute => + // If we are evicting based on a window, use the end of the window. Otherwise just + // use the attribute itself. + val evictionExpression = + if (watermarkAttribute.dataType.isInstanceOf[StructType]) { + LessThanOrEqual( + GetStructField(watermarkAttribute, 1), + Literal(eventTimeWatermark.get * 1000)) + } else { + LessThanOrEqual( + watermarkAttribute, + Literal(eventTimeWatermark.get * 1000)) + } + + logInfo(s"Filtering state store on: $evictionExpression") + newPredicate(evictionExpression, keyExpressions) + } + } + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver assert(outputMode.nonEmpty, @@ -151,25 +175,8 @@ case class StateStoreSaveExec( numUpdatedStateRows += 1 } - val watermarkAttribute = - keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)).get - // If we are evicting based on a window, use the end of the window. Otherwise just - // use the attribute itself. - val evictionExpression = - if (watermarkAttribute.dataType.isInstanceOf[StructType]) { - LessThanOrEqual( - GetStructField(watermarkAttribute, 1), - Literal(eventTimeWatermark.get * 1000)) - } else { - LessThanOrEqual( - watermarkAttribute, - Literal(eventTimeWatermark.get * 1000)) - } - - logInfo(s"Filtering state store on: $evictionExpression") - val predicate = newPredicate(evictionExpression, keyExpressions) - store.remove(predicate.eval) - + // Assumption: Append mode can be done only when watermark has been specified + store.remove(watermarkPredicate.get.eval) store.commit() numTotalStateRows += store.numKeys() @@ -180,11 +187,19 @@ case class StateStoreSaveExec( // Update and output modified rows from the StateStore. case Some(Update) => + new Iterator[InternalRow] { - private[this] val baseIterator = iter + + // Filter late date using watermark if specified + private[this] val baseIterator = watermarkPredicate match { + case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + case None => iter + } override def hasNext: Boolean = { if (!baseIterator.hasNext) { + // Remove old aggregates if watermark specified + if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval) store.commit() numTotalStateRows += store.numKeys() false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 6d0e269d34..a35950e2dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -25,15 +25,12 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import org.apache.hadoop.fs.Path -import org.json4s.NoTypeHints -import org.json4s.jackson.Serialization import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.streaming._ @@ -49,7 +46,7 @@ class StreamExecution( override val sparkSession: SparkSession, override val name: String, checkpointRoot: String, - val logicalPlan: LogicalPlan, + analyzedPlan: LogicalPlan, val sink: Sink, val trigger: Trigger, val triggerClock: Clock, @@ -58,17 +55,18 @@ class StreamExecution( import org.apache.spark.sql.streaming.StreamingQueryListener._ - // TODO: restore this from the checkpoint directory. - override val id: UUID = UUID.randomUUID() - private val pollingDelayMs = sparkSession.sessionState.conf.streamingPollingDelay + private val minBatchesToRetain = sparkSession.sessionState.conf.minBatchesToRetain + require(minBatchesToRetain > 0, "minBatchesToRetain has to be positive") + /** * A lock used to wait/notify when batches complete. Use a fair lock to avoid thread starvation. */ private val awaitBatchLock = new ReentrantLock(true) private val awaitBatchLockCondition = awaitBatchLock.newCondition() + private val initializationLatch = new CountDownLatch(1) private val startLatch = new CountDownLatch(1) private val terminationLatch = new CountDownLatch(1) @@ -90,20 +88,65 @@ class StreamExecution( * once, since the field's value may change at any time. */ @volatile - protected var availableOffsets = new StreamProgress + var availableOffsets = new StreamProgress /** The current batchId or -1 if execution has not yet been initialized. */ protected var currentBatchId: Long = -1 - /** Stream execution metadata */ - protected var streamExecutionMetadata = StreamExecutionMetadata() + /** Metadata associated with the whole query */ + protected val streamMetadata: StreamMetadata = { + val metadataPath = new Path(checkpointFile("metadata")) + val hadoopConf = sparkSession.sessionState.newHadoopConf() + StreamMetadata.read(metadataPath, hadoopConf).getOrElse { + val newMetadata = new StreamMetadata(UUID.randomUUID.toString) + StreamMetadata.write(newMetadata, metadataPath, hadoopConf) + newMetadata + } + } + + /** Metadata associated with the offset seq of a batch in the query. */ + protected var offsetSeqMetadata = OffsetSeqMetadata() - /** All stream sources present in the query plan. */ - protected val sources = - logicalPlan.collect { case s: StreamingExecutionRelation => s.source } + override val id: UUID = UUID.fromString(streamMetadata.id) - /** A list of unique sources in the query plan. */ - private val uniqueSources = sources.distinct + override val runId: UUID = UUID.randomUUID + + /** + * Pretty identified string of printing in logs. Format is + * If name is set "queryName [id = xyz, runId = abc]" else "[id = xyz, runId = abc]" + */ + private val prettyIdString = + Option(name).map(_ + " ").getOrElse("") + s"[id = $id, runId = $runId]" + + /** + * All stream sources present in the query plan. This will be set when generating logical plan. + */ + @volatile protected var sources: Seq[Source] = Seq.empty + + /** + * A list of unique sources in the query plan. This will be set when generating logical plan. + */ + @volatile private var uniqueSources: Seq[Source] = Seq.empty + + override lazy val logicalPlan: LogicalPlan = { + assert(microBatchThread eq Thread.currentThread, + "logicalPlan must be initialized in StreamExecutionThread " + + s"but the current thread was ${Thread.currentThread}") + var nextSourceId = 0L + val _logicalPlan = analyzedPlan.transform { + case StreamingRelation(dataSource, _, output) => + // Materialize source to avoid creating it in every batch + val metadataPath = s"$checkpointRoot/sources/$nextSourceId" + val source = dataSource.createSource(metadataPath) + nextSourceId += 1 + // We still need to use the previous `output` instead of `source.schema` as attributes in + // "df.logicalPlan" has already used attributes of the previous `output`. + StreamingExecutionRelation(source, output) + } + sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source } + uniqueSources = sources.distinct + _logicalPlan + } private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) @@ -111,7 +154,7 @@ class StreamExecution( /** Defines the internal state of execution */ @volatile - private var state: State = INITIALIZED + private var state: State = INITIALIZING @volatile var lastExecution: QueryExecution = _ @@ -125,8 +168,9 @@ class StreamExecution( /* Get the call site in the caller thread; will pass this into the micro batch thread */ private val callSite = Utils.getCallSite() - /** Used to report metrics to coda-hale. */ - lazy val streamMetrics = new MetricsReporter(this, s"spark.streaming.$name") + /** Used to report metrics to coda-hale. This uses id for easier tracking across restarts. */ + lazy val streamMetrics = new MetricsReporter( + this, s"spark.streaming.${Option(name).getOrElse(id)}") /** * The thread that runs the micro-batches of this stream. Note that this thread must be @@ -134,7 +178,7 @@ class StreamExecution( * [[HDFSMetadataLog]]. See SPARK-14131 for more details. */ val microBatchThread = - new StreamExecutionThread(s"stream execution thread for $name") { + new StreamExecutionThread(s"stream execution thread for $prettyIdString") { override def run(): Unit = { // To fix call site like "run at :0", we bridge the call site from the caller // thread to this micro batch thread @@ -151,8 +195,11 @@ class StreamExecution( */ val offsetLog = new OffsetSeqLog(sparkSession, checkpointFile("offsets")) + /** Whether all fields of the query have been initialized */ + private def isInitialized: Boolean = state != INITIALIZING + /** Whether the query is currently active or not */ - override def isActive: Boolean = state == ACTIVE + override def isActive: Boolean = state != TERMINATED /** Returns the [[StreamingQueryException]] if the query was terminated by an exception. */ override def exception: Option[StreamingQueryException] = Option(streamDeathCause) @@ -181,14 +228,12 @@ class StreamExecution( */ private def runBatches(): Unit = { try { - // Mark ACTIVE and then post the event. QueryStarted event is synchronously sent to listeners, - // so must mark this as ACTIVE first. - state = ACTIVE if (sparkSession.sessionState.conf.streamingMetricsEnabled) { sparkSession.sparkContext.env.metricsSystem.registerSource(streamMetrics) } - postEvent(new QueryStartedEvent(id, name)) // Assumption: Does not throw exception. + // `postEvent` does not throw non fatal exception. + postEvent(new QueryStartedEvent(id, runId, name)) // Unblock starting thread startLatch.countDown() @@ -196,6 +241,13 @@ class StreamExecution( // While active, repeatedly attempt to run batches. SparkSession.setActiveSession(sparkSession) + updateStatusMessage("Initializing sources") + // force initialization of the logical plan so that the sources can be created + logicalPlan + state = ACTIVE + // Unblock `awaitInitialization` + initializationLatch.countDown() + triggerExecutor.execute(() => { startTrigger() @@ -218,8 +270,6 @@ class StreamExecution( // Report trigger as finished and construct progress object. finishTrigger(dataAvailable) - postEvent(new QueryProgressEvent(lastProgress)) - if (dataAvailable) { // We'll increase currentBatchId after we complete processing current batch's data currentBatchId += 1 @@ -244,11 +294,12 @@ class StreamExecution( updateStatusMessage("Stopped") case e: Throwable => streamDeathCause = new StreamingQueryException( - this, - s"Query $name terminated with exception: ${e.getMessage}", + toDebugString(includeLogicalPlan = isInitialized), + s"Query $prettyIdString terminated with exception: ${e.getMessage}", e, - Some(committedOffsets.toOffsetSeq(sources, streamExecutionMetadata.json))) - logError(s"Query $name terminated with error", e) + committedOffsets.toOffsetSeq(sources, offsetSeqMetadata).toString, + availableOffsets.toOffsetSeq(sources, offsetSeqMetadata).toString) + logError(s"Query $prettyIdString terminated with error", e) updateStatusMessage(s"Terminated with exception: ${e.getMessage}") // Rethrow the fatal errors to allow the user using `Thread.UncaughtExceptionHandler` to // handle them @@ -256,17 +307,25 @@ class StreamExecution( throw e } } finally { - state = TERMINATED - currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false) + // Release latches to unblock the user codes since exception can happen in any place and we + // may not get a chance to release them + startLatch.countDown() + initializationLatch.countDown() - // Update metrics and status - sparkSession.sparkContext.env.metricsSystem.removeSource(streamMetrics) + try { + state = TERMINATED + currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false) + + // Update metrics and status + sparkSession.sparkContext.env.metricsSystem.removeSource(streamMetrics) - // Notify others - sparkSession.streams.notifyQueryTermination(StreamExecution.this) - postEvent( - new QueryTerminatedEvent(id, exception.map(_.cause).map(Utils.exceptionString))) - terminationLatch.countDown() + // Notify others + sparkSession.streams.notifyQueryTermination(StreamExecution.this) + postEvent( + new QueryTerminatedEvent(id, runId, exception.map(_.cause).map(Utils.exceptionString))) + } finally { + terminationLatch.countDown() + } } } @@ -284,9 +343,9 @@ class StreamExecution( logInfo(s"Resuming streaming query, starting with batch $batchId") currentBatchId = batchId availableOffsets = nextOffsets.toStreamProgress(sources) - streamExecutionMetadata = StreamExecutionMetadata(nextOffsets.metadata.getOrElse("{}")) + offsetSeqMetadata = nextOffsets.metadata.getOrElse(OffsetSeqMetadata()) logDebug(s"Found possibly unprocessed offsets $availableOffsets " + - s"at batch timestamp ${streamExecutionMetadata.batchTimestampMs}") + s"at batch timestamp ${offsetSeqMetadata.batchTimestampMs}") offsetLog.get(batchId - 1).foreach { case lastOffsets => @@ -342,15 +401,33 @@ class StreamExecution( } if (hasNewData) { // Current batch timestamp in milliseconds - streamExecutionMetadata.batchTimestampMs = triggerClock.getTimeMillis() + offsetSeqMetadata.batchTimestampMs = triggerClock.getTimeMillis() + // Update the eventTime watermark if we find one in the plan. + if (lastExecution != null) { + lastExecution.executedPlan.collect { + case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => + logDebug(s"Observed event time stats: ${e.eventTimeStats.value}") + e.eventTimeStats.value.max - e.delayMs + }.headOption.foreach { newWatermarkMs => + if (newWatermarkMs > offsetSeqMetadata.batchWatermarkMs) { + logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") + offsetSeqMetadata.batchWatermarkMs = newWatermarkMs + } else { + logDebug( + s"Event time didn't move: $newWatermarkMs < " + + s"${offsetSeqMetadata.batchWatermarkMs}") + } + } + } + updateStatusMessage("Writing offsets to log") reportTimeTaken("walCommit") { assert(offsetLog.add( currentBatchId, - availableOffsets.toOffsetSeq(sources, streamExecutionMetadata.json)), + availableOffsets.toOffsetSeq(sources, offsetSeqMetadata)), s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") logInfo(s"Committed offsets for batch $currentBatchId. " + - s"Metadata ${streamExecutionMetadata.toString}") + s"Metadata ${offsetSeqMetadata.toString}") // NOTE: The following code is correct because runBatches() processes exactly one // batch at a time. If we add pipeline parallelism (multiple batches in flight at @@ -365,10 +442,11 @@ class StreamExecution( } } - // Now that we have logged the new batch, no further processing will happen for - // the batch before the previous batch, and it is safe to discard the old metadata. + // It is now safe to discard the metadata beyond the minimum number to retain. // Note that purge is exclusive, i.e. it purges everything before the target ID. - offsetLog.purge(currentBatchId - 1) + if (minBatchesToRetain < currentBatchId) { + offsetLog.purge(currentBatchId - minBatchesToRetain) + } } } else { awaitBatchLock.lock() @@ -420,21 +498,21 @@ class StreamExecution( val triggerLogicalPlan = withNewSources transformAllExpressions { case a: Attribute if replacementMap.contains(a) => replacementMap(a) case ct: CurrentTimestamp => - CurrentBatchTimestamp(streamExecutionMetadata.batchTimestampMs, + CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, ct.dataType) case cd: CurrentDate => - CurrentBatchTimestamp(streamExecutionMetadata.batchTimestampMs, + CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, cd.dataType) } - val executedPlan = reportTimeTaken("queryPlanning") { + reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSession, triggerLogicalPlan, outputMode, checkpointFile("state"), currentBatchId, - streamExecutionMetadata.batchWatermarkMs) + offsetSeqMetadata.batchWatermarkMs) lastExecution.executedPlan // Force the lazy generation of execution plan } @@ -445,21 +523,6 @@ class StreamExecution( sink.addBatch(currentBatchId, nextBatch) } - // Update the eventTime watermark if we find one in the plan. - lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec => - logTrace(s"Maximum observed eventTime: ${e.maxEventTime.value}") - (e.maxEventTime.value / 1000) - e.delay.milliseconds() - }.headOption.foreach { newWatermark => - if (newWatermark > streamExecutionMetadata.batchWatermarkMs) { - logInfo(s"Updating eventTime watermark to: $newWatermark ms") - streamExecutionMetadata.batchWatermarkMs = newWatermark - } else { - logTrace(s"Event time didn't move: $newWatermark < " + - s"$streamExecutionMetadata.currentEventTimeWatermark") - } - } - awaitBatchLock.lock() try { // Wake up any threads that are waiting for the stream to progress. @@ -469,7 +532,7 @@ class StreamExecution( } } - private def postEvent(event: StreamingQueryListener.Event) { + override protected def postEvent(event: StreamingQueryListener.Event): Unit = { sparkSession.streams.postListenerEvent(event) } @@ -486,7 +549,7 @@ class StreamExecution( microBatchThread.join() } uniqueSources.foreach(_.stop()) - logInfo(s"Query $name was stopped") + logInfo(s"Query $prettyIdString was stopped") } /** @@ -494,6 +557,7 @@ class StreamExecution( * least the given `Offset`. This method is intended for use primarily when writing tests. */ private[sql] def awaitOffset(source: Source, newOffset: Offset): Unit = { + assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset @@ -516,7 +580,38 @@ class StreamExecution( /** A flag to indicate that a batch has completed with no new data available. */ @volatile private var noNewData = false + /** + * Assert that the await APIs should not be called in the stream thread. Otherwise, it may cause + * dead-lock, e.g., calling any await APIs in `StreamingQueryListener.onQueryStarted` will block + * the stream thread forever. + */ + private def assertAwaitThread(): Unit = { + if (microBatchThread eq Thread.currentThread) { + throw new IllegalStateException( + "Cannot wait for a query state from the same thread that is running the query") + } + } + + /** + * Await until all fields of the query have been initialized. + */ + def awaitInitialization(timeoutMs: Long): Unit = { + assertAwaitThread() + require(timeoutMs > 0, "Timeout has to be positive") + if (streamDeathCause != null) { + throw streamDeathCause + } + initializationLatch.await(timeoutMs, TimeUnit.MILLISECONDS) + if (streamDeathCause != null) { + throw streamDeathCause + } + } + override def processAllAvailable(): Unit = { + assertAwaitThread() + if (streamDeathCause != null) { + throw streamDeathCause + } awaitBatchLock.lock() try { noNewData = false @@ -535,9 +630,7 @@ class StreamExecution( } override def awaitTermination(): Unit = { - if (state == INITIALIZED) { - throw new IllegalStateException("Cannot wait for termination on a query that has not started") - } + assertAwaitThread() terminationLatch.await() if (streamDeathCause != null) { throw streamDeathCause @@ -545,9 +638,7 @@ class StreamExecution( } override def awaitTermination(timeoutMs: Long): Boolean = { - if (state == INITIALIZED) { - throw new IllegalStateException("Cannot wait for termination on a query that has not started") - } + assertAwaitThread() require(timeoutMs > 0, "Timeout has to be positive") terminationLatch.await(timeoutMs, TimeUnit.MILLISECONDS) if (streamDeathCause != null) { @@ -577,61 +668,31 @@ class StreamExecution( override def explain(): Unit = explain(extended = false) override def toString: String = { - s"Streaming Query - $name [state = $state]" + s"Streaming Query $prettyIdString [state = $state]" } - def toDebugString: String = { - val deathCauseStr = if (streamDeathCause != null) { - "Error:\n" + stackTraceToString(streamDeathCause.cause) - } else "" - s""" - |=== Streaming Query === - |Name: $name - |Current Offsets: $committedOffsets - | - |Current State: $state - |Thread State: ${microBatchThread.getState} - | - |Logical Plan: - |$logicalPlan - | - |$deathCauseStr - """.stripMargin + private def toDebugString(includeLogicalPlan: Boolean): String = { + val debugString = + s"""|=== Streaming Query === + |Identifier: $prettyIdString + |Current Committed Offsets: $committedOffsets + |Current Available Offsets: $availableOffsets + | + |Current State: $state + |Thread State: ${microBatchThread.getState}""".stripMargin + if (includeLogicalPlan) { + debugString + s"\n\nLogical Plan:\n$logicalPlan" + } else { + debugString + } } trait State - case object INITIALIZED extends State + case object INITIALIZING extends State case object ACTIVE extends State case object TERMINATED extends State } -/** - * Contains metadata associated with a stream execution. This information is - * persisted to the offset log via the OffsetSeq metadata field. Current - * information contained in this object includes: - * - * @param batchWatermarkMs: The current eventTime watermark, used to - * bound the lateness of data that will processed. Time unit: milliseconds - * @param batchTimestampMs: The current batch processing timestamp. - * Time unit: milliseconds - */ -case class StreamExecutionMetadata( - var batchWatermarkMs: Long = 0, - var batchTimestampMs: Long = 0) { - private implicit val formats = StreamExecutionMetadata.formats - - /** - * JSON string representation of this object. - */ - def json: String = Serialization.write(this) -} - -object StreamExecutionMetadata { - private implicit val formats = Serialization.formats(NoTypeHints) - - def apply(json: String): StreamExecutionMetadata = - Serialization.read[StreamExecutionMetadata](json) -} /** * A special thread to run the stream query. Some codes require to run in the StreamExecutionThread diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala new file mode 100644 index 0000000000..7807c9fae8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetadata.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{InputStreamReader, OutputStreamWriter} +import java.nio.charset.StandardCharsets + +import scala.util.control.NonFatal + +import org.apache.commons.io.IOUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, FSDataOutputStream, Path} +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.streaming.StreamingQuery + +/** + * Contains metadata associated with a [[StreamingQuery]]. This information is written + * in the checkpoint location the first time a query is started and recovered every time the query + * is restarted. + * + * @param id unique id of the [[StreamingQuery]] that needs to be persisted across restarts + */ +case class StreamMetadata(id: String) { + def json: String = Serialization.write(this)(StreamMetadata.format) +} + +object StreamMetadata extends Logging { + implicit val format = Serialization.formats(NoTypeHints) + + /** Read the metadata from file if it exists */ + def read(metadataFile: Path, hadoopConf: Configuration): Option[StreamMetadata] = { + val fs = FileSystem.get(hadoopConf) + if (fs.exists(metadataFile)) { + var input: FSDataInputStream = null + try { + input = fs.open(metadataFile) + val reader = new InputStreamReader(input, StandardCharsets.UTF_8) + val metadata = Serialization.read[StreamMetadata](reader) + Some(metadata) + } catch { + case NonFatal(e) => + logError(s"Error reading stream metadata from $metadataFile", e) + throw e + } finally { + IOUtils.closeQuietly(input) + } + } else None + } + + /** Write metadata to file */ + def write( + metadata: StreamMetadata, + metadataFile: Path, + hadoopConf: Configuration): Unit = { + var output: FSDataOutputStream = null + try { + val fs = FileSystem.get(hadoopConf) + output = fs.create(metadataFile) + val writer = new OutputStreamWriter(output) + Serialization.write(metadata, writer) + writer.close() + } catch { + case NonFatal(e) => + logError(s"Error writing stream metadata $metadata to $metadataFile", e) + throw e + } finally { + IOUtils.closeQuietly(output) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala index 21b8750ca9..a3f3662e6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -26,7 +26,7 @@ class StreamProgress( val baseMap: immutable.Map[Source, Offset] = new immutable.HashMap[Source, Offset]) extends scala.collection.immutable.Map[Source, Offset] { - def toOffsetSeq(source: Seq[Source], metadata: String): OffsetSeq = { + def toOffsetSeq(source: Seq[Source], metadata: OffsetSeqMetadata): OffsetSeq = { OffsetSeq(source.map(get), Some(metadata)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala index 22e4c6380f..a2153d27e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID + +import scala.collection.mutable + import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerEvent} import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.ListenerBus @@ -25,7 +29,11 @@ import org.apache.spark.util.ListenerBus * A bus to forward events to [[StreamingQueryListener]]s. This one will send received * [[StreamingQueryListener.Event]]s to the Spark listener bus. It also registers itself with * Spark listener bus, so that it can receive [[StreamingQueryListener.Event]]s and dispatch them - * to StreamingQueryListener. + * to StreamingQueryListeners. + * + * Note that each bus and its registered listeners are associated with a single SparkSession + * and StreamingQueryManager. So this bus will dispatch events to registered listeners for only + * those queries that were started in the associated SparkSession. */ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) extends SparkListener with ListenerBus[StreamingQueryListener, StreamingQueryListener.Event] { @@ -35,12 +43,30 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) sparkListenerBus.addListener(this) /** - * Post a StreamingQueryListener event to the Spark listener bus asynchronously. This event will - * be dispatched to all StreamingQueryListener in the thread of the Spark listener bus. + * RunIds of active queries whose events are supposed to be forwarded by this ListenerBus + * to registered `StreamingQueryListeners`. + * + * Note 1: We need to track runIds instead of ids because the runId is unique for every started + * query, even it its a restart. So even if a query is restarted, this bus will identify them + * separately and correctly account for the restart. + * + * Note 2: This list needs to be maintained separately from the + * `StreamingQueryManager.activeQueries` because a terminated query is cleared from + * `StreamingQueryManager.activeQueries` as soon as it is stopped, but the this ListenerBus + * must clear a query only after the termination event of that query has been posted. + */ + private val activeQueryRunIds = new mutable.HashSet[UUID] + + /** + * Post a StreamingQueryListener event to the added StreamingQueryListeners. + * Note that only the QueryStarted event is posted to the listener synchronously. Other events + * are dispatched to Spark listener bus. This method is guaranteed to be called by queries in + * the same SparkSession as this listener. */ def post(event: StreamingQueryListener.Event) { event match { case s: QueryStartedEvent => + activeQueryRunIds.synchronized { activeQueryRunIds += s.runId } sparkListenerBus.post(s) // post to local listeners to trigger callbacks postToAll(s) @@ -63,18 +89,32 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) } } + /** + * Dispatch events to registered StreamingQueryListeners. Only the events associated queries + * started in the same SparkSession as this ListenerBus will be dispatched to the listeners. + */ override protected def doPostEvent( listener: StreamingQueryListener, event: StreamingQueryListener.Event): Unit = { + def shouldReport(runId: UUID): Boolean = { + activeQueryRunIds.synchronized { activeQueryRunIds.contains(runId) } + } + event match { case queryStarted: QueryStartedEvent => - listener.onQueryStarted(queryStarted) + if (shouldReport(queryStarted.runId)) { + listener.onQueryStarted(queryStarted) + } case queryProgress: QueryProgressEvent => - listener.onQueryProgress(queryProgress) + if (shouldReport(queryProgress.progress.runId)) { + listener.onQueryProgress(queryProgress) + } case queryTerminated: QueryTerminatedEvent => - listener.onQueryTerminated(queryTerminated) + if (shouldReport(queryTerminated.runId)) { + listener.onQueryTerminated(queryTerminated) + activeQueryRunIds.synchronized { activeQueryRunIds -= queryTerminated.runId } + } case _ => } } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala new file mode 100644 index 0000000000..020c9cb4a7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.util.UUID + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamingQueryProgress, StreamingQueryStatus} + +/** + * Wrap non-serializable StreamExecution to make the query serializable as it's easy to for it to + * get captured with normal usage. It's safe to capture the query but not use it in executors. + * However, if the user tries to call its methods, it will throw `IllegalStateException`. + */ +class StreamingQueryWrapper(@transient private val _streamingQuery: StreamExecution) + extends StreamingQuery with Serializable { + + def streamingQuery: StreamExecution = { + /** Assert the codes run in the driver. */ + if (_streamingQuery == null) { + throw new IllegalStateException("StreamingQuery cannot be used in executors") + } + _streamingQuery + } + + override def name: String = { + streamingQuery.name + } + + override def id: UUID = { + streamingQuery.id + } + + override def runId: UUID = { + streamingQuery.runId + } + + override def awaitTermination(): Unit = { + streamingQuery.awaitTermination() + } + + override def awaitTermination(timeoutMs: Long): Boolean = { + streamingQuery.awaitTermination(timeoutMs) + } + + override def stop(): Unit = { + streamingQuery.stop() + } + + override def processAllAvailable(): Unit = { + streamingQuery.processAllAvailable() + } + + override def isActive: Boolean = { + streamingQuery.isActive + } + + override def lastProgress: StreamingQueryProgress = { + streamingQuery.lastProgress + } + + override def explain(): Unit = { + streamingQuery.explain() + } + + override def explain(extended: Boolean): Unit = { + streamingQuery.explain(extended) + } + + /** + * This method is called in Python. Python cannot call "explain" directly as it outputs in the JVM + * process, which may not be visible in Python process. + */ + def explainInternal(extended: Boolean): String = { + streamingQuery.explainInternal(extended) + } + + override def sparkSession: SparkSession = { + streamingQuery.sparkSession + } + + override def recentProgress: Array[StreamingQueryProgress] = { + streamingQuery.recentProgress + } + + override def status: StreamingQueryStatus = { + streamingQuery.status + } + + override def exception: Option[StreamingQueryException] = { + streamingQuery.exception + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index adf6963577..91da6b3846 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -70,11 +71,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def schema: StructType = encoder.schema - def toDS()(implicit sqlContext: SQLContext): Dataset[A] = { + def toDS(): Dataset[A] = { Dataset(sqlContext.sparkSession, logicalPlan) } - def toDF()(implicit sqlContext: SQLContext): DataFrame = { + def toDF(): DataFrame = { Dataset.ofRows(sqlContext.sparkSession, logicalPlan) } @@ -186,16 +187,23 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi }.mkString("\n") } - override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized { - if (latestBatchId.isEmpty || batchId > latestBatchId.get) { + override def addBatch(batchId: Long, data: DataFrame): Unit = { + val notCommitted = synchronized { + latestBatchId.isEmpty || batchId > latestBatchId.get + } + if (notCommitted) { logDebug(s"Committing batch $batchId to $this") outputMode match { - case InternalOutputModes.Append | InternalOutputModes.Update => - batches.append(AddedData(batchId, data.collect())) + case Append | Update => + val rows = AddedData(batchId, data.collect()) + synchronized { batches += rows } - case InternalOutputModes.Complete => - batches.clear() - batches += AddedData(batchId, data.collect()) + case Complete => + val rows = AddedData(batchId, data.collect()) + synchronized { + batches.clear() + batches += rows + } case _ => throw new IllegalArgumentException( @@ -206,7 +214,7 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi } } - def clear(): Unit = { + def clear(): Unit = synchronized { batches.clear() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 493fdaaec5..4f3f8181d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -303,7 +303,6 @@ private[state] class HDFSBackedStateStoreProvider( val mapFromFile = readSnapshotFile(version).getOrElse { val prevMap = loadMap(version - 1) val newMap = new MapType(prevMap) - newMap.putAll(prevMap) updateFromDeltaFile(version, newMap) newMap } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index de72f1cf27..acfaa8e5eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -26,9 +26,11 @@ private[streaming] class StateStoreConf(@transient private val conf: SQLConf) ex val minDeltasForSnapshot = conf.stateStoreMinDeltasForSnapshot - val minVersionsToRetain = conf.stateStoreMinVersionsToRetain + val minVersionsToRetain = conf.minBatchesToRetain } private[streaming] object StateStoreConf { val empty = new StateStoreConf() + + def apply(conf: SQLConf): StateStoreConf = new StateStoreConf(conf) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 6d984621cc..41ed9d7180 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, FunctionIdenti import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.execution.command.AlterTableRecoverPartitionsCommand import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource} import org.apache.spark.sql.types.StructType @@ -393,6 +394,19 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } } + /** + * Recover all the partitions in the directory of a table and update the catalog. + * + * @param tableName the name of the table to be repaired. + * @group ddl_ops + * @since 2.1.1 + */ + override def recoverPartitions(tableName: String): Unit = { + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + sparkSession.sessionState.executePlan( + AlterTableRecoverPartitionsCommand(tableIdent)).toRdd + } + /** * Returns true if the table is currently cached in-memory. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 200f0603e1..cce16264d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -136,7 +136,7 @@ object SQLConf { "That is to say by default the optimizer will not choose to broadcast a table unless it " + "knows for sure its size is small enough.") .longConf - .createWithDefault(-1) + .createWithDefault(Long.MaxValue) val SHUFFLE_PARTITIONS = SQLConfigBuilder("spark.sql.shuffle.partitions") .doc("The default number of partitions to use when shuffling data for joins or aggregations.") @@ -466,6 +466,19 @@ object SQLConf { .longConf .createWithDefault(4 * 1024 * 1024) + val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles") + .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + + "encountering corrupted or non-existing and contents that have been read will still be " + + "returned.") + .booleanConf + .createWithDefault(false) + + val MAX_RECORDS_PER_FILE = SQLConfigBuilder("spark.sql.files.maxRecordsPerFile") + .doc("Maximum number of records to write out to a single file. " + + "If this value is zero or negative, there is no limit.") + .longConf + .createWithDefault(0) + val EXCHANGE_REUSE_ENABLED = SQLConfigBuilder("spark.sql.exchange.reuse") .internal() .doc("When true, the planner will try to find out duplicated exchanges and re-use them.") @@ -480,18 +493,17 @@ object SQLConf { .intConf .createWithDefault(10) - val STATE_STORE_MIN_VERSIONS_TO_RETAIN = - SQLConfigBuilder("spark.sql.streaming.stateStore.minBatchesToRetain") - .internal() - .doc("Minimum number of versions of a state store's data to retain after cleaning.") - .intConf - .createWithDefault(2) - val CHECKPOINT_LOCATION = SQLConfigBuilder("spark.sql.streaming.checkpointLocation") .doc("The default location for storing checkpoint data for streaming queries.") .stringConf .createOptional + val MIN_BATCHES_TO_RETAIN = SQLConfigBuilder("spark.sql.streaming.minBatchesToRetain") + .internal() + .doc("The minimum number of batches that must be retained and made recoverable.") + .intConf + .createWithDefault(100) + val UNSUPPORTED_OPERATION_CHECK_ENABLED = SQLConfigBuilder("spark.sql.streaming.unsupportedOperationCheck") .internal() @@ -603,6 +615,13 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(10L) + val STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL = + SQLConfigBuilder("spark.sql.streaming.noDataProgressEventInterval") + .internal() + .doc("How long to wait between two progress events when there is no data") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(10000L) + val STREAMING_METRICS_ENABLED = SQLConfigBuilder("spark.sql.streaming.metricsEnabled") .doc("Whether Dropwizard/Codahale metrics will be reported for active streaming queries.") @@ -610,7 +629,7 @@ object SQLConf { .createWithDefault(false) val STREAMING_PROGRESS_RETENTION = - SQLConfigBuilder("spark.sql.streaming.numRecentProgresses") + SQLConfigBuilder("spark.sql.streaming.numRecentProgressUpdates") .doc("The number of progress updates to retain for a streaming query") .intConf .createWithDefault(100) @@ -623,12 +642,6 @@ object SQLConf { .doubleConf .createWithDefault(0.05) - val IGNORE_CORRUPT_FILES = SQLConfigBuilder("spark.sql.files.ignoreCorruptFiles") - .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + - "encountering corrupt files and contents that have been read will still be returned.") - .booleanConf - .createWithDefault(false) - object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -660,8 +673,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) - def stateStoreMinVersionsToRetain: Int = getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN) - def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED) @@ -684,6 +695,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def streamingPollingDelay: Long = getConf(STREAMING_POLLING_DELAY) + def streamingNoDataProgressEventInterval: Long = + getConf(STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL) + def streamingMetricsEnabled: Boolean = getConf(STREAMING_METRICS_ENABLED) def streamingProgressRetention: Int = getConf(STREAMING_PROGRESS_RETENTION) @@ -692,6 +706,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES) + def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES) + + def maxRecordsPerFile: Long = getConf(MAX_RECORDS_PER_FILE) + def useCompression: Boolean = getConf(COMPRESS_CACHED) def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) @@ -712,6 +730,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def minNumPostShufflePartitions: Int = getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS) + def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN) + def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) @@ -753,7 +773,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def enableRadixSort: Boolean = getConf(RADIX_SORT_ENABLED) - def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES, Long.MaxValue) + def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES) def isParquetSchemaMergingEnabled: Boolean = getConf(PARQUET_SCHEMA_MERGING_ENABLED) @@ -811,8 +831,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def warehousePath: String = new Path(getConf(StaticSQLConf.WAREHOUSE_PATH)).toString - def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES) - override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 70122f2599..da787b4859 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -36,6 +36,8 @@ private object MsSqlServerDialect extends JdbcDialect { override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case TimestampType => Some(JdbcType("DATETIME", java.sql.Types.TIMESTAMP)) + case StringType => Some(JdbcType("NVARCHAR(MAX)", java.sql.Types.NVARCHAR)) + case BooleanType => Some(JdbcType("BIT", java.sql.Types.BIT)) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index b3c600ae53..6c0c5e0c95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, ForeachWriter} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink} @@ -65,9 +66,11 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { OutputMode.Append case "complete" => OutputMode.Complete + case "update" => + OutputMode.Update case _ => throw new IllegalArgumentException(s"Unknown output mode $outputMode. " + - "Accepted output modes are 'append' and 'complete'") + "Accepted output modes are 'append', 'complete', 'update'") } this } @@ -99,7 +102,6 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { this } - /** * Specifies the name of the [[StreamingQuery]] that can be started with `start()`. * This name must be unique among all the currently active queries in the associated SQLContext. @@ -219,11 +221,20 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { if (extraOptions.get("queryName").isEmpty) { throw new AnalysisException("queryName must be specified for memory sink") } - + val supportedModes = "Output modes supported by the memory sink are 'append' and 'complete'." + outputMode match { + case Append | Complete => // allowed + case Update => + throw new AnalysisException( + s"Update output mode is not supported for memory sink. $supportedModes") + case _ => + throw new AnalysisException( + s"$outputMode is not supported for memory sink. $supportedModes") + } val sink = new MemorySink(df.schema, outputMode) val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink)) val chkpointLoc = extraOptions.get("checkpointLocation") - val recoverFromChkpoint = chkpointLoc.isDefined && outputMode == OutputMode.Complete() + val recoverFromChkpoint = outputMode == OutputMode.Complete() val query = df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), chkpointLoc, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index 8fc4e43b6d..596bd90140 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -32,20 +32,31 @@ import org.apache.spark.sql.SparkSession trait StreamingQuery { /** - * Returns the name of the query. This name is unique across all active queries. This can be - * set in the `org.apache.spark.sql.streaming.DataStreamWriter` as - * `dataframe.writeStream.queryName("query").start()`. + * Returns the user-specified name of the query, or null if not specified. + * This name can be specified in the `org.apache.spark.sql.streaming.DataStreamWriter` + * as `dataframe.writeStream.queryName("query").start()`. + * This name, if set, must be unique across all active queries. * * @since 2.0.0 */ def name: String /** - * Returns the unique id of this query. + * Returns the unique id of this query that persists across restarts from checkpoint data. + * That is, this id is generated when a query is started for the first time, and + * will be the same every time it is restarted from checkpoint data. Also see [[runId]]. + * * @since 2.1.0 */ def id: UUID + /** + * Returns the unique id of this run of the query. That is, every start/restart of a query will + * generated a unique runId. Therefore, every time a query is restarted from + * checkpoint, it will have the same [[id]] but different [[runId]]s. + */ + def runId: UUID + /** * Returns the `SparkSession` associated with `this`. * @@ -76,11 +87,11 @@ trait StreamingQuery { /** * Returns an array of the most recent [[StreamingQueryProgress]] updates for this query. * The number of progress updates retained for each stream is configured by Spark session - * configuration `spark.sql.streaming.numRecentProgresses`. + * configuration `spark.sql.streaming.numRecentProgressUpdates`. * * @since 2.1.0 */ - def recentProgresses: Array[StreamingQueryProgress] + def recentProgress: Array[StreamingQueryProgress] /** * Returns the most recent [[StreamingQueryProgress]] update of this streaming query. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala index 13f11ba1c9..c53c29591a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala @@ -18,38 +18,30 @@ package org.apache.spark.sql.streaming import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.streaming.{Offset, OffsetSeq, StreamExecution} /** * :: Experimental :: * Exception that stopped a [[StreamingQuery]]. Use `cause` get the actual exception * that caused the failure. - * @param query Query that caused the exception * @param message Message of this exception * @param cause Internal cause of this exception - * @param startOffset Starting offset (if known) of the range of data in which exception occurred - * @param endOffset Ending offset (if known) of the range of data in exception occurred + * @param startOffset Starting offset in json of the range of data in which exception occurred + * @param endOffset Ending offset in json of the range of data in exception occurred * @since 2.0.0 */ @Experimental class StreamingQueryException private[sql]( - @transient val query: StreamingQuery, + private val queryDebugString: String, val message: String, val cause: Throwable, - val startOffset: Option[OffsetSeq] = None, - val endOffset: Option[OffsetSeq] = None) + val startOffset: String, + val endOffset: String) extends Exception(message, cause) { /** Time when the exception occurred */ val time: Long = System.currentTimeMillis - override def toString(): String = { - val causeStr = - s"${cause.getMessage} ${cause.getStackTrace.take(10).mkString("", "\n|\t", "\n")}" - s""" - |$causeStr - | - |${query.asInstanceOf[StreamExecution].toDebugString} - """.stripMargin - } + override def toString(): String = + s"""${classOf[StreamingQueryException].getName}: ${cause.getMessage} + |$queryDebugString""".stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index d9ee75c064..817733286b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -83,14 +83,21 @@ object StreamingQueryListener { /** * :: Experimental :: * Event representing the start of a query + * @param id An unique query id that persists across restarts. See `StreamingQuery.id()`. + * @param runId A query id that is unique for every start/restart. See `StreamingQuery.runId()`. + * @param name User-specified name of the query, null if not specified. * @since 2.1.0 */ @Experimental - class QueryStartedEvent private[sql](val id: UUID, val name: String) extends Event + class QueryStartedEvent private[sql]( + val id: UUID, + val runId: UUID, + val name: String) extends Event /** * :: Experimental :: * Event representing any progress updates in a query. + * @param progress The query progress updates. * @since 2.1.0 */ @Experimental @@ -100,11 +107,15 @@ object StreamingQueryListener { * :: Experimental :: * Event representing that termination of a query. * - * @param id The query id. + * @param id An unique query id that persists across restarts. See `StreamingQuery.id()`. + * @param runId A query id that is unique for every start/restart. See `StreamingQuery.runId()`. * @param exception The exception message of the query if the query was terminated * with an exception. Otherwise, it will be `None`. * @since 2.1.0 */ @Experimental - class QueryTerminatedEvent private[sql](val id: UUID, val exception: Option[String]) extends Event + class QueryTerminatedEvent private[sql]( + val id: UUID, + val runId: UUID, + val exception: Option[String]) extends Event } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index c448468bea..8c26ee2bd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.streaming import java.util.UUID -import java.util.concurrent.atomic.AtomicLong +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -44,10 +44,13 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) { private[sql] val stateStoreCoordinator = StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env) private val listenerBus = new StreamingQueryListenerBus(sparkSession.sparkContext.listenerBus) + + @GuardedBy("activeQueriesLock") private val activeQueries = new mutable.HashMap[UUID, StreamingQuery] private val activeQueriesLock = new Object private val awaitTerminationLock = new Object + @GuardedBy("awaitTerminationLock") private var lastTerminatedQuery: StreamingQuery = null /** @@ -181,8 +184,65 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) { listenerBus.post(event) } + private def createQuery( + userSpecifiedName: Option[String], + userSpecifiedCheckpointLocation: Option[String], + df: DataFrame, + sink: Sink, + outputMode: OutputMode, + useTempCheckpointLocation: Boolean, + recoverFromCheckpointLocation: Boolean, + trigger: Trigger, + triggerClock: Clock): StreamingQueryWrapper = { + val checkpointLocation = userSpecifiedCheckpointLocation.map { userSpecified => + new Path(userSpecified).toUri.toString + }.orElse { + df.sparkSession.sessionState.conf.checkpointLocation.map { location => + new Path(location, userSpecifiedName.getOrElse(UUID.randomUUID().toString)).toUri.toString + } + }.getOrElse { + if (useTempCheckpointLocation) { + Utils.createTempDir(namePrefix = s"temporary").getCanonicalPath + } else { + throw new AnalysisException( + "checkpointLocation must be specified either " + + """through option("checkpointLocation", ...) or """ + + s"""SparkSession.conf.set("${SQLConf.CHECKPOINT_LOCATION.key}", ...)""") + } + } + + // If offsets have already been created, we trying to resume a query. + if (!recoverFromCheckpointLocation) { + val checkpointPath = new Path(checkpointLocation, "offsets") + val fs = checkpointPath.getFileSystem(df.sparkSession.sessionState.newHadoopConf()) + if (fs.exists(checkpointPath)) { + throw new AnalysisException( + s"This query does not support recovering from checkpoint location. " + + s"Delete $checkpointPath to start over.") + } + } + + val analyzedPlan = df.queryExecution.analyzed + df.queryExecution.assertAnalyzed() + + if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { + UnsupportedOperationChecker.checkForStreaming(analyzedPlan, outputMode) + } + + new StreamingQueryWrapper(new StreamExecution( + sparkSession, + userSpecifiedName.orNull, + checkpointLocation, + analyzedPlan, + sink, + trigger, + triggerClock, + outputMode)) + } + /** * Start a [[StreamingQuery]]. + * * @param userSpecifiedName Query name optionally specified by the user. * @param userSpecifiedCheckpointLocation Checkpoint location optionally specified by the user. * @param df Streaming DataFrame. @@ -206,72 +266,50 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) { recoverFromCheckpointLocation: Boolean = true, trigger: Trigger = ProcessingTime(0), triggerClock: Clock = new SystemClock()): StreamingQuery = { - activeQueriesLock.synchronized { - val name = userSpecifiedName.getOrElse(s"query-${StreamingQueryManager.nextId}") - if (activeQueries.values.exists(_.name == name)) { - throw new IllegalArgumentException( - s"Cannot start query with name $name as a query with that name is already active") - } - val checkpointLocation = userSpecifiedCheckpointLocation.map { userSpecified => - new Path(userSpecified).toUri.toString - }.orElse { - df.sparkSession.sessionState.conf.checkpointLocation.map { location => - new Path(location, name).toUri.toString - } - }.getOrElse { - if (useTempCheckpointLocation) { - Utils.createTempDir(namePrefix = s"temporary").getCanonicalPath - } else { - throw new AnalysisException( - "checkpointLocation must be specified either " + - """through option("checkpointLocation", ...) or """ + - s"""SparkSession.conf.set("${SQLConf.CHECKPOINT_LOCATION.key}", ...)""") - } - } + val query = createQuery( + userSpecifiedName, + userSpecifiedCheckpointLocation, + df, + sink, + outputMode, + useTempCheckpointLocation, + recoverFromCheckpointLocation, + trigger, + triggerClock) - // If offsets have already been created, we trying to resume a query. - if (!recoverFromCheckpointLocation) { - val checkpointPath = new Path(checkpointLocation, "offsets") - val fs = checkpointPath.getFileSystem(df.sparkSession.sessionState.newHadoopConf()) - if (fs.exists(checkpointPath)) { - throw new AnalysisException( - s"This query does not support recovering from checkpoint location. " + - s"Delete $checkpointPath to start over.") + activeQueriesLock.synchronized { + // Make sure no other query with same name is active + userSpecifiedName.foreach { name => + if (activeQueries.values.exists(_.name == name)) { + throw new IllegalArgumentException( + s"Cannot start query with name $name as a query with that name is already active") } } - val analyzedPlan = df.queryExecution.analyzed - df.queryExecution.assertAnalyzed() - - if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { - UnsupportedOperationChecker.checkForStreaming(analyzedPlan, outputMode) + // Make sure no other query with same id is active + if (activeQueries.values.exists(_.id == query.id)) { + throw new IllegalStateException( + s"Cannot start query with id ${query.id} as another query with same id is " + + s"already active. Perhaps you are attempting to restart a query from checkpoint " + + s"that is already active.") } - var nextSourceId = 0L - - val logicalPlan = analyzedPlan.transform { - case StreamingRelation(dataSource, _, output) => - // Materialize source to avoid creating it in every batch - val metadataPath = s"$checkpointLocation/sources/$nextSourceId" - val source = dataSource.createSource(metadataPath) - nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. - StreamingExecutionRelation(source, output) - } - val query = new StreamExecution( - sparkSession, - name, - checkpointLocation, - logicalPlan, - sink, - trigger, - triggerClock, - outputMode) - query.start() activeQueries.put(query.id, query) - query } + try { + // When starting a query, it will call `StreamingQueryListener.onQueryStarted` synchronously. + // As it's provided by the user and can run arbitrary codes, we must not hold any lock here. + // Otherwise, it's easy to cause dead-lock, or block too long if the user codes take a long + // time to finish. + query.streamingQuery.start() + } catch { + case e: Throwable => + activeQueriesLock.synchronized { + activeQueries -= query.id + } + throw e + } + query } /** Notify (by the StreamingQuery) that the query has been terminated */ @@ -287,8 +325,3 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) { } } } - -private object StreamingQueryManager { - private val _nextId = new AtomicLong(0) - private def nextId: Long = _nextId.getAndIncrement() -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala index 44befa0d2f..c2befa6343 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -22,7 +22,10 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.annotation.Experimental + /** + * :: Experimental :: * Reports information about the instantaneous status of a streaming query. * * @param message A human readable description of what the stream is currently doing. @@ -32,10 +35,11 @@ import org.json4s.jackson.JsonMethods._ * * @since 2.1.0 */ +@Experimental class StreamingQueryStatus protected[sql]( val message: String, val isDataAvailable: Boolean, - val isTriggerActive: Boolean) { + val isTriggerActive: Boolean) extends Serializable { /** The compact JSON representation of this status. */ def json: String = compact(render(jsonValue)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index 4c8247458f..fde61c52ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming import java.{util => ju} +import java.lang.{Long => JLong} import java.util.UUID import scala.collection.JavaConverters._ @@ -37,7 +38,14 @@ import org.apache.spark.annotation.Experimental @Experimental class StateOperatorProgress private[sql]( val numRowsTotal: Long, - val numRowsUpdated: Long) { + val numRowsUpdated: Long) extends Serializable { + + /** The compact JSON representation of this progress. */ + def json: String = compact(render(jsonValue)) + + /** The pretty (i.e. indented) JSON representation of this progress. */ + def prettyJson: String = pretty(render(jsonValue)) + private[sql] def jsonValue: JValue = { ("numRowsTotal" -> JInt(numRowsTotal)) ~ ("numRowsUpdated" -> JInt(numRowsUpdated)) @@ -50,15 +58,23 @@ class StateOperatorProgress private[sql]( * a trigger. Each event relates to processing done for a single trigger of the streaming * query. Events are emitted even when no new data is available to be processed. * - * @param id A unique id of the query. - * @param name Name of the query. This name is unique across all active queries. - * @param timestamp Timestamp (ms) of the beginning of the trigger. + * @param id An unique query id that persists across restarts. See `StreamingQuery.id()`. + * @param runId A query id that is unique for every start/restart. See `StreamingQuery.runId()`. + * @param name User-specified name of the query, null if not specified. + * @param timestamp Beginning time of the trigger in ISO8601 format, i.e. UTC timestamps. * @param batchId A unique id for the current batch of data being processed. Note that in the * case of retries after a failure a given batchId my be executed more than once. * Similarly, when there is no data to be processed, the batchId will not be * incremented. * @param durationMs The amount of time taken to perform various operations in milliseconds. - * @param currentWatermark The current event time watermark in milliseconds + * @param eventTime Statistics of event time seen in this batch. It may contain the following keys: + * {{{ + * "max" -> "2016-12-05T20:54:20.827Z" // maximum event time seen in this trigger + * "min" -> "2016-12-05T20:54:20.827Z" // minimum event time seen in this trigger + * "avg" -> "2016-12-05T20:54:20.827Z" // average event time seen in this trigger + * "watermark" -> "2016-12-05T20:54:20.827Z" // watermark used in this trigger + * }}} + * All timestamps are in ISO8601 format, i.e. UTC timestamps. * @param stateOperators Information about operators in the query that store state. * @param sources detailed statistics on data being read from each of the streaming sources. * @since 2.1.0 @@ -66,14 +82,15 @@ class StateOperatorProgress private[sql]( @Experimental class StreamingQueryProgress private[sql]( val id: UUID, + val runId: UUID, val name: String, - val timestamp: Long, + val timestamp: String, val batchId: Long, - val durationMs: ju.Map[String, java.lang.Long], - val currentWatermark: Long, + val durationMs: ju.Map[String, JLong], + val eventTime: ju.Map[String, String], val stateOperators: Array[StateOperatorProgress], val sources: Array[SourceProgress], - val sink: SinkProgress) { + val sink: SinkProgress) extends Serializable { /** The aggregate (across all sources) number of records processed in a trigger. */ def numInputRows: Long = sources.map(_.numInputRows).sum @@ -97,21 +114,25 @@ class StreamingQueryProgress private[sql]( if (value.isNaN || value.isInfinity) JNothing else JDouble(value) } + /** Convert map to JValue while handling empty maps. Also, this sorts the keys. */ + def safeMapToJValue[T](map: ju.Map[String, T], valueToJValue: T => JValue): JValue = { + if (map.isEmpty) return JNothing + val keys = map.asScala.keySet.toSeq.sorted + keys.map { k => k -> valueToJValue(map.get(k)) : JObject }.reduce(_ ~ _) + } + ("id" -> JString(id.toString)) ~ + ("runId" -> JString(runId.toString)) ~ ("name" -> JString(name)) ~ - ("timestamp" -> JInt(timestamp)) ~ + ("timestamp" -> JString(timestamp)) ~ ("numInputRows" -> JInt(numInputRows)) ~ ("inputRowsPerSecond" -> safeDoubleToJValue(inputRowsPerSecond)) ~ ("processedRowsPerSecond" -> safeDoubleToJValue(processedRowsPerSecond)) ~ - ("durationMs" -> durationMs - .asScala - .map { case (k, v) => k -> JInt(v.toLong): JObject } - .reduce(_ ~ _)) ~ - ("currentWatermark" -> JInt(currentWatermark)) ~ + ("durationMs" -> safeMapToJValue[JLong](durationMs, v => JInt(v.toLong))) ~ + ("eventTime" -> safeMapToJValue[String](eventTime, s => JString(s))) ~ ("stateOperators" -> JArray(stateOperators.map(_.jsonValue).toList)) ~ ("sources" -> JArray(sources.map(_.jsonValue).toList)) ~ ("sink" -> sink.jsonValue) - } } @@ -136,7 +157,7 @@ class SourceProgress protected[sql]( val endOffset: String, val numInputRows: Long, val inputRowsPerSecond: Double, - val processedRowsPerSecond: Double) { + val processedRowsPerSecond: Double) extends Serializable { /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) @@ -176,7 +197,7 @@ class SourceProgress protected[sql]( */ @Experimental class SinkProgress protected[sql]( - val description: String) { + val description: String) extends Serializable { /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/change-column.sql b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql new file mode 100644 index 0000000000..818b19c50f --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/change-column.sql @@ -0,0 +1,55 @@ +-- Create the origin table +CREATE TABLE test_change(a INT, b STRING, c INT); +DESC test_change; + +-- Change column name (not supported yet) +ALTER TABLE test_change CHANGE a a1 INT; +DESC test_change; + +-- Change column dataType (not supported yet) +ALTER TABLE test_change CHANGE a a STRING; +DESC test_change; + +-- Change column position (not supported yet) +ALTER TABLE test_change CHANGE a a INT AFTER b; +ALTER TABLE test_change CHANGE b b STRING FIRST; +DESC test_change; + +-- Change column comment +ALTER TABLE test_change CHANGE a a INT COMMENT 'this is column a'; +ALTER TABLE test_change CHANGE b b STRING COMMENT '#*02?`'; +ALTER TABLE test_change CHANGE c c INT COMMENT ''; +DESC test_change; + +-- Don't change anything. +ALTER TABLE test_change CHANGE a a INT COMMENT 'this is column a'; +DESC test_change; + +-- Change a invalid column +ALTER TABLE test_change CHANGE invalid_col invalid_col INT; +DESC test_change; + +-- Change column name/dataType/position/comment together (not supported yet) +ALTER TABLE test_change CHANGE a a1 STRING COMMENT 'this is column a1' AFTER b; +DESC test_change; + +-- Check the behavior with different values of CASE_SENSITIVE +SET spark.sql.caseSensitive=false; +ALTER TABLE test_change CHANGE a A INT COMMENT 'this is column A'; +SET spark.sql.caseSensitive=true; +ALTER TABLE test_change CHANGE a A INT COMMENT 'this is column A1'; +DESC test_change; + +-- Change column can't apply to a temporary/global_temporary view +CREATE TEMPORARY VIEW temp_view(a, b) AS SELECT 1, "one"; +ALTER TABLE temp_view CHANGE a a INT COMMENT 'this is column a'; +CREATE GLOBAL TEMPORARY VIEW global_temp_view(a, b) AS SELECT 1, "one"; +ALTER TABLE global_temp.global_temp_view CHANGE a a INT COMMENT 'this is column a'; + +-- Change column in partition spec (not supported yet) +CREATE TABLE partition_table(a INT, b STRING) PARTITIONED BY (c INT, d STRING); +ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT; + +-- DROP TEST TABLE +DROP TABLE test_change; +DROP TABLE partition_table; diff --git a/sql/core/src/test/resources/sql-tests/inputs/scalar-subquery.sql b/sql/core/src/test/resources/sql-tests/inputs/scalar-subquery.sql new file mode 100644 index 0000000000..3acc9db09c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/scalar-subquery.sql @@ -0,0 +1,20 @@ +CREATE OR REPLACE TEMPORARY VIEW p AS VALUES (1, 1) AS T(pk, pv); +CREATE OR REPLACE TEMPORARY VIEW c AS VALUES (1, 1) AS T(ck, cv); + +-- SPARK-18814.1: Simplified version of TPCDS-Q32 +SELECT pk, cv +FROM p, c +WHERE p.pk = c.ck +AND c.cv = (SELECT avg(c1.cv) + FROM c c1 + WHERE c1.ck = p.pk); + +-- SPARK-18814.2: Adding stack of aggregates +SELECT pk, cv +FROM p, c +WHERE p.pk = c.ck +AND c.cv = (SELECT max(avg) + FROM (SELECT c1.cv, avg(c1.cv) avg + FROM c c1 + WHERE c1.ck = p.pk + GROUP BY c1.cv)); diff --git a/sql/core/src/test/resources/sql-tests/inputs/show-tables.sql b/sql/core/src/test/resources/sql-tests/inputs/show-tables.sql index a16c39819a..18d02e150e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/show-tables.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/show-tables.sql @@ -16,11 +16,11 @@ SHOW TABLES 'show_t*'; SHOW TABLES LIKE 'show_t1*|show_t2*'; SHOW TABLES IN showdb 'show_t*'; --- SHOW TABLES EXTENDED +-- SHOW TABLE EXTENDED -- Ignore these because there exist timestamp results, e.g. `Created`. --- SHOW TABLES EXTENDED LIKE 'show_t*'; -SHOW TABLES EXTENDED; -SHOW TABLES EXTENDED LIKE 'show_t1' PARTITION(c='Us'); +-- SHOW TABLE EXTENDED LIKE 'show_t*'; +SHOW TABLE EXTENDED; +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us'); -- Clean Up DROP TABLE show_t1; diff --git a/sql/core/src/test/resources/sql-tests/results/change-column.sql.out b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out new file mode 100644 index 0000000000..156ddb86ad --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/change-column.sql.out @@ -0,0 +1,306 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 32 + + +-- !query 0 +CREATE TABLE test_change(a INT, b STRING, c INT) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +DESC test_change +-- !query 1 schema +struct +-- !query 1 output +a int +b string +c int + + +-- !query 2 +ALTER TABLE test_change CHANGE a a1 INT +-- !query 2 schema +struct<> +-- !query 2 output +org.apache.spark.sql.AnalysisException +ALTER TABLE CHANGE COLUMN is not supported for changing column 'a' with type 'IntegerType' to 'a1' with type 'IntegerType'; + + +-- !query 3 +DESC test_change +-- !query 3 schema +struct +-- !query 3 output +a int +b string +c int + + +-- !query 4 +ALTER TABLE test_change CHANGE a a STRING +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +ALTER TABLE CHANGE COLUMN is not supported for changing column 'a' with type 'IntegerType' to 'a' with type 'StringType'; + + +-- !query 5 +DESC test_change +-- !query 5 schema +struct +-- !query 5 output +a int +b string +c int + + +-- !query 6 +ALTER TABLE test_change CHANGE a a INT AFTER b +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.catalyst.parser.ParseException + +Operation not allowed: ALTER TABLE table [PARTITION partition_spec] CHANGE COLUMN ... FIRST | AFTER otherCol(line 1, pos 0) + +== SQL == +ALTER TABLE test_change CHANGE a a INT AFTER b +^^^ + + +-- !query 7 +ALTER TABLE test_change CHANGE b b STRING FIRST +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.catalyst.parser.ParseException + +Operation not allowed: ALTER TABLE table [PARTITION partition_spec] CHANGE COLUMN ... FIRST | AFTER otherCol(line 1, pos 0) + +== SQL == +ALTER TABLE test_change CHANGE b b STRING FIRST +^^^ + + +-- !query 8 +DESC test_change +-- !query 8 schema +struct +-- !query 8 output +a int +b string +c int + + +-- !query 9 +ALTER TABLE test_change CHANGE a a INT COMMENT 'this is column a' +-- !query 9 schema +struct<> +-- !query 9 output + + + +-- !query 10 +ALTER TABLE test_change CHANGE b b STRING COMMENT '#*02?`' +-- !query 10 schema +struct<> +-- !query 10 output + + + +-- !query 11 +ALTER TABLE test_change CHANGE c c INT COMMENT '' +-- !query 11 schema +struct<> +-- !query 11 output + + + +-- !query 12 +DESC test_change +-- !query 12 schema +struct +-- !query 12 output +a int this is column a +b string #*02?` +c int + + +-- !query 13 +ALTER TABLE test_change CHANGE a a INT COMMENT 'this is column a' +-- !query 13 schema +struct<> +-- !query 13 output + + + +-- !query 14 +DESC test_change +-- !query 14 schema +struct +-- !query 14 output +a int this is column a +b string #*02?` +c int + + +-- !query 15 +ALTER TABLE test_change CHANGE invalid_col invalid_col INT +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +Invalid column reference 'invalid_col', table schema is 'StructType(StructField(a,IntegerType,true), StructField(b,StringType,true), StructField(c,IntegerType,true))'; + + +-- !query 16 +DESC test_change +-- !query 16 schema +struct +-- !query 16 output +a int this is column a +b string #*02?` +c int + + +-- !query 17 +ALTER TABLE test_change CHANGE a a1 STRING COMMENT 'this is column a1' AFTER b +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.catalyst.parser.ParseException + +Operation not allowed: ALTER TABLE table [PARTITION partition_spec] CHANGE COLUMN ... FIRST | AFTER otherCol(line 1, pos 0) + +== SQL == +ALTER TABLE test_change CHANGE a a1 STRING COMMENT 'this is column a1' AFTER b +^^^ + + +-- !query 18 +DESC test_change +-- !query 18 schema +struct +-- !query 18 output +a int this is column a +b string #*02?` +c int + + +-- !query 19 +SET spark.sql.caseSensitive=false +-- !query 19 schema +struct +-- !query 19 output +spark.sql.caseSensitive + + +-- !query 20 +ALTER TABLE test_change CHANGE a A INT COMMENT 'this is column A' +-- !query 20 schema +struct<> +-- !query 20 output + + + +-- !query 21 +SET spark.sql.caseSensitive=true +-- !query 21 schema +struct +-- !query 21 output +spark.sql.caseSensitive + + +-- !query 22 +ALTER TABLE test_change CHANGE a A INT COMMENT 'this is column A1' +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +ALTER TABLE CHANGE COLUMN is not supported for changing column 'a' with type 'IntegerType' to 'A' with type 'IntegerType'; + + +-- !query 23 +DESC test_change +-- !query 23 schema +struct +-- !query 23 output +a int this is column A +b string #*02?` +c int + + +-- !query 24 +CREATE TEMPORARY VIEW temp_view(a, b) AS SELECT 1, "one" +-- !query 24 schema +struct<> +-- !query 24 output + + + +-- !query 25 +ALTER TABLE temp_view CHANGE a a INT COMMENT 'this is column a' +-- !query 25 schema +struct<> +-- !query 25 output +org.apache.spark.sql.catalyst.analysis.NoSuchTableException +Table or view 'temp_view' not found in database 'default'; + + +-- !query 26 +CREATE GLOBAL TEMPORARY VIEW global_temp_view(a, b) AS SELECT 1, "one" +-- !query 26 schema +struct<> +-- !query 26 output + + + +-- !query 27 +ALTER TABLE global_temp.global_temp_view CHANGE a a INT COMMENT 'this is column a' +-- !query 27 schema +struct<> +-- !query 27 output +org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException +Database 'global_temp' not found; + + +-- !query 28 +CREATE TABLE partition_table(a INT, b STRING) PARTITIONED BY (c INT, d STRING) +-- !query 28 schema +struct<> +-- !query 28 output + + + +-- !query 29 +ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT +-- !query 29 schema +struct<> +-- !query 29 output +org.apache.spark.sql.catalyst.parser.ParseException + +Operation not allowed: ALTER TABLE table PARTITION partition_spec CHANGE COLUMN(line 1, pos 0) + +== SQL == +ALTER TABLE partition_table PARTITION (c = 1) CHANGE COLUMN a new_a INT +^^^ + + +-- !query 30 +DROP TABLE test_change +-- !query 30 schema +struct<> +-- !query 30 output + + + +-- !query 31 +DROP TABLE partition_table +-- !query 31 schema +struct<> +-- !query 31 output + diff --git a/sql/core/src/test/resources/sql-tests/results/scalar-subquery.sql.out b/sql/core/src/test/resources/sql-tests/results/scalar-subquery.sql.out new file mode 100644 index 0000000000..c249329d6a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/scalar-subquery.sql.out @@ -0,0 +1,46 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 4 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW p AS VALUES (1, 1) AS T(pk, pv) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE OR REPLACE TEMPORARY VIEW c AS VALUES (1, 1) AS T(ck, cv) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT pk, cv +FROM p, c +WHERE p.pk = c.ck +AND c.cv = (SELECT avg(c1.cv) + FROM c c1 + WHERE c1.ck = p.pk) +-- !query 2 schema +struct +-- !query 2 output +1 1 + + +-- !query 3 +SELECT pk, cv +FROM p, c +WHERE p.pk = c.ck +AND c.cv = (SELECT max(avg) + FROM (SELECT c1.cv, avg(c1.cv) avg + FROM c c1 + WHERE c1.ck = p.pk + GROUP BY c1.cv)) +-- !query 3 schema +struct +-- !query 3 output +1 1 diff --git a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out index a4f411258d..904601bf11 100644 --- a/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/show-tables.sql.out @@ -114,28 +114,30 @@ show_t3 -- !query 12 -SHOW TABLES EXTENDED +SHOW TABLE EXTENDED -- !query 12 schema struct<> -- !query 12 output org.apache.spark.sql.catalyst.parser.ParseException -SHOW TABLES EXTENDED must have identifier_with_wildcards specified. +mismatched input '' expecting 'LIKE'(line 1, pos 19) + == SQL == -SHOW TABLES EXTENDED +SHOW TABLE EXTENDED +-------------------^^^ -- !query 13 -SHOW TABLES EXTENDED LIKE 'show_t1' PARTITION(c='Us') +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us') -- !query 13 schema struct<> -- !query 13 output org.apache.spark.sql.catalyst.parser.ParseException -Operation not allowed: SHOW TABLES [EXTENDED] ... PARTITION(line 1, pos 0) +Operation not allowed: SHOW TABLE EXTENDED ... PARTITION(line 1, pos 0) == SQL == -SHOW TABLES EXTENDED LIKE 'show_t1' PARTITION(c='Us') +SHOW TABLE EXTENDED LIKE 'show_t1' PARTITION(c='Us') ^^^ diff --git a/sql/core/src/test/resources/structured-streaming/file-sink-log-version-2.1.0/7.compact b/sql/core/src/test/resources/structured-streaming/file-sink-log-version-2.1.0/7.compact new file mode 100644 index 0000000000..e1ec8a74f0 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/file-sink-log-version-2.1.0/7.compact @@ -0,0 +1,9 @@ +v1 +{"path":"/a/b/0","size":1,"isDir":false,"modificationTime":1,"blockReplication":1,"blockSize":100,"action":"add"} +{"path":"/a/b/1","size":100,"isDir":false,"modificationTime":100,"blockReplication":1,"blockSize":100,"action":"add"} +{"path":"/a/b/2","size":200,"isDir":false,"modificationTime":200,"blockReplication":1,"blockSize":100,"action":"add"} +{"path":"/a/b/3","size":300,"isDir":false,"modificationTime":300,"blockReplication":1,"blockSize":100,"action":"add"} +{"path":"/a/b/4","size":400,"isDir":false,"modificationTime":400,"blockReplication":1,"blockSize":100,"action":"add"} +{"path":"/a/b/5","size":500,"isDir":false,"modificationTime":500,"blockReplication":1,"blockSize":100,"action":"add"} +{"path":"/a/b/6","size":600,"isDir":false,"modificationTime":600,"blockReplication":1,"blockSize":100,"action":"add"} +{"path":"/a/b/7","size":700,"isDir":false,"modificationTime":700,"blockReplication":1,"blockSize":100,"action":"add"} diff --git a/sql/core/src/test/resources/structured-streaming/file-sink-log-version-2.1.0/8 b/sql/core/src/test/resources/structured-streaming/file-sink-log-version-2.1.0/8 new file mode 100644 index 0000000000..e7989804e8 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/file-sink-log-version-2.1.0/8 @@ -0,0 +1,3 @@ +v1 +{"path":"/a/b/8","size":800,"isDir":false,"modificationTime":800,"blockReplication":1,"blockSize":100,"action":"add"} +{"path":"/a/b/0","size":100,"isDir":false,"modificationTime":100,"blockReplication":1,"blockSize":100,"action":"delete"} diff --git a/sql/core/src/test/resources/structured-streaming/file-sink-log-version-2.1.0/9 b/sql/core/src/test/resources/structured-streaming/file-sink-log-version-2.1.0/9 new file mode 100644 index 0000000000..42fb0ee416 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/file-sink-log-version-2.1.0/9 @@ -0,0 +1,2 @@ +v1 +{"path":"/a/b/9","size":900,"isDir":false,"modificationTime":900,"blockReplication":3,"blockSize":200,"action":"add"} diff --git a/sql/core/src/test/resources/structured-streaming/file-source-log-version-2.1.0/2.compact b/sql/core/src/test/resources/structured-streaming/file-source-log-version-2.1.0/2.compact new file mode 100644 index 0000000000..95f78bb262 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/file-source-log-version-2.1.0/2.compact @@ -0,0 +1,4 @@ +v1 +{"path":"/a/b/0","timestamp":1480730949000,"batchId":0} +{"path":"/a/b/1","timestamp":1480730950000,"batchId":1} +{"path":"/a/b/2","timestamp":1480730950000,"batchId":2} diff --git a/sql/core/src/test/resources/structured-streaming/file-source-log-version-2.1.0/3 b/sql/core/src/test/resources/structured-streaming/file-source-log-version-2.1.0/3 new file mode 100644 index 0000000000..2caa5972e4 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/file-source-log-version-2.1.0/3 @@ -0,0 +1,2 @@ +v1 +{"path":"/a/b/3","timestamp":1480730950000,"batchId":3} diff --git a/sql/core/src/test/resources/structured-streaming/file-source-log-version-2.1.0/4 b/sql/core/src/test/resources/structured-streaming/file-source-log-version-2.1.0/4 new file mode 100644 index 0000000000..e54b943229 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/file-source-log-version-2.1.0/4 @@ -0,0 +1,2 @@ +v1 +{"path":"/a/b/4","timestamp":1480730951000,"batchId":4} diff --git a/sql/core/src/test/resources/structured-streaming/file-source-offset-version-2.1.0-json.txt b/sql/core/src/test/resources/structured-streaming/file-source-offset-version-2.1.0-json.txt new file mode 100644 index 0000000000..e266a47368 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/file-source-offset-version-2.1.0-json.txt @@ -0,0 +1 @@ +{"logOffset":345} diff --git a/sql/core/src/test/resources/structured-streaming/file-source-offset-version-2.1.0-long.txt b/sql/core/src/test/resources/structured-streaming/file-source-offset-version-2.1.0-long.txt new file mode 100644 index 0000000000..51b4008129 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/file-source-offset-version-2.1.0-long.txt @@ -0,0 +1 @@ +345 diff --git a/sql/core/src/test/resources/structured-streaming/offset-log-version-2.1.0/0 b/sql/core/src/test/resources/structured-streaming/offset-log-version-2.1.0/0 new file mode 100644 index 0000000000..988a98a758 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/offset-log-version-2.1.0/0 @@ -0,0 +1,4 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1480981499528} +{"logOffset":345} +{"topic-0":{"0":1}} diff --git a/sql/core/src/test/resources/structured-streaming/query-metadata-logs-version-2.1.0.txt b/sql/core/src/test/resources/structured-streaming/query-metadata-logs-version-2.1.0.txt new file mode 100644 index 0000000000..79613e2362 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/query-metadata-logs-version-2.1.0.txt @@ -0,0 +1,3 @@ +{ + "id": "d366a8bf-db79-42ca-b5a4-d9ca0a11d63e" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 26e1a9f75d..b0339a88fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -533,31 +533,54 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { ) } - test("input_file_name - FileScanRDD") { + test("input_file_name, input_file_block_start, input_file_block_length - FileScanRDD") { withTempPath { dir => val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) - val answer = spark.read.parquet(dir.getCanonicalPath).select(input_file_name()) - .head.getString(0) - assert(answer.contains(dir.getCanonicalPath)) - checkAnswer(data.select(input_file_name()).limit(1), Row("")) + // Test the 3 expressions when reading from files + val q = spark.read.parquet(dir.getCanonicalPath).select( + input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()")) + val firstRow = q.head() + assert(firstRow.getString(0).contains(dir.getCanonicalPath)) + assert(firstRow.getLong(1) == 0) + assert(firstRow.getLong(2) > 0) + + // Now read directly from the original RDD without going through any files to make sure + // we are returning empty string, -1, and -1. + checkAnswer( + data.select( + input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()") + ).limit(1), + Row("", -1L, -1L)) } } - test("input_file_name - HadoopRDD") { + test("input_file_name, input_file_block_start, input_file_block_length - HadoopRDD") { withTempPath { dir => val data = sparkContext.parallelize((0 to 10).map(_.toString)).toDF() data.write.text(dir.getCanonicalPath) val df = spark.sparkContext.textFile(dir.getCanonicalPath).toDF() - val answer = df.select(input_file_name()).head.getString(0) - assert(answer.contains(dir.getCanonicalPath)) - checkAnswer(data.select(input_file_name()).limit(1), Row("")) + // Test the 3 expressions when reading from files + val q = df.select( + input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()")) + val firstRow = q.head() + assert(firstRow.getString(0).contains(dir.getCanonicalPath)) + assert(firstRow.getLong(1) == 0) + assert(firstRow.getLong(2) > 0) + + // Now read directly from the original RDD without going through any files to make sure + // we are returning empty string, -1, and -1. + checkAnswer( + data.select( + input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()") + ).limit(1), + Row("", -1L, -1L)) } } - test("input_file_name - NewHadoopRDD") { + test("input_file_name, input_file_block_start, input_file_block_length - NewHadoopRDD") { withTempPath { dir => val data = sparkContext.parallelize((0 to 10).map(_.toString)).toDF() data.write.text(dir.getCanonicalPath) @@ -567,10 +590,22 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { classOf[LongWritable], classOf[Text]) val df = rdd.map(pair => pair._2.toString).toDF() - val answer = df.select(input_file_name()).head.getString(0) - assert(answer.contains(dir.getCanonicalPath)) - checkAnswer(data.select(input_file_name()).limit(1), Row("")) + // Test the 3 expressions when reading from files + val q = df.select( + input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()")) + val firstRow = q.head() + assert(firstRow.getString(0).contains(dir.getCanonicalPath)) + assert(firstRow.getLong(1) == 0) + assert(firstRow.getLong(2) > 0) + + // Now read directly from the original RDD without going through any files to make sure + // we are returning empty string, -1, and -1. + checkAnswer( + data.select( + input_file_name(), expr("input_file_block_start()"), expr("input_file_block_length()") + ).limit(1), + Row("", -1L, -1L)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 7aa4f0026f..645175900f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -513,4 +513,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { df.groupBy($"x").agg(countDistinct($"y"), sort_array(collect_list($"z"))), Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d")))) } + + test("SPARK-18004 limit + aggregates") { + val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") + val limit2Df = df.limit(2) + checkAnswer( + limit2Df.groupBy("id").count().select($"id"), + limit2Df.select($"id")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 47b55e2547..fd829846ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -138,6 +138,24 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", "col1" :: Nil), Row("test", null)) + + checkAnswer( + Seq[(Long, Long)]((1, 2), (-1, -2), (9123146099426677101L, 9123146560113991650L)) + .toDF("a", "b").na.fill(0), + Row(1, 2) :: Row(-1, -2) :: Row(9123146099426677101L, 9123146560113991650L) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) + .toDF("a", "b").na.fill(2.34), + Row(2, 1.23) :: Row(3, 2.34) :: Row(4, 3.45) :: Nil + ) + + checkAnswer( + Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 3.45)) + .toDF("a", "b").na.fill(5), + Row(5, 1.23) :: Row(3, 5.0) :: Row(4, 3.45) :: Nil + ) } test("fill with map") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 1bbe1354d5..a8d854ccbc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -208,4 +208,12 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ ) } + test("pivot with column definition in groupby") { + checkAnswer( + courseSales.groupBy(substring(col("course"), 0, 1).as("foo")) + .pivot("year", Seq(2012, 2013)) + .sum("earnings"), + Row("d", 15000.0, 48000.0) :: Row("J", 20000.0, 30000.0) :: Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 36b2651e5a..0e7eaa9e88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -92,13 +92,13 @@ object NameAgg extends Aggregator[AggData, String, String] { } -object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[Int]] { +object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[(Int, Int)]] { def zero: Seq[Int] = Nil def reduce(b: Seq[Int], a: AggData): Seq[Int] = a.a +: b def merge(b1: Seq[Int], b2: Seq[Int]): Seq[Int] = b1 ++ b2 - def finish(r: Seq[Int]): Seq[Int] = r + def finish(r: Seq[Int]): Seq[(Int, Int)] = r.map(i => i -> i) override def bufferEncoder: Encoder[Seq[Int]] = ExpressionEncoder() - override def outputEncoder: Encoder[Seq[Int]] = ExpressionEncoder() + override def outputEncoder: Encoder[Seq[(Int, Int)]] = ExpressionEncoder() } @@ -281,7 +281,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { checkDataset( ds.groupByKey(_.b).agg(SeqAgg.toColumn), - "a" -> Seq(1, 2) + "a" -> Seq(1 -> 1, 2 -> 2) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d31c766cb7..c27b815dfa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -1110,8 +1109,45 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } assert(e.getMessage.contains("Cannot create encoder for Option of Product type")) } + + test ("SPARK-17460: the sizeInBytes in Statistics shouldn't overflow to a negative number") { + // Since the sizeInBytes in Statistics could exceed the limit of an Int, we should use BigInt + // instead of Int for avoiding possible overflow. + val ds = (0 to 10000).map( i => + (i, Seq((i, Seq((i, "This is really not that long of a string")))))).toDS() + val sizeInBytes = ds.logicalPlan.statistics.sizeInBytes + // sizeInBytes is 2404280404, before the fix, it overflows to a negative number + assert(sizeInBytes > 0) + } + + test("SPARK-18717: code generation works for both scala.collection.Map" + + " and scala.collection.imutable.Map") { + val ds = Seq(WithImmutableMap("hi", Map(42L -> "foo"))).toDS + checkDataset(ds.map(t => t), WithImmutableMap("hi", Map(42L -> "foo"))) + + val ds2 = Seq(WithMap("hi", Map(42L -> "foo"))).toDS + checkDataset(ds2.map(t => t), WithMap("hi", Map(42L -> "foo"))) + } + + test("SPARK-18746: add implicit encoder for BigDecimal, date, timestamp") { + // For this implicit encoder, 18 is the default scale + assert(spark.range(1).map { x => new java.math.BigDecimal(1) }.head == + new java.math.BigDecimal(1).setScale(18)) + + assert(spark.range(1).map { x => scala.math.BigDecimal(1, 18) }.head == + scala.math.BigDecimal(1, 18)) + + assert(spark.range(1).map { x => new java.sql.Date(2016, 12, 12) }.head == + new java.sql.Date(2016, 12, 12)) + + assert(spark.range(1).map { x => new java.sql.Timestamp(100000) }.head == + new java.sql.Timestamp(100000)) + } } +case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) +case class WithMap(id: String, map_test: scala.collection.Map[Long, String]) + case class Generic[T](id: T, value: Double) case class OtherTuple(_1: String, _2: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d2ec3cfc05..e89599be2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1546,7 +1546,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("specifying database name for a temporary table is not allowed") { withTempPath { dir => - val path = dir.getCanonicalPath + val path = dir.toURI.toString val df = sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") df diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 1fcccd0610..c663b31351 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -21,10 +21,12 @@ import java.{lang => jl} import java.sql.{Date, Timestamp} import scala.collection.mutable +import scala.util.Random import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.test.SQLTestData.ArrayData import org.apache.spark.sql.types._ @@ -133,6 +135,40 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } + test("column stats round trip serialization") { + // Make sure we serialize and then deserialize and we will get the result data + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + stats.zip(df.schema).foreach { case ((k, v), field) => + withClue(s"column $k with type ${field.dataType}") { + val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap) + assert(roundtrip == Some(v)) + } + } + } + + test("analyze column command - result verification") { + // (data.head.productArity - 1) because the last column does not support stats collection. + assert(stats.size == data.head.productArity - 1) + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + checkColStats(df, stats) + } + + test("column stats collection for null columns") { + val dataTypes: Seq[(DataType, Int)] = Seq( + BooleanType, ByteType, ShortType, IntegerType, LongType, + DoubleType, FloatType, DecimalType.SYSTEM_DEFAULT, + StringType, BinaryType, DateType, TimestampType + ).zipWithIndex + + val df = sql("select " + dataTypes.map { case (tpe, idx) => + s"cast(null as ${tpe.sql}) as col$idx" + }.mkString(", ")) + + val expectedColStats = dataTypes.map { case (tpe, idx) => + (s"col$idx", ColumnStat(0, None, None, 1, tpe.defaultSize.toLong, tpe.defaultSize.toLong)) + } + checkColStats(df, mutable.LinkedHashMap(expectedColStats: _*)) + } } @@ -180,39 +216,48 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils "ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8) ) - test("column stats round trip serialization") { - // Make sure we serialize and then deserialize and we will get the result data - val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) - stats.zip(df.schema).foreach { case ((k, v), field) => - withClue(s"column $k with type ${field.dataType}") { - val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap) - assert(roundtrip == Some(v)) - } - } - } - - test("analyze column command - result verification") { - val tableName = "column_stats_test2" - // (data.head.productArity - 1) because the last column does not support stats collection. - assert(stats.size == data.head.productArity - 1) - val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + private val randomName = new Random(31) + /** + * Compute column stats for the given DataFrame and compare it with colStats. + */ + def checkColStats( + df: DataFrame, + colStats: mutable.LinkedHashMap[String, ColumnStat]): Unit = { + val tableName = "column_stats_test_" + randomName.nextInt(1000) withTable(tableName) { df.write.saveAsTable(tableName) // Collect statistics - sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) + sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + + colStats.keys.mkString(", ")) // Validate statistics val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) assert(table.stats.isDefined) - assert(table.stats.get.colStats.size == stats.size) + assert(table.stats.get.colStats.size == colStats.size) - stats.foreach { case (k, v) => + colStats.foreach { case (k, v) => withClue(s"column $k") { assert(table.stats.get.colStats(k) == v) } } } } + + // This test will be run twice: with and without Hive support + test("SPARK-18856: non-empty partitioned table should not report zero size") { + withTable("ds_tbl", "hive_tbl") { + spark.range(100).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("ds_tbl") + val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.statistics + assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") + + if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { + sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)") + sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1") + val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.statistics + assert(stats2.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 73a5394496..2ef8b18c04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -54,6 +54,24 @@ class SubquerySuite extends QueryTest with SharedSQLContext { t.createOrReplaceTempView("t") } + test("SPARK-18854 numberedTreeString for subquery") { + val df = sql("select * from range(10) where id not in " + + "(select id from range(2) union all select id from range(2))") + + // The depth first traversal of the plan tree + val dfs = Seq("Project", "Filter", "Union", "Project", "Range", "Project", "Range", "Range") + val numbered = df.queryExecution.analyzed.numberedTreeString.split("\n") + + // There should be 8 plan nodes in total + assert(numbered.size == dfs.size) + + for (i <- dfs.indices) { + val node = df.queryExecution.analyzed(i) + assert(node.nodeName == dfs(i)) + assert(numbered(i).contains(node.nodeName)) + } + } + test("rdd deserialization does not crash [SPARK-15791]") { sql("select (select 1 as b) as b").rdd.count() } @@ -491,7 +509,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { sql("select (select sum(-1) from t t2 where t1.c2 = t2.c1 group by t2.c2) sum from t t1") } assert(errMsg.getMessage.contains( - "a GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:")) + "A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns:")) } } @@ -789,4 +807,22 @@ class SubquerySuite extends QueryTest with SharedSQLContext { } } } + + // Generate operator + test("Correlated subqueries in LATERAL VIEW") { + withTempView("t1", "t2") { + Seq((1, 1), (2, 0)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq[(Int, Array[Int])]((1, Array(1, 2)), (2, Array(-1, -3))) + .toDF("c1", "arr_c2").createTempView("t2") + checkAnswer( + sql( + """ + | select c2 + | from t1 + | where exists (select * + | from t2 lateral view explode(arr_c2) q as c2 + where t1.c1 = t2.c1)""".stripMargin), + Row(1) :: Row(0) :: Nil) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 547d3c1abe..e8ccefa69a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -64,7 +64,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { data.write.parquet(dir.getCanonicalPath) spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("test_table") val answer = sql("select input_file_name() from test_table").head().getString(0) - assert(answer.contains(dir.getCanonicalPath)) + assert(answer.contains(dir.toURI.getPath)) assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2) spark.catalog.dropTempView("test_table") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 5ef5f8ee77..1a5e5226c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} -import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} // TODO: merge this with DDLSuite (SPARK-14441) @@ -660,6 +660,34 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed2, expected2) } + test("alter table: change column name/type/comment") { + val sql1 = "ALTER TABLE table_name CHANGE COLUMN col_old_name col_new_name INT" + val sql2 = "ALTER TABLE table_name CHANGE COLUMN col_name col_name INT COMMENT 'new_comment'" + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + val tableIdent = TableIdentifier("table_name", None) + val expected1 = AlterTableChangeColumnCommand( + tableIdent, + "col_old_name", + StructField("col_new_name", IntegerType)) + val expected2 = AlterTableChangeColumnCommand( + tableIdent, + "col_name", + StructField("col_name", IntegerType).withComment("new_comment")) + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } + + test("alter table: change column position (not supported)") { + assertUnsupported("ALTER TABLE table_name CHANGE COLUMN col_old_name col_new_name INT FIRST") + assertUnsupported( + "ALTER TABLE table_name CHANGE COLUMN col_old_name col_new_name INT AFTER other_col") + } + + test("alter table: change column in partition spec") { + assertUnsupported("ALTER TABLE table_name PARTITION (a='1', a='2') CHANGE COLUMN a new_a INT") + } + test("alter table: touch (not supported)") { assertUnsupported("ALTER TABLE table_name TOUCH") assertUnsupported("ALTER TABLE table_name TOUCH PARTITION (dt='2008-08-08', country='us')") @@ -695,19 +723,6 @@ class DDLCommandSuite extends PlanTest { assertUnsupported("ALTER TABLE table_name SKEWED BY (key) ON (1,5,6) STORED AS DIRECTORIES") } - test("alter table: change column name/type/position/comment (not allowed)") { - assertUnsupported("ALTER TABLE table_name CHANGE col_old_name col_new_name INT") - assertUnsupported( - """ - |ALTER TABLE table_name CHANGE COLUMN col_old_name col_new_name INT - |COMMENT 'col_comment' FIRST CASCADE - """.stripMargin) - assertUnsupported(""" - |ALTER TABLE table_name CHANGE COLUMN col_old_name col_new_name INT - |COMMENT 'col_comment' AFTER column_name RESTRICT - """.stripMargin) - } - test("alter table: add/replace columns (not allowed)") { assertUnsupported( """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 2a004ba2f1..ac3878e849 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -84,12 +84,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { serde = None, compressed = false, properties = Map()) + val metadata = new MetadataBuilder() + .putString("key", "value") + .build() CatalogTable( identifier = name, tableType = CatalogTableType.EXTERNAL, storage = storage, schema = new StructType() - .add("col1", "int") + .add("col1", "int", nullable = true, metadata = metadata) .add("col2", "string") .add("a", "int") .add("b", "int"), @@ -312,7 +315,13 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { pathToNonPartitionedTable, userSpecifiedSchema = Option("num int, str string"), userSpecifiedPartitionCols = partitionCols, - expectedSchema = new StructType().add("num", IntegerType).add("str", StringType), + expectedSchema = if (partitionCols.isDefined) { + // we skipped inference, so the partition col is ordered at the end + new StructType().add("str", StringType).add("num", IntegerType) + } else { + // no inferred partitioning, so schema is in original order + new StructType().add("num", IntegerType).add("str", StringType) + }, expectedPartitionCols = partitionCols.map(Seq(_)).getOrElse(Seq.empty[String])) } } @@ -336,7 +345,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val e = intercept[AnalysisException] { sql("CREATE TABLE tbl(a int, b string) USING json PARTITIONED BY (c)") } - assert(e.message == "partition column c is not defined in table `tbl`, " + + assert(e.message == "partition column c is not defined in table tbl, " + "defined table columns are: a, b") } @@ -344,7 +353,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val e = intercept[AnalysisException] { sql("CREATE TABLE tbl(a int, b string) USING json CLUSTERED BY (c) INTO 4 BUCKETS") } - assert(e.message == "bucket column c is not defined in table `tbl`, " + + assert(e.message == "bucket column c is not defined in table tbl, " + "defined table columns are: a, b") } @@ -565,7 +574,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val table = catalog.getTableMetadata(TableIdentifier("tbl")) assert(table.tableType == CatalogTableType.MANAGED) assert(table.provider == Some("parquet")) - assert(table.schema == new StructType().add("a", IntegerType).add("b", IntegerType)) + // a is ordered last since it is a user-specified partitioning column + assert(table.schema == new StructType().add("b", IntegerType).add("a", IntegerType)) assert(table.partitionColumnNames == Seq("a")) } } @@ -764,6 +774,14 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { testSetSerdePartition(isDatasourceTable = true) } + test("alter table: change column") { + testChangeColumn(isDatasourceTable = false) + } + + test("alter table: change column (datasource table)") { + testChangeColumn(isDatasourceTable = true) + } + test("alter table: bucketing is not supported") { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) @@ -878,7 +896,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { testRenamePartitions(isDatasourceTable = true) } - test("show tables") { + test("show table extended") { withTempView("show1a", "show2b") { sql( """ @@ -902,9 +920,9 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { |) """.stripMargin) assert( - sql("SHOW TABLES EXTENDED LIKE 'show*'").count() >= 2) + sql("SHOW TABLE EXTENDED LIKE 'show*'").count() >= 2) assert( - sql("SHOW TABLES EXTENDED LIKE 'show*'").schema == + sql("SHOW TABLE EXTENDED LIKE 'show*'").schema == StructType(StructField("database", StringType, false) :: StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: @@ -1361,6 +1379,26 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { Set(Map("a" -> "1", "b" -> "p"), Map("a" -> "20", "b" -> "c"), Map("a" -> "3", "b" -> "p"))) } + private def testChangeColumn(isDatasourceTable: Boolean): Unit = { + val catalog = spark.sessionState.catalog + val resolver = spark.sessionState.conf.resolver + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + def getMetadata(colName: String): Metadata = { + val column = catalog.getTableMetadata(tableIdent).schema.fields.find { field => + resolver(field.name, colName) + } + column.map(_.metadata).getOrElse(Metadata.empty) + } + // Ensure that change column will preserve other metadata fields. + sql("ALTER TABLE dbx.tab1 CHANGE COLUMN col1 col1 INT COMMENT 'this is col1'") + assert(getMetadata("col1").getString("key") == "value") + } + test("drop build-in function") { Seq("true", "false").foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BucketingUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BucketingUtilsSuite.scala new file mode 100644 index 0000000000..9d892bbdba --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BucketingUtilsSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.SparkFunSuite + +class BucketingUtilsSuite extends SparkFunSuite { + + test("generate bucket id") { + assert(BucketingUtils.bucketIdToString(0) == "_00000") + assert(BucketingUtils.bucketIdToString(10) == "_00010") + assert(BucketingUtils.bucketIdToString(999999) == "_999999") + } + + test("match bucket ids") { + def testCase(filename: String, expected: Option[Int]): Unit = withClue(s"name: $filename") { + assert(BucketingUtils.getBucketId(filename) == expected) + } + + testCase("a_1", Some(1)) + testCase("a_1.txt", Some(1)) + testCase("a_9999999", Some(9999999)) + testCase("a_9999999.txt", Some(9999999)) + testCase("a_1.c2.txt", Some(1)) + testCase("a_1.", Some(1)) + + testCase("a_1:txt", None) + testCase("a_1-c2.txt", None) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index d900ce7bb2..f36162858b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -476,6 +476,17 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } } + test("[SPARK-18753] keep pushed-down null literal as a filter in Spark-side post-filter") { + val ds = Seq(Tuple1(Some(true)), Tuple1(None), Tuple1(Some(false))).toDS() + withTempPath { p => + val path = p.getAbsolutePath + ds.write.parquet(path) + val readBack = spark.read.parquet(path).filter($"_1" === "true") + val filtered = ds.filter($"_1" === "true").toDF() + checkAnswer(readBack, filtered) + } + } + // Helpers for checking the arguments passed to the FileFormat. protected val checkPartitionSchema = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 598e44ec8c..a324183b43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -68,7 +68,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Utils.tryWithResource(factory.createParser(writer.toString)) { jsonParser => jsonParser.nextToken() - val converter = parser.makeRootConverter(dataType) + val converter = parser.makeConverter(dataType) converter.apply(jsonParser) } } @@ -845,7 +845,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Loading a JSON dataset from a text file with SQL") { val dir = Utils.createTempDir() dir.delete() - val path = dir.getCanonicalPath + val path = dir.toURI.toString primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 22e35a1bc0..f433a74da8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -969,4 +969,15 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha )) } } + + test("SPARK-18108 Parquet reader fails when data column types conflict with partition ones") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = Seq((1L, 2.0)).toDF("a", "b") + df.write.parquet(s"$path/a=1") + checkAnswer(spark.read.parquet(s"$path"), Seq(Row(1, 2.0))) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 83db81ea3f..119d6e25df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{LongType, ShortType} +import org.apache.spark.util.Utils /** * Test various broadcast join operators. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala new file mode 100644 index 0000000000..81bea2fef8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.api.python.PythonFunction +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In} +import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.BooleanType + +class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.newProductEncoder + import testImplicits.localSeqToDatasetHolder + + override def beforeAll(): Unit = { + super.beforeAll() + spark.udf.registerPython("dummyPythonUDF", new MyDummyPythonUDF) + } + + override def afterAll(): Unit = { + spark.sessionState.functionRegistry.dropFunction("dummyPythonUDF") + super.afterAll() + } + + test("Python UDF: push down deterministic FilterExec predicates") { + val df = Seq(("Hello", 4)).toDF("a", "b") + .where("dummyPythonUDF(b) and dummyPythonUDF(a) and a in (3, 4)") + val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { + case f @ FilterExec( + And(_: AttributeReference, _: AttributeReference), + InputAdapter(_: BatchEvalPythonExec)) => f + case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b + } + assert(qualifiedPlanNodes.size == 2) + } + + test("Nested Python UDF: push down deterministic FilterExec predicates") { + val df = Seq(("Hello", 4)).toDF("a", "b") + .where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)") + val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { + case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec)) => f + case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b + } + assert(qualifiedPlanNodes.size == 2) + } + + test("Python UDF: no push down on non-deterministic") { + val df = Seq(("Hello", 4)).toDF("a", "b") + .where("b > 4 and dummyPythonUDF(a) and rand() > 3") + val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { + case f @ FilterExec( + And(_: AttributeReference, _: GreaterThan), + InputAdapter(_: BatchEvalPythonExec)) => f + case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b + } + assert(qualifiedPlanNodes.size == 2) + } + + test("Python UDF: no push down on predicates starting from the first non-deterministic") { + val df = Seq(("Hello", 4)).toDF("a", "b") + .where("dummyPythonUDF(a) and rand() > 3 and b > 4") + val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { + case f @ FilterExec(And(_: And, _: GreaterThan), InputAdapter(_: BatchEvalPythonExec)) => f + } + assert(qualifiedPlanNodes.size == 1) + } + + test("Python UDF refers to the attributes from more than one child") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = Seq(("Hello", 4)).toDF("c", "d") + val joinDF = df.join(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)") + + val e = intercept[RuntimeException] { + joinDF.queryExecution.executedPlan + }.getMessage + assert(Seq("Invalid PythonUDF dummyUDF", "requires attributes from more than one child") + .forall(e.contains)) + } +} + +// This Python UDF is dummy and just for testing. Unable to execute. +class DummyUDF extends PythonFunction( + command = Array[Byte](), + envVars = Map("" -> "").asJava, + pythonIncludes = ArrayBuffer("").asJava, + pythonExec = "", + pythonVer = "", + broadcastVars = null, + accumulator = null) + +class MyDummyPythonUDF + extends UserDefinedPythonFunction(name = "dummyUDF", func = new DummyUDF, dataType = BooleanType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index e511fda579..435d874d75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -104,6 +104,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext withFakeCompactibleFileStreamLog( fileCleanupDelayMs = Long.MaxValue, defaultCompactInterval = 3, + defaultMinBatchesToRetain = 1, compactibleLog => { assert("0" === compactibleLog.batchIdToPath(0).getName) assert("1" === compactibleLog.batchIdToPath(1).getName) @@ -118,6 +119,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext withFakeCompactibleFileStreamLog( fileCleanupDelayMs = Long.MaxValue, defaultCompactInterval = 3, + defaultMinBatchesToRetain = 1, compactibleLog => { val logs = Array("entry_1", "entry_2", "entry_3") val expected = s"""${FakeCompactibleFileStreamLog.VERSION} @@ -138,6 +140,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext withFakeCompactibleFileStreamLog( fileCleanupDelayMs = Long.MaxValue, defaultCompactInterval = 3, + defaultMinBatchesToRetain = 1, compactibleLog => { val logs = s"""${FakeCompactibleFileStreamLog.VERSION} |"entry_1" @@ -157,6 +160,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext withFakeCompactibleFileStreamLog( fileCleanupDelayMs = Long.MaxValue, defaultCompactInterval = 3, + defaultMinBatchesToRetain = 1, compactibleLog => { for (batchId <- 0 to 10) { compactibleLog.add(batchId, Array("some_path_" + batchId)) @@ -175,6 +179,7 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext withFakeCompactibleFileStreamLog( fileCleanupDelayMs = 0, defaultCompactInterval = 3, + defaultMinBatchesToRetain = 1, compactibleLog => { val fs = compactibleLog.metadataPath.getFileSystem(spark.sessionState.newHadoopConf()) @@ -194,25 +199,29 @@ class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext compactibleLog.add(1, Array("some_path_1")) assert(Set("0", "1") === listBatchFiles()) compactibleLog.add(2, Array("some_path_2")) - assert(Set("2.compact") === listBatchFiles()) + assert(Set("0", "1", "2.compact") === listBatchFiles()) compactibleLog.add(3, Array("some_path_3")) assert(Set("2.compact", "3") === listBatchFiles()) compactibleLog.add(4, Array("some_path_4")) assert(Set("2.compact", "3", "4") === listBatchFiles()) compactibleLog.add(5, Array("some_path_5")) - assert(Set("5.compact") === listBatchFiles()) + assert(Set("2.compact", "3", "4", "5.compact") === listBatchFiles()) + compactibleLog.add(6, Array("some_path_6")) + assert(Set("5.compact", "6") === listBatchFiles()) }) } private def withFakeCompactibleFileStreamLog( fileCleanupDelayMs: Long, defaultCompactInterval: Int, + defaultMinBatchesToRetain: Int, f: FakeCompactibleFileStreamLog => Unit ): Unit = { withTempDir { file => val compactibleLog = new FakeCompactibleFileStreamLog( fileCleanupDelayMs, defaultCompactInterval, + defaultMinBatchesToRetain, spark, file.getCanonicalPath) f(compactibleLog) @@ -227,6 +236,7 @@ object FakeCompactibleFileStreamLog { class FakeCompactibleFileStreamLog( _fileCleanupDelayMs: Long, _defaultCompactInterval: Int, + _defaultMinBatchesToRetain: Int, sparkSession: SparkSession, path: String) extends CompactibleFileStreamLog[String]( @@ -241,5 +251,7 @@ class FakeCompactibleFileStreamLog( override protected def defaultCompactInterval: Int = _defaultCompactInterval + override protected val minBatchesToRetain: Int = _defaultMinBatchesToRetain + override def compactLogs(logs: Seq[String]): Seq[String] = logs } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index e046fee0c0..7e0de5e265 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -151,10 +151,11 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { testWithUninterruptibleThread("delete expired file") { // Set FILE_SINK_LOG_CLEANUP_DELAY to 0 so that we can detect the deleting behaviour - // deterministically + // deterministically and one min batches to retain withSQLConf( SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3", - SQLConf.FILE_SINK_LOG_CLEANUP_DELAY.key -> "0") { + SQLConf.FILE_SINK_LOG_CLEANUP_DELAY.key -> "0", + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "1") { withFileStreamSinkLog { sinkLog => val fs = sinkLog.metadataPath.getFileSystem(spark.sessionState.newHadoopConf()) @@ -174,17 +175,71 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { sinkLog.add(1, Array(newFakeSinkFileStatus("/a/b/1", FileStreamSinkLog.ADD_ACTION))) assert(Set("0", "1") === listBatchFiles()) sinkLog.add(2, Array(newFakeSinkFileStatus("/a/b/2", FileStreamSinkLog.ADD_ACTION))) - assert(Set("2.compact") === listBatchFiles()) + assert(Set("0", "1", "2.compact") === listBatchFiles()) sinkLog.add(3, Array(newFakeSinkFileStatus("/a/b/3", FileStreamSinkLog.ADD_ACTION))) assert(Set("2.compact", "3") === listBatchFiles()) sinkLog.add(4, Array(newFakeSinkFileStatus("/a/b/4", FileStreamSinkLog.ADD_ACTION))) assert(Set("2.compact", "3", "4") === listBatchFiles()) sinkLog.add(5, Array(newFakeSinkFileStatus("/a/b/5", FileStreamSinkLog.ADD_ACTION))) - assert(Set("5.compact") === listBatchFiles()) + assert(Set("2.compact", "3", "4", "5.compact") === listBatchFiles()) + sinkLog.add(6, Array(newFakeSinkFileStatus("/a/b/6", FileStreamSinkLog.ADD_ACTION))) + assert(Set("5.compact", "6") === listBatchFiles()) + } + } + + withSQLConf( + SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3", + SQLConf.FILE_SINK_LOG_CLEANUP_DELAY.key -> "0", + SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2") { + withFileStreamSinkLog { sinkLog => + val fs = sinkLog.metadataPath.getFileSystem(spark.sessionState.newHadoopConf()) + + def listBatchFiles(): Set[String] = { + fs.listStatus(sinkLog.metadataPath).map(_.getPath.getName).filter { fileName => + try { + getBatchIdFromFileName(fileName) + true + } catch { + case _: NumberFormatException => false + } + }.toSet + } + + sinkLog.add(0, Array(newFakeSinkFileStatus("/a/b/0", FileStreamSinkLog.ADD_ACTION))) + assert(Set("0") === listBatchFiles()) + sinkLog.add(1, Array(newFakeSinkFileStatus("/a/b/1", FileStreamSinkLog.ADD_ACTION))) + assert(Set("0", "1") === listBatchFiles()) + sinkLog.add(2, Array(newFakeSinkFileStatus("/a/b/2", FileStreamSinkLog.ADD_ACTION))) + assert(Set("0", "1", "2.compact") === listBatchFiles()) + sinkLog.add(3, Array(newFakeSinkFileStatus("/a/b/3", FileStreamSinkLog.ADD_ACTION))) + assert(Set("0", "1", "2.compact", "3") === listBatchFiles()) + sinkLog.add(4, Array(newFakeSinkFileStatus("/a/b/4", FileStreamSinkLog.ADD_ACTION))) + assert(Set("2.compact", "3", "4") === listBatchFiles()) + sinkLog.add(5, Array(newFakeSinkFileStatus("/a/b/5", FileStreamSinkLog.ADD_ACTION))) + assert(Set("2.compact", "3", "4", "5.compact") === listBatchFiles()) + sinkLog.add(6, Array(newFakeSinkFileStatus("/a/b/6", FileStreamSinkLog.ADD_ACTION))) + assert(Set("2.compact", "3", "4", "5.compact", "6") === listBatchFiles()) + sinkLog.add(7, Array(newFakeSinkFileStatus("/a/b/7", FileStreamSinkLog.ADD_ACTION))) + assert(Set("5.compact", "6", "7") === listBatchFiles()) } } } + test("read Spark 2.1.0 log format") { + assert(readFromResource("file-sink-log-version-2.1.0") === Seq( + // SinkFileStatus("/a/b/0", 100, false, 100, 1, 100, FileStreamSinkLog.ADD_ACTION), -> deleted + SinkFileStatus("/a/b/1", 100, false, 100, 1, 100, FileStreamSinkLog.ADD_ACTION), + SinkFileStatus("/a/b/2", 200, false, 200, 1, 100, FileStreamSinkLog.ADD_ACTION), + SinkFileStatus("/a/b/3", 300, false, 300, 1, 100, FileStreamSinkLog.ADD_ACTION), + SinkFileStatus("/a/b/4", 400, false, 400, 1, 100, FileStreamSinkLog.ADD_ACTION), + SinkFileStatus("/a/b/5", 500, false, 500, 1, 100, FileStreamSinkLog.ADD_ACTION), + SinkFileStatus("/a/b/6", 600, false, 600, 1, 100, FileStreamSinkLog.ADD_ACTION), + SinkFileStatus("/a/b/7", 700, false, 700, 1, 100, FileStreamSinkLog.ADD_ACTION), + SinkFileStatus("/a/b/8", 800, false, 800, 1, 100, FileStreamSinkLog.ADD_ACTION), + SinkFileStatus("/a/b/9", 900, false, 900, 3, 200, FileStreamSinkLog.ADD_ACTION) + )) + } + /** * Create a fake SinkFileStatus using path and action. Most of tests don't care about other fields * in SinkFileStatus. @@ -206,4 +261,10 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { f(sinkLog) } } + + private def readFromResource(dir: String): Seq[SinkFileStatus] = { + val input = getClass.getResource(s"/structured-streaming/$dir") + val log = new FileStreamSinkLog(FileStreamSinkLog.VERSION, spark, input.toString) + log.allFiles() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala deleted file mode 100644 index 4a47c04d3f..0000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import java.io.File -import java.net.URI - -import scala.util.Random - -import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.streaming.ExistsThrowsExceptionFileSystem._ -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.StructType - -class FileStreamSourceSuite extends SparkFunSuite with SharedSQLContext { - - import FileStreamSource._ - - test("SeenFilesMap") { - val map = new SeenFilesMap(maxAgeMs = 10) - - map.add("a", 5) - assert(map.size == 1) - map.purge() - assert(map.size == 1) - - // Add a new entry and purge should be no-op, since the gap is exactly 10 ms. - map.add("b", 15) - assert(map.size == 2) - map.purge() - assert(map.size == 2) - - // Add a new entry that's more than 10 ms than the first entry. We should be able to purge now. - map.add("c", 16) - assert(map.size == 3) - map.purge() - assert(map.size == 2) - - // Override existing entry shouldn't change the size - map.add("c", 25) - assert(map.size == 2) - - // Not a new file because we have seen c before - assert(!map.isNewFile("c", 20)) - - // Not a new file because timestamp is too old - assert(!map.isNewFile("d", 5)) - - // Finally a new file: never seen and not too old - assert(map.isNewFile("e", 20)) - } - - test("SeenFilesMap should only consider a file old if it is earlier than last purge time") { - val map = new SeenFilesMap(maxAgeMs = 10) - - map.add("a", 20) - assert(map.size == 1) - - // Timestamp 5 should still considered a new file because purge time should be 0 - assert(map.isNewFile("b", 9)) - assert(map.isNewFile("b", 10)) - - // Once purge, purge time should be 10 and then b would be a old file if it is less than 10. - map.purge() - assert(!map.isNewFile("b", 9)) - assert(map.isNewFile("b", 10)) - } - - testWithUninterruptibleThread("do not recheck that files exist during getBatch") { - withTempDir { temp => - spark.conf.set( - s"fs.$scheme.impl", - classOf[ExistsThrowsExceptionFileSystem].getName) - // add the metadata entries as a pre-req - val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir - val metadataLog = - new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, dir.getAbsolutePath) - assert(metadataLog.add(0, Array(FileEntry(s"$scheme:///file1", 100L, 0)))) - - val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), Nil, - dir.getAbsolutePath, Map.empty) - // this method should throw an exception if `fs.exists` is called during resolveRelation - newSource.getBatch(None, LongOffset(1)) - } - } -} - -/** Fake FileSystem to test whether the method `fs.exists` is called during - * `DataSource.resolveRelation`. - */ -class ExistsThrowsExceptionFileSystem extends RawLocalFileSystem { - override def getUri: URI = { - URI.create(s"$scheme:///") - } - - override def exists(f: Path): Boolean = { - throw new IllegalArgumentException("Exists shouldn't have been called!") - } - - /** Simply return an empty file for now. */ - override def listStatus(file: Path): Array[FileStatus] = { - val emptyFile = new FileStatus() - emptyFile.setPath(file) - Array(emptyFile) - } -} - -object ExistsThrowsExceptionFileSystem { - val scheme = s"FileStreamSourceSuite${math.abs(Random.nextInt)}fs" -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index ee6261036f..9137d650e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -171,7 +171,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf } } - test("foreach with watermark") { + test("foreach with watermark: complete") { val inputData = MemoryStream[Int] val windowedAggregation = inputData.toDF() @@ -204,6 +204,72 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf query.stop() } } + + test("foreach with watermark: append") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"count".as[Long]) + .map(_.toInt) + .repartition(1) + + val query = windowedAggregation + .writeStream + .outputMode(OutputMode.Append) + .foreach(new TestForeachWriter()) + .start() + try { + inputData.addData(10, 11, 12) + query.processAllAvailable() + inputData.addData(25) // Advance watermark to 15 seconds + query.processAllAvailable() + inputData.addData(25) // Evict items less than previous watermark + query.processAllAvailable() + + // There should be 3 batches and only does the last batch contain a value. + val allEvents = ForeachSinkSuite.allEvents() + assert(allEvents.size === 3) + val expectedEvents = Seq( + Seq( + ForeachSinkSuite.Open(partition = 0, version = 0), + ForeachSinkSuite.Close(None) + ), + Seq( + ForeachSinkSuite.Open(partition = 0, version = 1), + ForeachSinkSuite.Close(None) + ), + Seq( + ForeachSinkSuite.Open(partition = 0, version = 2), + ForeachSinkSuite.Process(value = 3), + ForeachSinkSuite.Close(None) + ) + ) + assert(allEvents === expectedEvents) + } finally { + query.stop() + } + } + + test("foreach sink should support metrics") { + val inputData = MemoryStream[Int] + val query = inputData.toDS() + .writeStream + .foreach(new TestForeachWriter()) + .start() + try { + inputData.addData(10, 11, 12) + query.processAllAvailable() + val recentProgress = query.recentProgress.filter(_.numInputRows != 0).headOption + assert(recentProgress.isDefined && recentProgress.get.numInputRows === 3, + s"recentProgress[${query.recentProgress.toList}] doesn't contain correct metrics") + } finally { + query.stop() + } + } } /** A global object to collect events in the executor */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index d03e08d9a5..d556861a48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -88,14 +88,14 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { s"fs.$scheme.impl", classOf[FakeFileSystem].getName) withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://$temp") + val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") assert(metadataLog.add(0, "batch0")) assert(metadataLog.getLatest() === Some(0 -> "batch0")) assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) - val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://$temp") + val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") assert(metadataLog2.get(0) === Some("batch0")) assert(metadataLog2.getLatest() === Some(0 -> "batch0")) assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) @@ -209,14 +209,13 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } // Open and delete - val f1 = fm.open(path) + fm.open(path).close() fm.delete(path) assert(!fm.exists(path)) intercept[IOException] { fm.open(path) } - fm.delete(path) // should not throw exception - f1.close() + fm.delete(path) // should not throw exception // Rename val path1 = new Path(s"$dir/file1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala similarity index 90% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 4e9fba9dba..ca724fc5cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -15,15 +15,14 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.sql.execution.streaming import scala.language.implicitConversions import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.streaming.{OutputMode, StreamTest} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -37,7 +36,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Append output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, InternalOutputModes.Append) + val sink = new MemorySink(schema, OutputMode.Append) // Before adding data, check output assert(sink.latestBatchId === None) @@ -71,7 +70,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Update output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, InternalOutputModes.Update) + val sink = new MemorySink(schema, OutputMode.Update) // Before adding data, check output assert(sink.latestBatchId === None) @@ -105,7 +104,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Complete output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, InternalOutputModes.Complete) + val sink = new MemorySink(schema, OutputMode.Complete) // Before adding data, check output assert(sink.latestBatchId === None) @@ -138,7 +137,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { } - test("registering as a table in Append output mode") { + test("registering as a table in Append output mode - supported") { val input = MemoryStream[Int] val query = input.toDF().writeStream .format("memory") @@ -161,7 +160,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { query.stop() } - test("registering as a table in Complete output mode") { + test("registering as a table in Complete output mode - supported") { val input = MemoryStream[Int] val query = input.toDF() .groupBy("value") @@ -187,9 +186,23 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { query.stop() } + test("registering as a table in Update output mode - not supported") { + val input = MemoryStream[Int] + val df = input.toDF() + .groupBy("value") + .count() + intercept[AnalysisException] { + df.writeStream + .format("memory") + .outputMode("update") + .queryName("memStream") + .start() + } + } + test("MemoryPlan statistics") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, InternalOutputModes.Append) + val sink = new MemorySink(schema, OutputMode.Append) val plan = new MemoryPlan(sink) // Before adding data, check output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala index 3afd11fa46..bb4274a162 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLogSuite.scala @@ -27,10 +27,19 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { /** test string offset type */ case class StringOffset(override val json: String) extends Offset - testWithUninterruptibleThread("serialization - deserialization") { + test("OffsetSeqMetadata - deserialization") { + assert(OffsetSeqMetadata(0, 0) === OffsetSeqMetadata("""{}""")) + assert(OffsetSeqMetadata(1, 0) === OffsetSeqMetadata("""{"batchWatermarkMs":1}""")) + assert(OffsetSeqMetadata(0, 2) === OffsetSeqMetadata("""{"batchTimestampMs":2}""")) + assert( + OffsetSeqMetadata(1, 2) === + OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}""")) + } + + testWithUninterruptibleThread("OffsetSeqLog - serialization - deserialization") { withTempDir { temp => val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir - val metadataLog = new OffsetSeqLog(spark, dir.getAbsolutePath) + val metadataLog = new OffsetSeqLog(spark, dir.getAbsolutePath) val batch0 = OffsetSeq.fill(LongOffset(0), LongOffset(1), LongOffset(2)) val batch1 = OffsetSeq.fill(StringOffset("one"), StringOffset("two"), StringOffset("three")) @@ -60,4 +69,20 @@ class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext { Array(0 -> batch0Serialized, 1 -> batch1Serialized)) } } + + test("read Spark 2.1.0 log format") { + val (batchId, offsetSeq) = readFromResource("offset-log-version-2.1.0") + assert(batchId === 0) + assert(offsetSeq.offsets === Seq( + Some(SerializedOffset("""{"logOffset":345}""")), + Some(SerializedOffset("""{"topic-0":{"0":1}}""")) + )) + assert(offsetSeq.metadata === Some(OffsetSeqMetadata(0L, 1480981499528L))) + } + + private def readFromResource(dir: String): (Long, OffsetSeq) = { + val input = getClass.getResource(s"/structured-streaming/$dir") + val log = new OffsetSeqLog(spark, input.toString) + log.getLatest().get + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetadataSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetadataSuite.scala new file mode 100644 index 0000000000..87f8004ab9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetadataSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.File +import java.util.UUID + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.streaming.StreamTest + +class StreamMetadataSuite extends StreamTest { + + test("writing and reading") { + withTempDir { dir => + val id = UUID.randomUUID.toString + val metadata = StreamMetadata(id) + val file = new Path(new File(dir, "test").toString) + StreamMetadata.write(metadata, file, hadoopConf) + val readMetadata = StreamMetadata.read(file, hadoopConf) + assert(readMetadata.nonEmpty) + assert(readMetadata.get.id === id) + } + } + + test("read Spark 2.1.0 format") { + // query-metadata-logs-version-2.1.0.txt has the execution metadata generated by Spark 2.1.0 + assert( + readForResource("query-metadata-logs-version-2.1.0.txt") === + StreamMetadata("d366a8bf-db79-42ca-b5a4-d9ca0a11d63e")) + } + + private def readForResource(fileName: String): StreamMetadata = { + val input = getClass.getResource(s"/structured-streaming/$fileName") + StreamMetadata.read(new Path(input.toString), hadoopConf).get + } + + private val hadoopConf = new Configuration() +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 05fc7345a7..6b38b6a097 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -376,7 +376,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val opId = 0 val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString val storeId = StateStoreId(dir, opId, 0) - val storeConf = StateStoreConf.empty + val sqlConf = new SQLConf() + sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) + val storeConf = StateStoreConf(sqlConf) val hadoopConf = new Configuration() val provider = new HDFSBackedStateStoreProvider( storeId, keySchema, valueSchema, storeConf, hadoopConf) @@ -393,6 +395,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } } + val timeoutDuration = 60 seconds + quietly { withSpark(new SparkContext(conf)) { sc => withCoordinatorRef(sc) { coordinatorRef => @@ -401,7 +405,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Generate sufficient versions of store for snapshots generateStoreVersions() - eventually(timeout(10 seconds)) { + eventually(timeout(timeoutDuration)) { // Store should have been reported to the coordinator assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported") @@ -420,14 +424,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth generateStoreVersions() // Earliest delta file should get cleaned up - eventually(timeout(10 seconds)) { + eventually(timeout(timeoutDuration)) { assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") } // If driver decides to deactivate all instances of the store, then this instance // should be unloaded coordinatorRef.deactivateInstances(dir) - eventually(timeout(10 seconds)) { + eventually(timeout(timeoutDuration)) { assert(!StateStore.isLoaded(storeId)) } @@ -437,7 +441,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // If some other executor loads the store, then this instance should be unloaded coordinatorRef.reportActiveInstance(storeId, "other-host", "other-exec") - eventually(timeout(10 seconds)) { + eventually(timeout(timeoutDuration)) { assert(!StateStore.isLoaded(storeId)) } @@ -448,7 +452,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Verify if instance is unloaded if SparkContext is stopped - eventually(timeout(10 seconds)) { + eventually(timeout(timeoutDuration)) { require(SparkEnv.get === null) assert(!StateStore.isLoaded(storeId)) assert(!StateStore.isMaintenanceRunning) @@ -458,7 +462,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("SPARK-18342: commit fails when rename fails") { import RenameReturnsFalseFileSystem._ - val dir = scheme + "://" + Utils.createDirectory(tempDir, Random.nextString(5)).toString + val dir = scheme + "://" + Utils.createDirectory(tempDir, Random.nextString(5)).toURI.getPath val conf = new Configuration() conf.set(s"fs.$scheme.impl", classOf[RenameReturnsFalseFileSystem].getName) val provider = newStoreProvider(dir = dir, hadoopConf = conf) @@ -606,6 +610,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth ): HDFSBackedStateStoreProvider = { val sqlConf = new SQLConf() sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot) + sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2) new HDFSBackedStateStoreProvider( StateStoreId(dir, opId, partition), keySchema, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 4a85b5975e..13284ba649 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.sources import java.io.File import org.apache.spark.sql.{AnalysisException, Row} -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -38,7 +37,7 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { |CREATE TEMPORARY TABLE jsonTable (a int, b string) |USING org.apache.spark.sql.json.DefaultSource |OPTIONS ( - | path '${path.toString}' + | path '${path.toURI.toString}' |) """.stripMargin) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index a2decadbe0..953604e4ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import java.io.File + import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -61,4 +63,39 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { assert(spark.read.parquet(path).schema.map(_.name) == Seq("j", "i")) } } + + test("maxRecordsPerFile setting in non-partitioned write path") { + withTempDir { f => + spark.range(start = 0, end = 4, step = 1, numPartitions = 1) + .write.option("maxRecordsPerFile", 1).mode("overwrite").parquet(f.getAbsolutePath) + assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + + spark.range(start = 0, end = 4, step = 1, numPartitions = 1) + .write.option("maxRecordsPerFile", 2).mode("overwrite").parquet(f.getAbsolutePath) + assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 2) + + spark.range(start = 0, end = 4, step = 1, numPartitions = 1) + .write.option("maxRecordsPerFile", -1).mode("overwrite").parquet(f.getAbsolutePath) + assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 1) + } + } + + test("maxRecordsPerFile setting in dynamic partition writes") { + withTempDir { f => + spark.range(start = 0, end = 4, step = 1, numPartitions = 1).selectExpr("id", "id id1") + .write + .partitionBy("id") + .option("maxRecordsPerFile", 1) + .mode("overwrite") + .parquet(f.getAbsolutePath) + assert(recursiveList(f).count(_.getAbsolutePath.endsWith("parquet")) == 4) + } + } + + /** Lists files recursively. */ + private def recursiveList(f: File): Array[File] = { + require(f.isDirectory) + val current = f.listFiles + current ++ current.filter(_.isDirectory).flatMap(recursiveList) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala index bef47aacd3..faf9afc49a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import org.apache.hadoop.fs.Path + import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -53,8 +55,8 @@ class TestOptionsRelation(val options: Map[String, String])(@transient val sessi // We can't get the relation directly for write path, here we put the path option in schema // metadata, so that we can test it later. override def schema: StructType = { - val metadataWithPath = pathOption.map { - path => new MetadataBuilder().putString("path", path).build() + val metadataWithPath = pathOption.map { path => + new MetadataBuilder().putString("path", path).build() } new StructType().add("i", IntegerType, true, metadataWithPath.getOrElse(Metadata.empty)) } @@ -82,15 +84,16 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { test("path option also exist for write path") { withTable("src") { - withTempPath { path => + withTempPath { p => + val path = new Path(p.getAbsolutePath).toString sql( s""" |CREATE TABLE src |USING ${classOf[TestOptionsSource].getCanonicalName} - |OPTIONS (PATH '${path.getAbsolutePath}') + |OPTIONS (PATH '$path') |AS SELECT 1 """.stripMargin) - assert(spark.table("src").schema.head.metadata.getString("path") == path.getAbsolutePath) + assert(spark.table("src").schema.head.metadata.getString("path") == path) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala similarity index 54% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 12f3c3e5ff..23f51ff11d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -17,14 +17,19 @@ package org.apache.spark.sql.streaming +import java.{util => ju} +import java.text.SimpleDateFormat +import java.util.Date + import org.scalatest.BeforeAndAfter import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} +import org.apache.spark.sql.streaming.OutputMode._ -class WatermarkSuite extends StreamTest with BeforeAndAfter with Logging { +class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Logging { import testImplicits._ @@ -49,38 +54,70 @@ class WatermarkSuite extends StreamTest with BeforeAndAfter with Logging { assert(e.getMessage contains "int") } + test("event time and watermark metrics") { + // No event time metrics when there is no watermarking + val inputData1 = MemoryStream[Int] + val aggWithoutWatermark = inputData1.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) - test("watermark metric") { - - val inputData = MemoryStream[Int] + testStream(aggWithoutWatermark, outputMode = Complete)( + AddData(inputData1, 15), + CheckAnswer((15, 1)), + assertEventStats { e => assert(e.isEmpty) }, + AddData(inputData1, 10, 12, 14), + CheckAnswer((10, 3), (15, 1)), + assertEventStats { e => assert(e.isEmpty) } + ) - val windowedAggregation = inputData.toDF() + // All event time metrics where watermarking is set + val inputData2 = MemoryStream[Int] + val aggWithWatermark = inputData2.toDF() .withColumn("eventTime", $"value".cast("timestamp")) .withWatermark("eventTime", "10 seconds") .groupBy(window($"eventTime", "5 seconds") as 'window) .agg(count("*") as 'count) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) - testStream(windowedAggregation)( - AddData(inputData, 15), + testStream(aggWithWatermark)( + AddData(inputData2, 15), CheckAnswer(), - AssertOnQuery { query => - query.lastProgress.currentWatermark === 5000 + assertEventStats { e => + assert(e.get("max") === formatTimestamp(15)) + assert(e.get("min") === formatTimestamp(15)) + assert(e.get("avg") === formatTimestamp(15)) + assert(e.get("watermark") === formatTimestamp(0)) }, - AddData(inputData, 15), + AddData(inputData2, 10, 12, 14), CheckAnswer(), - AssertOnQuery { query => - query.lastProgress.currentWatermark === 5000 + assertEventStats { e => + assert(e.get("max") === formatTimestamp(14)) + assert(e.get("min") === formatTimestamp(10)) + assert(e.get("avg") === formatTimestamp(12)) + assert(e.get("watermark") === formatTimestamp(5)) }, - AddData(inputData, 25), + AddData(inputData2, 25), CheckAnswer(), - AssertOnQuery { query => - query.lastProgress.currentWatermark === 15000 + assertEventStats { e => + assert(e.get("max") === formatTimestamp(25)) + assert(e.get("min") === formatTimestamp(25)) + assert(e.get("avg") === formatTimestamp(25)) + assert(e.get("watermark") === formatTimestamp(5)) + }, + AddData(inputData2, 25), + CheckAnswer((10, 3)), + assertEventStats { e => + assert(e.get("max") === formatTimestamp(25)) + assert(e.get("min") === formatTimestamp(25)) + assert(e.get("avg") === formatTimestamp(25)) + assert(e.get("watermark") === formatTimestamp(15)) } ) } - test("append-mode watermark aggregation") { + test("append mode") { val inputData = MemoryStream[Int] val windowedAggregation = inputData.toDF() @@ -92,11 +129,69 @@ class WatermarkSuite extends StreamTest with BeforeAndAfter with Logging { testStream(windowedAggregation)( AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckLastBatch(), + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckLastBatch(), + assertNumStateRows(3), + AddData(inputData, 25), // Emit items less than watermark and drop their state + CheckLastBatch((10, 5)), + assertNumStateRows(2), + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckLastBatch(), + assertNumStateRows(2) + ) + } + + test("update mode") { + val inputData = MemoryStream[Int] + spark.conf.set("spark.sql.shuffle.partitions", "10") + + val windowedAggregation = inputData.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedAggregation, OutputMode.Update)( + AddData(inputData, 10, 11, 12, 13, 14, 15), + CheckLastBatch((10, 5), (15, 1)), + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckLastBatch((25, 1)), + assertNumStateRows(3), + AddData(inputData, 10, 25), // Ignore 10 as its less than watermark + CheckLastBatch((25, 2)), + assertNumStateRows(2), + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckLastBatch(), + assertNumStateRows(2) + ) + } + + test("delay in months and years handled correctly") { + val currentTimeMs = System.currentTimeMillis + val currentTime = new Date(currentTimeMs) + + val input = MemoryStream[Long] + val aggWithWatermark = input.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "2 years 5 months") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + def monthsSinceEpoch(date: Date): Int = { date.getYear * 12 + date.getMonth } + + testStream(aggWithWatermark)( + AddData(input, currentTimeMs / 1000), CheckAnswer(), - AddData(inputData, 25), // Advance watermark to 15 seconds + AddData(input, currentTimeMs / 1000), CheckAnswer(), - AddData(inputData, 25), // Evict items less than previous watermark. - CheckAnswer((10, 5)) + assertEventStats { e => + assert(timestampFormat.parse(e.get("max")).getTime === (currentTimeMs / 1000) * 1000) + val watermarkTime = timestampFormat.parse(e.get("watermark")) + assert(monthsSinceEpoch(currentTime) - monthsSinceEpoch(watermarkTime) === 29) + } ) } @@ -206,4 +301,24 @@ class WatermarkSuite extends StreamTest with BeforeAndAfter with Logging { CheckAnswer((10, 1)) ) } + + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => + val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get + assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows) + true + } + + private def assertEventStats(body: ju.Map[String, String] => Unit): AssertOnQuery = { + AssertOnQuery { q => + body(q.recentProgress.filter(_.numInputRows > 0).lastOption.get.eventTime) + true + } + } + + private val timestampFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601 + timestampFormat.setTimeZone(ju.TimeZone.getTimeZone("UTC")) + + private def formatTimestamp(sec: Long): String = { + timestampFormat.format(new ju.Date(sec * 1000)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 54efae3fb4..22f59f63d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.streaming.{MemoryStream, MetadataLogFileIndex} @@ -210,6 +210,26 @@ class FileStreamSinkSuite extends StreamTest { } } + test("Update and Complete output mode not supported") { + val df = MemoryStream[Int].toDF().groupBy().count() + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + + withTempDir { dir => + + def testOutputMode(mode: String): Unit = { + val e = intercept[AnalysisException] { + df.writeStream.format("parquet").outputMode(mode).start(dir.getCanonicalPath) + } + Seq(mode, "not support").foreach { w => + assert(e.getMessage.toLowerCase.contains(w)) + } + } + + testOutputMode("update") + testOutputMode("complete") + } + } + test("parquet") { testFormat(None) // should not throw error as default format parquet when not specified testFormat(Some("parquet")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 8256c63d87..8a9fa94bea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -18,21 +18,27 @@ package org.apache.spark.sql.streaming import java.io.File +import java.net.URI -import scala.collection.mutable +import scala.util.Random +import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.scalatest.PrivateMethodTester +import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.FileStreamSource.{FileEntry, SeenFilesMap} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.ExistsThrowsExceptionFileSystem._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class FileStreamSourceTest extends StreamTest with SharedSQLContext with PrivateMethodTester { +abstract class FileStreamSourceTest + extends StreamTest with SharedSQLContext with PrivateMethodTester { import testImplicits._ @@ -61,7 +67,7 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext with Private val source = sources.head val newOffset = source.withBatchingLocked { addData(source) - source.currentOffset + 1 + new FileStreamSourceOffset(source.currentLogOffset + 1) } logInfo(s"Added file to $source at offset $newOffset") (source, newOffset) @@ -745,7 +751,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest { .format("memory") .queryName("file_data") .start() - .asInstanceOf[StreamExecution] + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery q.processAllAvailable() val memorySink = q.sink.asInstanceOf[MemorySink] val fileSource = q.logicalPlan.collect { @@ -808,21 +815,31 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } test("max files per trigger - incorrect values") { - withTempDir { case src => - def testMaxFilePerTriggerValue(value: String): Unit = { - val df = spark.readStream.option("maxFilesPerTrigger", value).text(src.getCanonicalPath) - val e = intercept[IllegalArgumentException] { - testStream(df)() - } - Seq("maxFilesPerTrigger", value, "positive integer").foreach { s => - assert(e.getMessage.contains(s)) + val testTable = "maxFilesPerTrigger_test" + withTable(testTable) { + withTempDir { case src => + def testMaxFilePerTriggerValue(value: String): Unit = { + val df = spark.readStream.option("maxFilesPerTrigger", value).text(src.getCanonicalPath) + val e = intercept[StreamingQueryException] { + // Note: `maxFilesPerTrigger` is checked in the stream thread when creating the source + val q = df.writeStream.format("memory").queryName(testTable).start() + try { + q.processAllAvailable() + } finally { + q.stop() + } + } + assert(e.getCause.isInstanceOf[IllegalArgumentException]) + Seq("maxFilesPerTrigger", value, "positive integer").foreach { s => + assert(e.getMessage.contains(s)) + } } - } - testMaxFilePerTriggerValue("not-a-integer") - testMaxFilePerTriggerValue("-1") - testMaxFilePerTriggerValue("0") - testMaxFilePerTriggerValue("10.1") + testMaxFilePerTriggerValue("not-a-integer") + testMaxFilePerTriggerValue("-1") + testMaxFilePerTriggerValue("0") + testMaxFilePerTriggerValue("10.1") + } } } @@ -835,7 +852,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest { df.explain() val q = df.writeStream.queryName("file_explain").format("memory").start() - .asInstanceOf[StreamExecution] + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery try { assert("No physical plan. Waiting for data." === q.explainInternal(false)) assert("No physical plan. Waiting for data." === q.explainInternal(true)) @@ -849,13 +867,13 @@ class FileStreamSourceSuite extends FileStreamSourceTest { val explainWithoutExtended = q.explainInternal(false) // `extended = false` only displays the physical plan. assert("Relation.*text".r.findAllMatchIn(explainWithoutExtended).size === 0) - assert("TextFileFormat".r.findAllMatchIn(explainWithoutExtended).size === 1) + assert(": Text".r.findAllMatchIn(explainWithoutExtended).size === 1) val explainWithExtended = q.explainInternal(true) // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical // plan. assert("Relation.*text".r.findAllMatchIn(explainWithExtended).size === 3) - assert("TextFileFormat".r.findAllMatchIn(explainWithExtended).size === 1) + assert(": Text".r.findAllMatchIn(explainWithExtended).size === 1) } finally { q.stop() } @@ -891,7 +909,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { // This is to avoid actually running a Spark job with 10000 tasks val df = files.filter("1 == 0").groupBy().count() - testStream(df, InternalOutputModes.Complete)( + testStream(df, OutputMode.Complete)( AddTextFileData("0", src, tmp), CheckAnswer(0) ) @@ -987,12 +1005,17 @@ class FileStreamSourceSuite extends FileStreamSourceTest { val _sources = PrivateMethod[Seq[Source]]('sources) val fileSource = (execution invokePrivate _sources()).head.asInstanceOf[FileStreamSource] - assert(fileSource.getBatch(None, LongOffset(2)).as[String].collect() === - List("keep1", "keep2", "keep3")) - assert(fileSource.getBatch(Some(LongOffset(0)), LongOffset(2)).as[String].collect() === - List("keep2", "keep3")) - assert(fileSource.getBatch(Some(LongOffset(1)), LongOffset(2)).as[String].collect() === - List("keep3")) + + def verify(startId: Option[Int], endId: Int, expected: String*): Unit = { + val start = startId.map(new FileStreamSourceOffset(_)) + val end = FileStreamSourceOffset(endId) + assert(fileSource.getBatch(start, end).as[String].collect().toSeq === expected) + } + + verify(startId = None, endId = 2, "keep1", "keep2", "keep3") + verify(startId = Some(0), endId = 1, "keep2") + verify(startId = Some(0), endId = 2, "keep2", "keep3") + verify(startId = Some(1), endId = 2, "keep3") true } ) @@ -1007,7 +1030,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { AddTextFileData("100", src, tmp), CheckAnswer("100"), AssertOnQuery { query => - val actualProgress = query.recentProgresses + val actualProgress = query.recentProgress .find(_.numInputRows > 0) .getOrElse(sys.error("Could not find records with data.")) assert(actualProgress.numInputRows === 1) @@ -1022,6 +1045,152 @@ class FileStreamSourceSuite extends FileStreamSourceTest { val options = new FileStreamOptions(Map("maxfilespertrigger" -> "1")) assert(options.maxFilesPerTrigger == Some(1)) } + + test("FileStreamSource offset - read Spark 2.1.0 offset json format") { + val offset = readOffsetFromResource("file-source-offset-version-2.1.0-json.txt") + assert(FileStreamSourceOffset(offset) === FileStreamSourceOffset(345)) + } + + test("FileStreamSource offset - read Spark 2.1.0 offset long format") { + val offset = readOffsetFromResource("file-source-offset-version-2.1.0-long.txt") + assert(FileStreamSourceOffset(offset) === FileStreamSourceOffset(345)) + } + + test("FileStreamSourceLog - read Spark 2.1.0 log format") { + assert(readLogFromResource("file-source-log-version-2.1.0") === Seq( + FileEntry("/a/b/0", 1480730949000L, 0L), + FileEntry("/a/b/1", 1480730950000L, 1L), + FileEntry("/a/b/2", 1480730950000L, 2L), + FileEntry("/a/b/3", 1480730950000L, 3L), + FileEntry("/a/b/4", 1480730951000L, 4L) + )) + } + + private def readLogFromResource(dir: String): Seq[FileEntry] = { + val input = getClass.getResource(s"/structured-streaming/$dir") + val log = new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, input.toString) + log.allFiles() + } + + private def readOffsetFromResource(file: String): SerializedOffset = { + import scala.io.Source + val str = Source.fromFile(getClass.getResource(s"/structured-streaming/$file").toURI).mkString + SerializedOffset(str.trim) + } + + test("FileStreamSource - latestFirst") { + withTempDir { src => + // Prepare two files: 1.txt, 2.txt, and make sure they have different modified time. + val f1 = stringToFile(new File(src, "1.txt"), "1") + val f2 = stringToFile(new File(src, "2.txt"), "2") + f2.setLastModified(f1.lastModified + 1000) + + def runTwoBatchesAndVerifyResults( + latestFirst: Boolean, + firstBatch: String, + secondBatch: String): Unit = { + val fileStream = createFileStream( + "text", + src.getCanonicalPath, + options = Map("latestFirst" -> latestFirst.toString, "maxFilesPerTrigger" -> "1")) + val clock = new StreamManualClock() + testStream(fileStream)( + StartStream(trigger = ProcessingTime(10), triggerClock = clock), + AssertOnQuery { _ => + // Block until the first batch finishes. + eventually(timeout(streamingTimeout)) { + assert(clock.isStreamWaitingAt(0)) + } + true + }, + CheckLastBatch(firstBatch), + AdvanceManualClock(10), + AssertOnQuery { _ => + // Block until the second batch finishes. + eventually(timeout(streamingTimeout)) { + assert(clock.isStreamWaitingAt(10)) + } + true + }, + CheckLastBatch(secondBatch) + ) + } + + // Read oldest files first, so the first batch is "1", and the second batch is "2". + runTwoBatchesAndVerifyResults(latestFirst = false, firstBatch = "1", secondBatch = "2") + + // Read latest files first, so the first batch is "2", and the second batch is "1". + runTwoBatchesAndVerifyResults(latestFirst = true, firstBatch = "2", secondBatch = "1") + } + } + + test("SeenFilesMap") { + val map = new SeenFilesMap(maxAgeMs = 10) + + map.add("a", 5) + assert(map.size == 1) + map.purge() + assert(map.size == 1) + + // Add a new entry and purge should be no-op, since the gap is exactly 10 ms. + map.add("b", 15) + assert(map.size == 2) + map.purge() + assert(map.size == 2) + + // Add a new entry that's more than 10 ms than the first entry. We should be able to purge now. + map.add("c", 16) + assert(map.size == 3) + map.purge() + assert(map.size == 2) + + // Override existing entry shouldn't change the size + map.add("c", 25) + assert(map.size == 2) + + // Not a new file because we have seen c before + assert(!map.isNewFile("c", 20)) + + // Not a new file because timestamp is too old + assert(!map.isNewFile("d", 5)) + + // Finally a new file: never seen and not too old + assert(map.isNewFile("e", 20)) + } + + test("SeenFilesMap should only consider a file old if it is earlier than last purge time") { + val map = new SeenFilesMap(maxAgeMs = 10) + + map.add("a", 20) + assert(map.size == 1) + + // Timestamp 5 should still considered a new file because purge time should be 0 + assert(map.isNewFile("b", 9)) + assert(map.isNewFile("b", 10)) + + // Once purge, purge time should be 10 and then b would be a old file if it is less than 10. + map.purge() + assert(!map.isNewFile("b", 9)) + assert(map.isNewFile("b", 10)) + } + + testWithUninterruptibleThread("do not recheck that files exist during getBatch") { + withTempDir { temp => + spark.conf.set( + s"fs.$scheme.impl", + classOf[ExistsThrowsExceptionFileSystem].getName) + // add the metadata entries as a pre-req + val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir + val metadataLog = + new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, dir.getAbsolutePath) + assert(metadataLog.add(0, Array(FileEntry(s"$scheme:///file1", 100L, 0)))) + + val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), Nil, + dir.getAbsolutePath, Map.empty) + // this method should throw an exception if `fs.exists` is called during resolveRelation + newSource.getBatch(None, FileStreamSourceOffset(1)) + } + } } class FileStreamSourceStressTestSuite extends FileStreamSourceTest { @@ -1042,3 +1211,28 @@ class FileStreamSourceStressTestSuite extends FileStreamSourceTest { Utils.deleteRecursively(tmp) } } + +/** + * Fake FileSystem to test whether the method `fs.exists` is called during + * `DataSource.resolveRelation`. + */ +class ExistsThrowsExceptionFileSystem extends RawLocalFileSystem { + override def getUri: URI = { + URI.create(s"$scheme:///") + } + + override def exists(f: Path): Boolean = { + throw new IllegalArgumentException("Exists shouldn't have been called!") + } + + /** Simply return an empty file for now. */ + override def listStatus(file: Path): Array[FileStatus] = { + val emptyFile = new FileStatus() + emptyFile.setPath(file) + Array(emptyFile) + } +} + +object ExistsThrowsExceptionFileSystem { + val scheme = s"FileStreamSourceSuite${math.abs(Random.nextInt)}fs" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 6bdf47901a..34b0ee8064 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -21,10 +21,10 @@ import scala.reflect.ClassTag import scala.util.control.ControlThrowable import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.types.{IntegerType, StructField, StructType} -import org.apache.spark.util.ManualClock class StreamSuite extends StreamTest { @@ -259,17 +259,17 @@ class StreamSuite extends StreamTest { override def stop(): Unit = {} } val df = Dataset[Int](sqlContext.sparkSession, StreamingExecutionRelation(source)) + // These error are fatal errors and should be ignored in `testStream` to not fail the test. testStream(df)( - ExpectFailure()(ClassTag(e.getClass)) + ExpectFailure(isFatalError = true)(ClassTag(e.getClass)) ) } } test("output mode API in Scala") { - val o1 = OutputMode.Append - assert(o1 === InternalOutputModes.Append) - val o2 = OutputMode.Complete - assert(o2 === InternalOutputModes.Complete) + assert(OutputMode.Append === InternalOutputModes.Append) + assert(OutputMode.Complete === InternalOutputModes.Complete) + assert(OutputMode.Update === InternalOutputModes.Update) } test("explain") { @@ -278,7 +278,8 @@ class StreamSuite extends StreamTest { // Test `explain` not throwing errors df.explain() val q = df.writeStream.queryName("memory_explain").format("memory").start() - .asInstanceOf[StreamExecution] + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery try { assert("No physical plan. Waiting for data." === q.explainInternal(false)) assert("No physical plan. Waiting for data." === q.explainInternal(true)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index a2629f7f68..709050d29b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -167,10 +167,17 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { /** Advance the trigger clock's time manually. */ case class AdvanceManualClock(timeToAdd: Long) extends StreamAction - /** Signals that a failure is expected and should not kill the test. */ - case class ExpectFailure[T <: Throwable : ClassTag]() extends StreamAction { + /** + * Signals that a failure is expected and should not kill the test. + * + * @param isFatalError if this is a fatal error. If so, the error should also be caught by + * UncaughtExceptionHandler. + */ + case class ExpectFailure[T <: Throwable : ClassTag]( + isFatalError: Boolean = false) extends StreamAction { val causeClass: Class[T] = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] - override def toString(): String = s"ExpectFailure[${causeClass.getName}]" + override def toString(): String = + s"ExpectFailure[${causeClass.getName}, isFatalError: $isFatalError]" } /** Assert that a body is true */ @@ -231,8 +238,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = { val stream = _stream.toDF() + val sparkSession = stream.sparkSession // use the session in DF, not the default session var pos = 0 - var currentPlan: LogicalPlan = stream.logicalPlan var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for @@ -240,7 +247,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { val resetConfValues = mutable.Map[String, Option[String]]() @volatile - var streamDeathCause: Throwable = null + var streamThreadDeathCause: Throwable = null // If the test doesn't manually start the stream, we do it automatically at the beginning. val startedManually = @@ -271,7 +278,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { |Output Mode: $outputMode |Stream state: $currentOffsets |Thread state: $threadState - |${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""} + |${if (streamThreadDeathCause != null) stackTraceToString(streamThreadDeathCause) else ""} | |== Sink == |${sink.toDebugString} @@ -319,7 +326,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { """.stripMargin) } - val testThread = Thread.currentThread() val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath var manualClockExpectedTime = -1L try { @@ -337,14 +343,16 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { additionalConfs.foreach(pair => { val value = - if (spark.conf.contains(pair._1)) Some(spark.conf.get(pair._1)) else None + if (sparkSession.conf.contains(pair._1)) { + Some(sparkSession.conf.get(pair._1)) + } else None resetConfValues(pair._1) = value - spark.conf.set(pair._1, pair._2) + sparkSession.conf.set(pair._1, pair._2) }) lastStream = currentStream currentStream = - spark + sparkSession .streams .startQuery( None, @@ -354,13 +362,17 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { outputMode, trigger = trigger, triggerClock = triggerClock) - .asInstanceOf[StreamExecution] + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery currentStream.microBatchThread.setUncaughtExceptionHandler( new UncaughtExceptionHandler { override def uncaughtException(t: Thread, e: Throwable): Unit = { - streamDeathCause = e + streamThreadDeathCause = e } }) + // Wait until the initialization finishes, because some tests need to use `logicalPlan` + // after starting the query. + currentStream.awaitInitialization(streamingTimeout.toMillis) case AdvanceManualClock(timeToAdd) => verify(currentStream != null, @@ -394,8 +406,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { currentStream.exception.map(_.toString()).getOrElse("")) } catch { case _: InterruptedException => - case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => - failTest("Timed out while stopping and waiting for microbatchthread to terminate.") + case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest( + "Timed out while stopping and waiting for microbatchthread to terminate.", e) case t: Throwable => failTest("Error while stopping stream", t) } finally { @@ -412,8 +425,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { eventually("microbatch thread not stopped after termination with failure") { assert(!currentStream.microBatchThread.isAlive) } - verify(thrownException.query.eq(currentStream), - s"incorrect query reference in exception") verify(currentStream.exception === Some(thrownException), s"incorrect exception returned by query.exception()") @@ -421,16 +432,24 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { verify(exception.cause.getClass === ef.causeClass, "incorrect cause in exception returned by query.exception()\n" + s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}") + if (ef.isFatalError) { + // This is a fatal error, `streamThreadDeathCause` should be set to this error in + // UncaughtExceptionHandler. + verify(streamThreadDeathCause != null && + streamThreadDeathCause.getClass === ef.causeClass, + "UncaughtExceptionHandler didn't receive the correct error\n" + + s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause") + streamThreadDeathCause = null + } } catch { case _: InterruptedException => - case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => - failTest("Timed out while waiting for failure") + case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest("Timed out while waiting for failure", e) case t: Throwable => failTest("Error while checking stream failure", t) } finally { lastStream = currentStream currentStream = null - streamDeathCause = null } case a: AssertOnQuery => @@ -508,11 +527,14 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { } pos += 1 } + if (streamThreadDeathCause != null) { + failTest("Stream Thread Died", streamThreadDeathCause) + } } catch { - case _: InterruptedException if streamDeathCause != null => - failTest("Stream Thread Died") - case _: org.scalatest.exceptions.TestFailedDueToTimeoutException => - failTest("Timed out waiting for stream") + case _: InterruptedException if streamThreadDeathCause != null => + failTest("Stream Thread Died", streamThreadDeathCause) + case e: org.scalatest.exceptions.TestFailedDueToTimeoutException => + failTest("Timed out waiting for stream", e) } finally { if (currentStream != null && currentStream.microBatchThread.isAlive) { currentStream.stop() @@ -520,8 +542,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { // Rollback prev configuration values resetConfValues.foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) + case (key, Some(value)) => sparkSession.conf.set(key, value) + case (key, None) => sparkSession.conf.unset(key) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index fbe560e8d9..eca2647dea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -23,13 +23,13 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.InternalOutputModes._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.OutputMode._ object FailureSinglton { var firstTime = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 07a13a48a1..4596aa1d34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.streaming import java.util.UUID import scala.collection.mutable +import scala.concurrent.duration._ import org.scalactic.TolerantNumerics import org.scalatest.concurrent.AsyncAssertions.Waiter @@ -30,7 +31,9 @@ import org.scalatest.PrivateMethodTester._ import org.apache.spark.SparkException import org.apache.spark.scheduler._ +import org.apache.spark.sql.{Encoder, SparkSession} import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.util.JsonProtocol @@ -44,8 +47,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { after { spark.streams.active.foreach(_.stop()) assert(spark.streams.active.isEmpty) - assert(addedListeners.isEmpty) + assert(addedListeners().isEmpty) // Make sure we don't leak any events to the next test + spark.sparkContext.listenerBus.waitUntilEmpty(10000) } testQuietly("single listener, check trigger events are generated correctly") { @@ -67,6 +71,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { AssertOnQuery { query => assert(listener.startEvent !== null) assert(listener.startEvent.id === query.id) + assert(listener.startEvent.runId === query.runId) assert(listener.startEvent.name === query.name) assert(listener.progressEvents.isEmpty) assert(listener.terminationEvent === null) @@ -79,7 +84,11 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { CheckAnswer(10, 5), AssertOnQuery { query => assert(listener.progressEvents.nonEmpty) - assert(listener.progressEvents.last.json === query.lastProgress.json) + // SPARK-18868: We can't use query.lastProgress, because in progressEvents, we filter + // out non-zero input rows, but the lastProgress may be a zero input row trigger + val lastNonZeroProgress = query.recentProgress.filter(_.numInputRows > 0).lastOption + .getOrElse(fail("No progress updates received in StreamingQuery!")) + assert(listener.progressEvents.last.json === lastNonZeroProgress.json) assert(listener.terminationEvent === null) true }, @@ -90,6 +99,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { eventually(Timeout(streamingTimeout)) { assert(listener.terminationEvent !== null) assert(listener.terminationEvent.id === query.id) + assert(listener.terminationEvent.runId === query.runId) assert(listener.terminationEvent.exception === None) } listener.checkAsyncErrors() @@ -101,16 +111,19 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { StartStream(ProcessingTime(100), triggerClock = clock), AddData(inputData, 0), AdvanceManualClock(100), - ExpectFailure[SparkException], + ExpectFailure[SparkException](), AssertOnQuery { query => - assert(listener.terminationEvent !== null) - assert(listener.terminationEvent.id === query.id) - assert(listener.terminationEvent.exception.nonEmpty) - // Make sure that the exception message reported through listener - // contains the actual exception and relevant stack trace - assert(!listener.terminationEvent.exception.get.contains("StreamingQueryException")) - assert(listener.terminationEvent.exception.get.contains("java.lang.ArithmeticException")) - assert(listener.terminationEvent.exception.get.contains("StreamingQueryListenerSuite")) + eventually(Timeout(streamingTimeout)) { + assert(listener.terminationEvent !== null) + assert(listener.terminationEvent.id === query.id) + assert(listener.terminationEvent.exception.nonEmpty) + // Make sure that the exception message reported through listener + // contains the actual exception and relevant stack trace + assert(!listener.terminationEvent.exception.get.contains("StreamingQueryException")) + assert( + listener.terminationEvent.exception.get.contains("java.lang.ArithmeticException")) + assert(listener.terminationEvent.exception.get.contains("StreamingQueryListenerSuite")) + } listener.checkAsyncErrors() true } @@ -144,7 +157,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { assert(isListenerActive(listener1) === false) assert(isListenerActive(listener2) === true) } finally { - addedListeners.foreach(spark.streams.removeListener) + addedListeners().foreach(spark.streams.removeListener) } } @@ -165,30 +178,140 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } test("QueryStartedEvent serialization") { - val queryStarted = new StreamingQueryListener.QueryStartedEvent(UUID.randomUUID(), "name") - val json = JsonProtocol.sparkEventToJson(queryStarted) - val newQueryStarted = JsonProtocol.sparkEventFromJson(json) - .asInstanceOf[StreamingQueryListener.QueryStartedEvent] + def testSerialization(event: QueryStartedEvent): Unit = { + val json = JsonProtocol.sparkEventToJson(event) + val newEvent = JsonProtocol.sparkEventFromJson(json).asInstanceOf[QueryStartedEvent] + assert(newEvent.id === event.id) + assert(newEvent.runId === event.runId) + assert(newEvent.name === event.name) + } + + testSerialization(new QueryStartedEvent(UUID.randomUUID, UUID.randomUUID, "name")) + testSerialization(new QueryStartedEvent(UUID.randomUUID, UUID.randomUUID, null)) } test("QueryProgressEvent serialization") { - val event = new StreamingQueryListener.QueryProgressEvent( - StreamingQueryStatusAndProgressSuite.testProgress) - val json = JsonProtocol.sparkEventToJson(event) - val newEvent = JsonProtocol.sparkEventFromJson(json) - .asInstanceOf[StreamingQueryListener.QueryProgressEvent] - assert(event.progress.json === newEvent.progress.json) + def testSerialization(event: QueryProgressEvent): Unit = { + import scala.collection.JavaConverters._ + val json = JsonProtocol.sparkEventToJson(event) + val newEvent = JsonProtocol.sparkEventFromJson(json).asInstanceOf[QueryProgressEvent] + assert(newEvent.progress.json === event.progress.json) // json as a proxy for equality + assert(newEvent.progress.durationMs.asScala === event.progress.durationMs.asScala) + assert(newEvent.progress.eventTime.asScala === event.progress.eventTime.asScala) + } + testSerialization(new QueryProgressEvent(StreamingQueryStatusAndProgressSuite.testProgress1)) + testSerialization(new QueryProgressEvent(StreamingQueryStatusAndProgressSuite.testProgress2)) } test("QueryTerminatedEvent serialization") { + def testSerialization(event: QueryTerminatedEvent): Unit = { + val json = JsonProtocol.sparkEventToJson(event) + val newEvent = JsonProtocol.sparkEventFromJson(json).asInstanceOf[QueryTerminatedEvent] + assert(newEvent.id === event.id) + assert(newEvent.runId === event.runId) + assert(newEvent.exception === event.exception) + } + val exception = new RuntimeException("exception") - val queryQueryTerminated = new StreamingQueryListener.QueryTerminatedEvent( - UUID.randomUUID, Some(exception.getMessage)) - val json = JsonProtocol.sparkEventToJson(queryQueryTerminated) - val newQueryTerminated = JsonProtocol.sparkEventFromJson(json) - .asInstanceOf[StreamingQueryListener.QueryTerminatedEvent] - assert(queryQueryTerminated.id === newQueryTerminated.id) - assert(queryQueryTerminated.exception === newQueryTerminated.exception) + testSerialization( + new QueryTerminatedEvent(UUID.randomUUID, UUID.randomUUID, Some(exception.getMessage))) + } + + test("only one progress event per interval when no data") { + // This test will start a query but not push any data, and then check if we push too many events + withSQLConf(SQLConf.STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL.key -> "100ms") { + @volatile var numProgressEvent = 0 + val listener = new StreamingQueryListener { + override def onQueryStarted(event: QueryStartedEvent): Unit = {} + override def onQueryProgress(event: QueryProgressEvent): Unit = { + numProgressEvent += 1 + } + override def onQueryTerminated(event: QueryTerminatedEvent): Unit = {} + } + spark.streams.addListener(listener) + try { + val input = new MemoryStream[Int](0, sqlContext) { + @volatile var numTriggers = 0 + override def getOffset: Option[Offset] = { + numTriggers += 1 + super.getOffset + } + } + val clock = new StreamManualClock() + val actions = mutable.ArrayBuffer[StreamAction]() + actions += StartStream(trigger = ProcessingTime(10), triggerClock = clock) + for (_ <- 1 to 100) { + actions += AdvanceManualClock(10) + } + actions += AssertOnQuery { _ => + eventually(timeout(streamingTimeout)) { + assert(input.numTriggers > 100) // at least 100 triggers have occurred + } + true + } + // `recentProgress` should not receive too many no data events + actions += AssertOnQuery { q => + q.recentProgress.size > 1 && q.recentProgress.size <= 11 + } + testStream(input.toDS)(actions: _*) + spark.sparkContext.listenerBus.waitUntilEmpty(10000) + // 11 is the max value of the possible numbers of events. + assert(numProgressEvent > 1 && numProgressEvent <= 11) + } finally { + spark.streams.removeListener(listener) + } + } + } + + test("listener only posts events from queries started in the related sessions") { + val session1 = spark.newSession() + val session2 = spark.newSession() + val collector1 = new EventCollector + val collector2 = new EventCollector + + def runQuery(session: SparkSession): Unit = { + collector1.reset() + collector2.reset() + val mem = MemoryStream[Int](implicitly[Encoder[Int]], session.sqlContext) + testStream(mem.toDS)( + AddData(mem, 1, 2, 3), + CheckAnswer(1, 2, 3) + ) + session.sparkContext.listenerBus.waitUntilEmpty(5000) + } + + def assertEventsCollected(collector: EventCollector): Unit = { + assert(collector.startEvent !== null) + assert(collector.progressEvents.nonEmpty) + assert(collector.terminationEvent !== null) + } + + def assertEventsNotCollected(collector: EventCollector): Unit = { + assert(collector.startEvent === null) + assert(collector.progressEvents.isEmpty) + assert(collector.terminationEvent === null) + } + + assert(session1.ne(session2)) + assert(session1.streams.ne(session2.streams)) + + withListenerAdded(collector1, session1) { + assert(addedListeners(session1).nonEmpty) + + withListenerAdded(collector2, session2) { + assert(addedListeners(session2).nonEmpty) + + // query on session1 should send events only to collector1 + runQuery(session1) + assertEventsCollected(collector1) + assertEventsNotCollected(collector2) + + // query on session2 should send events only to collector2 + runQuery(session2) + assertEventsCollected(collector2) + assertEventsNotCollected(collector1) + } + } } testQuietly("ReplayListenerBus should ignore broken event jsons generated in 2.0.0") { @@ -238,21 +361,23 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } } - private def withListenerAdded(listener: StreamingQueryListener)(body: => Unit): Unit = { + private def withListenerAdded( + listener: StreamingQueryListener, + session: SparkSession = spark)(body: => Unit): Unit = { try { failAfter(streamingTimeout) { - spark.streams.addListener(listener) + session.streams.addListener(listener) body } } finally { - spark.streams.removeListener(listener) + session.streams.removeListener(listener) } } - private def addedListeners(): Array[StreamingQueryListener] = { + private def addedListeners(session: SparkSession = spark): Array[StreamingQueryListener] = { val listenerBusMethod = PrivateMethod[StreamingQueryListenerBus]('listenerBus) - val listenerBus = spark.streams invokePrivate listenerBusMethod() + val listenerBus = session.streams invokePrivate listenerBusMethod() listenerBus.listeners.toArray.map(_.asInstanceOf[StreamingQueryListener]) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala index 268b8ff7b4..8e16fd418a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import java.util.concurrent.CountDownLatch + import scala.concurrent.Future import scala.util.Random import scala.util.control.NonFatal @@ -30,6 +32,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException import org.apache.spark.sql.Dataset import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.streaming.util.BlockingSource import org.apache.spark.util.Utils class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { @@ -213,13 +216,35 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { } } + test("SPARK-18811: Source resolution should not block main thread") { + failAfter(streamingTimeout) { + BlockingSource.latch = new CountDownLatch(1) + withTempDir { tempDir => + // if source resolution was happening on the main thread, it would block the start call, + // now it should only be blocking the stream execution thread + val sq = spark.readStream + .format("org.apache.spark.sql.streaming.util.BlockingSource") + .load() + .writeStream + .format("org.apache.spark.sql.streaming.util.BlockingSource") + .option("checkpointLocation", tempDir.toString) + .start() + eventually(Timeout(streamingTimeout)) { + assert(sq.status.message.contains("Initializing sources")) + } + BlockingSource.latch.countDown() + sq.stop() + } + } + } + /** Run a body of code by defining a query on each dataset */ private def withQueriesOn(datasets: Dataset[_]*)(body: Seq[StreamingQuery] => Unit): Unit = { failAfter(streamingTimeout) { val queries = withClue("Error starting queries") { datasets.zipWithIndex.map { case (ds, i) => - @volatile var query: StreamExecution = null + var query: StreamingQuery = null try { val df = ds.toDF val metadataRoot = @@ -231,7 +256,6 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { .option("checkpointLocation", metadataRoot) .outputMode("append") .start() - .asInstanceOf[StreamExecution] } catch { case NonFatal(e) => if (query != null) query.stop() @@ -279,7 +303,7 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { Thread.sleep(stopAfter.toMillis) if (withError) { logDebug(s"Terminating query ${queryToStop.name} with error") - queryToStop.asInstanceOf[StreamExecution].logicalPlan.collect { + queryToStop.asInstanceOf[StreamingQueryWrapper].streamingQuery.logicalPlan.collect { case StreamingExecutionRelation(source, _) => source.asInstanceOf[MemoryStream[Int]].addData(0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index 4da712fa0f..34bf3985ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -24,26 +24,33 @@ import scala.collection.JavaConverters._ import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.StreamingQueryStatusAndProgressSuite._ -class StreamingQueryStatusAndProgressSuite extends SparkFunSuite { +class StreamingQueryStatusAndProgressSuite extends StreamTest { test("StreamingQueryProgress - prettyJson") { - val json = testProgress.prettyJson - assert(json === + val json1 = testProgress1.prettyJson + assert(json1 === s""" |{ - | "id" : "${testProgress.id.toString}", - | "name" : "name", - | "timestamp" : 1, + | "id" : "${testProgress1.id.toString}", + | "runId" : "${testProgress1.runId.toString}", + | "name" : "myName", + | "timestamp" : "2016-12-05T20:54:20.827Z", | "numInputRows" : 678, | "inputRowsPerSecond" : 10.0, | "durationMs" : { | "total" : 0 | }, - | "currentWatermark" : 3, + | "eventTime" : { + | "avg" : "2016-12-05T20:54:20.827Z", + | "max" : "2016-12-05T20:54:20.827Z", + | "min" : "2016-12-05T20:54:20.827Z", + | "watermark" : "2016-12-05T20:54:20.827Z" + | }, | "stateOperators" : [ { | "numRowsTotal" : 0, | "numRowsUpdated" : 1 @@ -60,16 +67,47 @@ class StreamingQueryStatusAndProgressSuite extends SparkFunSuite { | } |} """.stripMargin.trim) - assert(compact(parse(json)) === testProgress.json) - + assert(compact(parse(json1)) === testProgress1.json) + + val json2 = testProgress2.prettyJson + assert( + json2 === + s""" + |{ + | "id" : "${testProgress2.id.toString}", + | "runId" : "${testProgress2.runId.toString}", + | "name" : null, + | "timestamp" : "2016-12-05T20:54:20.827Z", + | "numInputRows" : 678, + | "durationMs" : { + | "total" : 0 + | }, + | "stateOperators" : [ { + | "numRowsTotal" : 0, + | "numRowsUpdated" : 1 + | } ], + | "sources" : [ { + | "description" : "source", + | "startOffset" : 123, + | "endOffset" : 456, + | "numInputRows" : 678 + | } ], + | "sink" : { + | "description" : "sink" + | } + |} + """.stripMargin.trim) + assert(compact(parse(json2)) === testProgress2.json) } test("StreamingQueryProgress - json") { - assert(compact(parse(testProgress.json)) === testProgress.json) + assert(compact(parse(testProgress1.json)) === testProgress1.json) + assert(compact(parse(testProgress2.json)) === testProgress2.json) } test("StreamingQueryProgress - toString") { - assert(testProgress.toString === testProgress.prettyJson) + assert(testProgress1.toString === testProgress1.prettyJson) + assert(testProgress2.toString === testProgress2.prettyJson) } test("StreamingQueryStatus - prettyJson") { @@ -91,16 +129,57 @@ class StreamingQueryStatusAndProgressSuite extends SparkFunSuite { test("StreamingQueryStatus - toString") { assert(testStatus.toString === testStatus.prettyJson) } + + test("progress classes should be Serializable") { + import testImplicits._ + + val inputData = MemoryStream[Int] + + val query = inputData.toDS() + .groupBy($"value") + .agg(count("*")) + .writeStream + .queryName("progress_serializable_test") + .format("memory") + .outputMode("complete") + .start() + try { + inputData.addData(1, 2, 3) + query.processAllAvailable() + + val progress = query.recentProgress + + // Make sure it generates the progress objects we want to test + assert(progress.exists { p => + p.sources.size >= 1 && p.stateOperators.size >= 1 && p.sink != null + }) + + val array = spark.sparkContext.parallelize(progress).collect() + assert(array.length === progress.length) + array.zip(progress).foreach { case (p1, p2) => + // Make sure we did serialize and deserialize the object + assert(p1 ne p2) + assert(p1.json === p2.json) + } + } finally { + query.stop() + } + } } object StreamingQueryStatusAndProgressSuite { - val testProgress = new StreamingQueryProgress( - id = UUID.randomUUID(), - name = "name", - timestamp = 1L, + val testProgress1 = new StreamingQueryProgress( + id = UUID.randomUUID, + runId = UUID.randomUUID, + name = "myName", + timestamp = "2016-12-05T20:54:20.827Z", batchId = 2L, - durationMs = Map("total" -> 0L).mapValues(long2Long).asJava, - currentWatermark = 3L, + durationMs = new java.util.HashMap(Map("total" -> 0L).mapValues(long2Long).asJava), + eventTime = new java.util.HashMap(Map( + "max" -> "2016-12-05T20:54:20.827Z", + "min" -> "2016-12-05T20:54:20.827Z", + "avg" -> "2016-12-05T20:54:20.827Z", + "watermark" -> "2016-12-05T20:54:20.827Z").asJava), stateOperators = Array(new StateOperatorProgress(numRowsTotal = 0, numRowsUpdated = 1)), sources = Array( new SourceProgress( @@ -115,6 +194,29 @@ object StreamingQueryStatusAndProgressSuite { sink = new SinkProgress("sink") ) + val testProgress2 = new StreamingQueryProgress( + id = UUID.randomUUID, + runId = UUID.randomUUID, + name = null, // should not be present in the json + timestamp = "2016-12-05T20:54:20.827Z", + batchId = 2L, + durationMs = new java.util.HashMap(Map("total" -> 0L).mapValues(long2Long).asJava), + // empty maps should be handled correctly + eventTime = new java.util.HashMap(Map.empty[String, String].asJava), + stateOperators = Array(new StateOperatorProgress(numRowsTotal = 0, numRowsUpdated = 1)), + sources = Array( + new SourceProgress( + description = "source", + startOffset = "123", + endOffset = "456", + numInputRows = 678, + inputRowsPerSecond = Double.NaN, // should not be present in the json + processedRowsPerSecond = Double.NegativeInfinity // should not be present in the json + ) + ), + sink = new SinkProgress("sink") + ) + val testStatus = new StreamingQueryStatus("active", true, false) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 56abe1201c..1525ad5fd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,18 +17,23 @@ package org.apache.spark.sql.streaming +import java.util.concurrent.CountDownLatch + +import org.apache.commons.lang3.RandomStringUtils import org.scalactic.TolerantNumerics import org.scalatest.concurrent.Eventually._ import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.apache.spark.internal.Logging -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType import org.apache.spark.SparkException import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ -import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.util.BlockingSource +import org.apache.spark.util.ManualClock class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { @@ -43,38 +48,77 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { sqlContext.streams.active.foreach(_.stop()) } - test("names unique across active queries, ids unique across all started queries") { - val inputData = MemoryStream[Int] - val mapped = inputData.toDS().map { 6 / _} + test("name unique in active queries") { + withTempDir { dir => + def startQuery(name: Option[String]): StreamingQuery = { + val writer = MemoryStream[Int].toDS.writeStream + name.foreach(writer.queryName) + writer + .foreach(new TestForeachWriter) + .start() + } - def startQuery(queryName: String): StreamingQuery = { - val metadataRoot = Utils.createTempDir(namePrefix = "streaming.checkpoint").getCanonicalPath - val writer = mapped.writeStream - writer - .queryName(queryName) - .format("memory") - .option("checkpointLocation", metadataRoot) - .start() - } + // No name by default, multiple active queries can have no name + val q1 = startQuery(name = None) + assert(q1.name === null) + val q2 = startQuery(name = None) + assert(q2.name === null) - val q1 = startQuery("q1") - assert(q1.name === "q1") + // Can be set by user + val q3 = startQuery(name = Some("q3")) + assert(q3.name === "q3") - // Verify that another query with same name cannot be started - val e1 = intercept[IllegalArgumentException] { - startQuery("q1") + // Multiple active queries cannot have same name + val e = intercept[IllegalArgumentException] { + startQuery(name = Some("q3")) + } + + q1.stop() + q2.stop() + q3.stop() } - Seq("q1", "already active").foreach { s => assert(e1.getMessage.contains(s)) } + } - // Verify q1 was unaffected by the above exception and stop it - assert(q1.isActive) - q1.stop() + test( + "id unique in active queries + persists across restarts, runId unique across start/restarts") { + val inputData = MemoryStream[Int] + withTempDir { dir => + var cpDir: String = null + + def startQuery(restart: Boolean): StreamingQuery = { + if (cpDir == null || !restart) cpDir = s"$dir/${RandomStringUtils.randomAlphabetic(10)}" + MemoryStream[Int].toDS().groupBy().count() + .writeStream + .format("memory") + .outputMode("complete") + .queryName(s"name${RandomStringUtils.randomAlphabetic(10)}") + .option("checkpointLocation", cpDir) + .start() + } - // Verify another query can be started with name q1, but will have different id - val q2 = startQuery("q1") - assert(q2.name === "q1") - assert(q2.id !== q1.id) - q2.stop() + // id and runId unique for new queries + val q1 = startQuery(restart = false) + val q2 = startQuery(restart = false) + assert(q1.id !== q2.id) + assert(q1.runId !== q2.runId) + q1.stop() + q2.stop() + + // id persists across restarts, runId unique across restarts + val q3 = startQuery(restart = false) + q3.stop() + + val q4 = startQuery(restart = true) + q4.stop() + assert(q3.id === q3.id) + assert(q3.runId !== q4.runId) + + // Only one query with same id can be active + val q5 = startQuery(restart = false) + val e = intercept[IllegalStateException] { + startQuery(restart = true) + } + } } testQuietly("isActive, exception, and awaitTermination") { @@ -98,19 +142,21 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { StartStream(), AssertOnQuery(_.isActive === true), AddData(inputData, 0), - ExpectFailure[SparkException], + ExpectFailure[SparkException](), AssertOnQuery(_.isActive === false), TestAwaitTermination(ExpectException[SparkException]), TestAwaitTermination(ExpectException[SparkException], timeoutMs = 2000), TestAwaitTermination(ExpectException[SparkException], timeoutMs = 10), - AssertOnQuery( - q => q.exception.get.startOffset.get.offsets === - q.committedOffsets.toOffsetSeq(Seq(inputData), "{}").offsets, - "incorrect start offset on exception") + AssertOnQuery(q => { + q.exception.get.startOffset === + q.committedOffsets.toOffsetSeq(Seq(inputData), OffsetSeqMetadata()).toString && + q.exception.get.endOffset === + q.availableOffsets.toOffsetSeq(Seq(inputData), OffsetSeqMetadata()).toString + }, "incorrect start offset or end offset on exception") ) } - testQuietly("status, lastProgress, and recentProgresses") { + testQuietly("status, lastProgress, and recentProgress") { import StreamingQuerySuite._ clock = new StreamManualClock @@ -159,7 +205,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), - AssertOnQuery(_.recentProgresses.count(_.numInputRows > 0) === 0), + AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while offset is being fetched AddData(inputData, 1, 2), @@ -168,7 +214,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message.startsWith("Getting offsets from")), - AssertOnQuery(_.recentProgresses.count(_.numInputRows > 0) === 0), + AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch is being fetched AdvanceManualClock(200), // time = 300 to unblock getOffset, will block on getBatch @@ -176,14 +222,14 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), - AssertOnQuery(_.recentProgresses.count(_.numInputRows > 0) === 0), + AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch is being processed AdvanceManualClock(300), // time = 600 to unblock getBatch, will block in Spark job AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), - AssertOnQuery(_.recentProgresses.count(_.numInputRows > 0) === 0), + AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch processing has completed AdvanceManualClock(500), // time = 1100 to unblock job @@ -194,14 +240,14 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { AssertOnQuery(_.status.message === "Waiting for next trigger"), AssertOnQuery { query => assert(query.lastProgress != null) - assert(query.recentProgresses.exists(_.numInputRows > 0)) - assert(query.recentProgresses.last.eq(query.lastProgress)) + assert(query.recentProgress.exists(_.numInputRows > 0)) + assert(query.recentProgress.last.eq(query.lastProgress)) val progress = query.lastProgress assert(progress.id === query.id) assert(progress.name === query.name) assert(progress.batchId === 0) - assert(progress.timestamp === 100) + assert(progress.timestamp === "1970-01-01T00:00:00.100Z") // 100 ms in UTC assert(progress.numInputRows === 2) assert(progress.processedRowsPerSecond === 2.0) @@ -232,7 +278,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message === "Waiting for next trigger"), AssertOnQuery { query => - assert(query.recentProgresses.last.eq(query.lastProgress)) + assert(query.recentProgress.last.eq(query.lastProgress)) assert(query.lastProgress.batchId === 1) assert(query.lastProgress.sources(0).inputRowsPerSecond === 1.818) true @@ -260,19 +306,37 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { StartStream(ProcessingTime(100), triggerClock = clock), AddData(inputData, 0), AdvanceManualClock(100), - ExpectFailure[SparkException], + ExpectFailure[SparkException](), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === false), AssertOnQuery(_.status.message.startsWith("Terminated with exception")) ) } + test("lastProgress should be null when recentProgress is empty") { + BlockingSource.latch = new CountDownLatch(1) + withTempDir { tempDir => + val sq = spark.readStream + .format("org.apache.spark.sql.streaming.util.BlockingSource") + .load() + .writeStream + .format("org.apache.spark.sql.streaming.util.BlockingSource") + .option("checkpointLocation", tempDir.toString) + .start() + // Creating source is blocked so recentProgress is empty and lastProgress should be null + assert(sq.lastProgress === null) + // Release the latch and stop the query + BlockingSource.latch.countDown() + sq.stop() + } + } + test("codahale metrics") { val inputData = MemoryStream[Int] /** Whether metrics of a query is registered for reporting */ def isMetricsRegistered(query: StreamingQuery): Boolean = { - val sourceName = s"spark.streaming.${query.name}" + val sourceName = s"spark.streaming.${query.id}" val sources = spark.sparkContext.env.metricsSystem.getSourcesByName(sourceName) require(sources.size <= 1) sources.nonEmpty @@ -327,25 +391,94 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { testQuietly("StreamExecution metadata garbage collection") { val inputData = MemoryStream[Int] val mapped = inputData.toDS().map(6 / _) + withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "1") { + // Run 3 batches, and then assert that only 2 metadata files is are at the end + // since the first should have been purged. + testStream(mapped)( + AddData(inputData, 1, 2), + CheckAnswer(6, 3), + AddData(inputData, 1, 2), + CheckAnswer(6, 3, 6, 3), + AddData(inputData, 4, 6), + CheckAnswer(6, 3, 6, 3, 1, 1), + + AssertOnQuery("metadata log should contain only two files") { q => + val metadataLogDir = new java.io.File(q.offsetLog.metadataPath.toString) + val logFileNames = metadataLogDir.listFiles().toSeq.map(_.getName()) + val toTest = logFileNames.filter(!_.endsWith(".crc")).sorted // Workaround for SPARK-17475 + assert(toTest.size == 2 && toTest.head == "1") + true + } + ) + } - // Run 3 batches, and then assert that only 2 metadata files is are at the end - // since the first should have been purged. - testStream(mapped)( - AddData(inputData, 1, 2), - CheckAnswer(6, 3), - AddData(inputData, 1, 2), - CheckAnswer(6, 3, 6, 3), - AddData(inputData, 4, 6), - CheckAnswer(6, 3, 6, 3, 1, 1), - - AssertOnQuery("metadata log should contain only two files") { q => - val metadataLogDir = new java.io.File(q.offsetLog.metadataPath.toString) - val logFileNames = metadataLogDir.listFiles().toSeq.map(_.getName()) - val toTest = logFileNames.filter(! _.endsWith(".crc")).sorted // Workaround for SPARK-17475 - assert(toTest.size == 2 && toTest.head == "1") - true + val inputData2 = MemoryStream[Int] + withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2") { + // Run 5 batches, and then assert that 3 metadata files is are at the end + // since the two should have been purged. + testStream(inputData2.toDS())( + AddData(inputData2, 1, 2), + CheckAnswer(1, 2), + AddData(inputData2, 1, 2), + CheckAnswer(1, 2, 1, 2), + AddData(inputData2, 3, 4), + CheckAnswer(1, 2, 1, 2, 3, 4), + AddData(inputData2, 5, 6), + CheckAnswer(1, 2, 1, 2, 3, 4, 5, 6), + AddData(inputData2, 7, 8), + CheckAnswer(1, 2, 1, 2, 3, 4, 5, 6, 7, 8), + + AssertOnQuery("metadata log should contain three files") { q => + val metadataLogDir = new java.io.File(q.offsetLog.metadataPath.toString) + val logFileNames = metadataLogDir.listFiles().toSeq.map(_.getName()) + val toTest = logFileNames.filter(!_.endsWith(".crc")).sorted // Workaround for SPARK-17475 + assert(toTest.size == 3 && toTest.head == "2") + true + } + ) + } + } + + test("StreamingQuery should be Serializable but cannot be used in executors") { + def startQuery(ds: Dataset[Int], queryName: String): StreamingQuery = { + ds.writeStream + .queryName(queryName) + .format("memory") + .start() + } + + val input = MemoryStream[Int] + val q1 = startQuery(input.toDS, "stream_serializable_test_1") + val q2 = startQuery(input.toDS.map { i => + // Emulate that `StreamingQuery` get captured with normal usage unintentionally. + // It should not fail the query. + q1 + i + }, "stream_serializable_test_2") + val q3 = startQuery(input.toDS.map { i => + // Emulate that `StreamingQuery` is used in executors. We should fail the query with a clear + // error message. + q1.explain() + i + }, "stream_serializable_test_3") + try { + input.addData(1) + + // q2 should not fail since it doesn't use `q1` in the closure + q2.processAllAvailable() + + // The user calls `StreamingQuery` in the closure and it should fail + val e = intercept[StreamingQueryException] { + q3.processAllAvailable() } - ) + assert(e.getCause.isInstanceOf[SparkException]) + assert(e.getCause.getCause.isInstanceOf[IllegalStateException]) + assert(e.getMessage.contains("StreamingQuery cannot be used in executors")) + } finally { + q1.stop() + q2.stop() + q3.stop() + } } /** Create a streaming DF that only execute one batch in which it returns the given static DF */ @@ -366,7 +499,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { try { val q = streamingDF.writeStream.format("memory").queryName("test").start() q.processAllAvailable() - q.recentProgresses.head + q.recentProgress.head } finally { spark.streams.active.map(_.stop()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 0eb95a0243..097dd6e367 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -23,12 +23,14 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import org.mockito.Mockito._ -import org.scalatest.BeforeAndAfter +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} +import org.scalatest.PrivateMethodTester.PrivateMethod import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} -import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, StreamingQuery, StreamTest} +import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -104,7 +106,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { } } -class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { +class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter with PrivateMethodTester { private def newMetadataDir = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -338,7 +340,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { .start() q.stop() - assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(10000)) + assert(q.asInstanceOf[StreamingQueryWrapper].streamingQuery.trigger == ProcessingTime(10000)) q = df.writeStream .format("org.apache.spark.sql.streaming.test") @@ -347,7 +349,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { .start() q.stop() - assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(100000)) + assert(q.asInstanceOf[StreamingQueryWrapper].streamingQuery.trigger == ProcessingTime(100000)) } test("source metadataPath") { @@ -387,19 +389,40 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath - test("check outputMode(string) throws exception on unsupported modes") { - def testError(outputMode: String): Unit = { + test("supported strings in outputMode(string)") { + val outputModeMethod = PrivateMethod[OutputMode]('outputMode) + + def testMode(outputMode: String, expected: OutputMode): Unit = { + val df = spark.readStream + .format("org.apache.spark.sql.streaming.test") + .load() + val w = df.writeStream + w.outputMode(outputMode) + val setOutputMode = w invokePrivate outputModeMethod() + assert(setOutputMode === expected) + } + + testMode("append", OutputMode.Append) + testMode("Append", OutputMode.Append) + testMode("complete", OutputMode.Complete) + testMode("Complete", OutputMode.Complete) + testMode("update", OutputMode.Update) + testMode("Update", OutputMode.Update) + } + + test("unsupported strings in outputMode(string)") { + def testMode(outputMode: String): Unit = { + val acceptedModes = Seq("append", "update", "complete") val df = spark.readStream .format("org.apache.spark.sql.streaming.test") .load() val w = df.writeStream val e = intercept[IllegalArgumentException](w.outputMode(outputMode)) - Seq("output mode", "unknown", outputMode).foreach { s => + (Seq("output mode", "unknown", outputMode) ++ acceptedModes).foreach { s => assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) } } - testError("Update") - testError("Xyz") + testMode("Xyz") } test("check foreach() catches null writers") { @@ -469,24 +492,22 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { sq.stop() } - test("MemorySink can recover from a checkpoint in Complete Mode") { + private def testMemorySinkCheckpointRecovery(chkLoc: String, provideInWriter: Boolean): Unit = { import testImplicits._ val ms = new MemoryStream[Int](0, sqlContext) val df = ms.toDF().toDF("a") - val checkpointLoc = newMetadataDir - val checkpointDir = new File(checkpointLoc, "offsets") - checkpointDir.mkdirs() - assert(checkpointDir.exists()) val tableName = "test" def startQuery: StreamingQuery = { - df.groupBy("a") + val writer = df.groupBy("a") .count() .writeStream .format("memory") .queryName(tableName) - .option("checkpointLocation", checkpointLoc) .outputMode("complete") - .start() + if (provideInWriter) { + writer.option("checkpointLocation", chkLoc) + } + writer.start() } // no exception here val q = startQuery @@ -512,6 +533,24 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { q2.stop() } + test("MemorySink can recover from a checkpoint in Complete Mode") { + val checkpointLoc = newMetadataDir + val checkpointDir = new File(checkpointLoc, "offsets") + checkpointDir.mkdirs() + assert(checkpointDir.exists()) + testMemorySinkCheckpointRecovery(checkpointLoc, provideInWriter = true) + } + + test("SPARK-18927: MemorySink can recover from a checkpoint provided in conf in Complete Mode") { + val checkpointLoc = newMetadataDir + val checkpointDir = new File(checkpointLoc, "offsets") + checkpointDir.mkdirs() + assert(checkpointDir.exists()) + withSQLConf(SQLConf.CHECKPOINT_LOCATION.key -> checkpointLoc) { + testMemorySinkCheckpointRecovery(checkpointLoc, provideInWriter = false) + } + } + test("append mode memory sink's do not support checkpoint recovery") { import testImplicits._ val ms = new MemoryStream[Int](0, sqlContext) @@ -575,4 +614,59 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { sq.stop() } } + + test("user specified checkpointLocation precedes SQLConf") { + import testImplicits._ + withTempDir { checkpointPath => + withTempPath { userCheckpointPath => + assert(!userCheckpointPath.exists(), s"$userCheckpointPath should not exist") + withSQLConf(SQLConf.CHECKPOINT_LOCATION.key -> checkpointPath.getAbsolutePath) { + val queryName = "test_query" + val ds = MemoryStream[Int].toDS + ds.writeStream + .format("memory") + .queryName(queryName) + .option("checkpointLocation", userCheckpointPath.getAbsolutePath) + .start() + .stop() + assert(checkpointPath.listFiles().isEmpty, + "SQLConf path is used even if user specified checkpointLoc: " + + s"${checkpointPath.listFiles()} is not empty") + assert(userCheckpointPath.exists(), + s"The user specified checkpointLoc (userCheckpointPath) is not created") + } + } + } + } + + test("use SQLConf checkpoint dir when checkpointLocation is not specified") { + import testImplicits._ + withTempDir { checkpointPath => + withSQLConf(SQLConf.CHECKPOINT_LOCATION.key -> checkpointPath.getAbsolutePath) { + val queryName = "test_query" + val ds = MemoryStream[Int].toDS + ds.writeStream.format("memory").queryName(queryName).start().stop() + // Should use query name to create a folder in `checkpointPath` + val queryCheckpointDir = new File(checkpointPath, queryName) + assert(queryCheckpointDir.exists(), s"$queryCheckpointDir doesn't exist") + assert( + checkpointPath.listFiles().size === 1, + s"${checkpointPath.listFiles().toList} has 0 or more than 1 files ") + } + } + } + + test("use SQLConf checkpoint dir when checkpointLocation is not specified without query name") { + import testImplicits._ + withTempDir { checkpointPath => + withSQLConf(SQLConf.CHECKPOINT_LOCATION.key -> checkpointPath.getAbsolutePath) { + val ds = MemoryStream[Int].toDS + ds.writeStream.format("console").start().stop() + // Should create a random folder in `checkpointPath` + assert( + checkpointPath.listFiles().size === 1, + s"${checkpointPath.listFiles().toList} has 0 or more than 1 files ") + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockingSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockingSource.scala new file mode 100644 index 0000000000..19ab2ff13e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/BlockingSource.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.util + +import java.util.concurrent.CountDownLatch + +import org.apache.spark.sql.{SQLContext, _} +import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source} +import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +/** Dummy provider: returns a SourceProvider with a blocking `createSource` call. */ +class BlockingSource extends StreamSourceProvider with StreamSinkProvider { + + private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) + + override def sourceSchema( + spark: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + ("dummySource", fakeSchema) + } + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + BlockingSource.latch.await() + new Source { + override def schema: StructType = fakeSchema + override def getOffset: Option[Offset] = Some(new LongOffset(0)) + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + import spark.implicits._ + Seq[Int]().toDS().toDF() + } + override def stop() {} + } + } + + override def createSink( + spark: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { + new Sink { + override def addBatch(batchId: Long, data: DataFrame): Unit = {} + } + } +} + +object BlockingSource { + var latch: CountDownLatch = null +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index e0887e0f1c..4bec2e3fdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -108,16 +108,14 @@ class DefaultSourceWithoutUserSpecifiedSchema } class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter { - + import testImplicits._ private val userSchema = new StructType().add("s", StringType) private val textSchema = new StructType().add("value", StringType) private val data = Seq("1", "2", "3") private val dir = Utils.createTempDir(namePrefix = "input").getCanonicalPath - private implicit var enc: Encoder[String] = _ before { - enc = spark.implicits.newStringEncoder Utils.deleteRecursively(new File(dir)) } @@ -459,8 +457,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } test("column nullability and comment - write and then read") { - import testImplicits._ - Seq("json", "parquet", "csv").foreach { format => val schema = StructType( StructField("cl1", IntegerType, nullable = false).withComment("test") :: @@ -576,7 +572,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be test("SPARK-18510: use user specified types for partition columns in file sources") { import org.apache.spark.sql.functions.udf - import testImplicits._ withTempDir { src => val createArray = udf { (length: Long) => for (i <- 1 to length.toInt) yield i.toString @@ -609,4 +604,35 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be ) } } + + test("SPARK-18899: append to a bucketed table using DataFrameWriter with mismatched bucketing") { + withTable("t") { + Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.bucketBy(2, "i").saveAsTable("t") + val e = intercept[AnalysisException] { + Seq(3 -> "c").toDF("i", "j").write.bucketBy(3, "i").mode("append").saveAsTable("t") + } + assert(e.message.contains("Specified bucketing does not match that of the existing table")) + } + } + + test("SPARK-18912: number of columns mismatch for non-file-based data source table") { + withTable("t") { + sql("CREATE TABLE t USING org.apache.spark.sql.test.DefaultSource") + + val e = intercept[AnalysisException] { + Seq(1 -> "a").toDF("a", "b").write + .format("org.apache.spark.sql.test.DefaultSource") + .mode("append").saveAsTable("t") + } + assert(e.message.contains("The column number of the existing table")) + } + } + + test("SPARK-18913: append to a table with special column names") { + withTable("t") { + Seq(1 -> "a").toDF("x.x", "y.y").write.saveAsTable("t") + Seq(2 -> "b").toDF("x.x", "y.y").write.mode("append").saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Nil) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index db24ee8b46..2239f10870 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -48,14 +48,18 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { */ protected implicit def sqlContext: SQLContext = _spark.sqlContext + protected def createSparkSession: TestSparkSession = { + new TestSparkSession( + sparkConf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + } + /** * Initialize the [[TestSparkSession]]. */ protected override def beforeAll(): Unit = { SparkSession.sqlListener.set(null) if (_spark == null) { - _spark = new TestSparkSession( - sparkConf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + _spark = createSparkSession } // Ensure we have initialized the context before calling parent code super.beforeAll() diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 819897cd46..9c879218dd 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -85,6 +85,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + net.sf.jpam jpam diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 2be99cb104..9aedaf234e 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../../pom.xml @@ -60,6 +60,8 @@ org.apache.spark spark-tags_${scala.binary.version} + test-jar + test + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + com.google.guava diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamExecutionMetadataSuite.scala b/streaming/src/main/java/org/apache/spark/status/api/v1/streaming/BatchStatus.java similarity index 55% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamExecutionMetadataSuite.scala rename to streaming/src/main/java/org/apache/spark/status/api/v1/streaming/BatchStatus.java index c7139c588d..1bbca5a225 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamExecutionMetadataSuite.scala +++ b/streaming/src/main/java/org/apache/spark/status/api/v1/streaming/BatchStatus.java @@ -15,21 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.streaming +package org.apache.spark.status.api.v1.streaming; -import org.apache.spark.sql.execution.streaming.StreamExecutionMetadata +import org.apache.spark.util.EnumUtil; -class StreamExecutionMetadataSuite extends StreamTest { +public enum BatchStatus { + COMPLETED, + QUEUED, + PROCESSING; - test("stream execution metadata") { - assert(StreamExecutionMetadata(0, 0) === - StreamExecutionMetadata("""{}""")) - assert(StreamExecutionMetadata(1, 0) === - StreamExecutionMetadata("""{"batchWatermarkMs":1}""")) - assert(StreamExecutionMetadata(0, 2) === - StreamExecutionMetadata("""{"batchTimestampMs":2}""")) - assert(StreamExecutionMetadata(1, 2) === - StreamExecutionMetadata( - """{"batchWatermarkMs":1,"batchTimestampMs":2}""")) + public static BatchStatus fromString(String str) { + return EnumUtil.parseIgnoreCase(BatchStatus.class, str); } } diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllBatchesResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllBatchesResource.scala new file mode 100644 index 0000000000..3a51ae6093 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllBatchesResource.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1.streaming + +import java.util.{ArrayList => JArrayList, Arrays => JArrays, Date, List => JList} +import javax.ws.rs.{GET, Produces, QueryParam} +import javax.ws.rs.core.MediaType + +import org.apache.spark.status.api.v1.streaming.AllBatchesResource._ +import org.apache.spark.streaming.ui.StreamingJobProgressListener + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class AllBatchesResource(listener: StreamingJobProgressListener) { + + @GET + def batchesList(@QueryParam("status") statusParams: JList[BatchStatus]): Seq[BatchInfo] = { + batchInfoList(listener, statusParams).sortBy(- _.batchId) + } +} + +private[v1] object AllBatchesResource { + + def batchInfoList( + listener: StreamingJobProgressListener, + statusParams: JList[BatchStatus] = new JArrayList[BatchStatus]()): Seq[BatchInfo] = { + + listener.synchronized { + val statuses = + if (statusParams.isEmpty) JArrays.asList(BatchStatus.values(): _*) else statusParams + val statusToBatches = Seq( + BatchStatus.COMPLETED -> listener.retainedCompletedBatches, + BatchStatus.QUEUED -> listener.waitingBatches, + BatchStatus.PROCESSING -> listener.runningBatches + ) + + val batchInfos = for { + (status, batches) <- statusToBatches + batch <- batches if statuses.contains(status) + } yield { + val batchId = batch.batchTime.milliseconds + val firstFailureReason = batch.outputOperations.flatMap(_._2.failureReason).headOption + + new BatchInfo( + batchId = batchId, + batchTime = new Date(batchId), + status = status.toString, + batchDuration = listener.batchDuration, + inputSize = batch.numRecords, + schedulingDelay = batch.schedulingDelay, + processingTime = batch.processingDelay, + totalDelay = batch.totalDelay, + numActiveOutputOps = batch.numActiveOutputOp, + numCompletedOutputOps = batch.numCompletedOutputOp, + numFailedOutputOps = batch.numFailedOutputOp, + numTotalOutputOps = batch.outputOperations.size, + firstFailureReason = firstFailureReason + ) + } + + batchInfos + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllOutputOperationsResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllOutputOperationsResource.scala new file mode 100644 index 0000000000..0eb649f0e1 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllOutputOperationsResource.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1.streaming + +import java.util.Date +import javax.ws.rs.{GET, PathParam, Produces} +import javax.ws.rs.core.MediaType + +import org.apache.spark.status.api.v1.NotFoundException +import org.apache.spark.status.api.v1.streaming.AllOutputOperationsResource._ +import org.apache.spark.streaming.Time +import org.apache.spark.streaming.ui.StreamingJobProgressListener + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class AllOutputOperationsResource(listener: StreamingJobProgressListener) { + + @GET + def operationsList(@PathParam("batchId") batchId: Long): Seq[OutputOperationInfo] = { + outputOperationInfoList(listener, batchId).sortBy(_.outputOpId) + } +} + +private[v1] object AllOutputOperationsResource { + + def outputOperationInfoList( + listener: StreamingJobProgressListener, + batchId: Long): Seq[OutputOperationInfo] = { + + listener.synchronized { + listener.getBatchUIData(Time(batchId)) match { + case Some(batch) => + for ((opId, op) <- batch.outputOperations) yield { + val jobIds = batch.outputOpIdSparkJobIdPairs + .filter(_.outputOpId == opId).map(_.sparkJobId).toSeq.sorted + + new OutputOperationInfo( + outputOpId = opId, + name = op.name, + description = op.description, + startTime = op.startTime.map(new Date(_)), + endTime = op.endTime.map(new Date(_)), + duration = op.duration, + failureReason = op.failureReason, + jobIds = jobIds + ) + } + case None => throw new NotFoundException("unknown batch: " + batchId) + } + }.toSeq + } +} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllReceiversResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllReceiversResource.scala new file mode 100644 index 0000000000..5a276a9236 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/AllReceiversResource.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1.streaming + +import java.util.Date +import javax.ws.rs.{GET, Produces} +import javax.ws.rs.core.MediaType + +import org.apache.spark.status.api.v1.streaming.AllReceiversResource._ +import org.apache.spark.streaming.ui.StreamingJobProgressListener + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class AllReceiversResource(listener: StreamingJobProgressListener) { + + @GET + def receiversList(): Seq[ReceiverInfo] = { + receiverInfoList(listener).sortBy(_.streamId) + } +} + +private[v1] object AllReceiversResource { + + def receiverInfoList(listener: StreamingJobProgressListener): Seq[ReceiverInfo] = { + listener.synchronized { + listener.receivedRecordRateWithBatchTime.map { case (streamId, eventRates) => + + val receiverInfo = listener.receiverInfo(streamId) + val streamName = receiverInfo.map(_.name) + .orElse(listener.streamName(streamId)).getOrElse(s"Stream-$streamId") + val avgEventRate = + if (eventRates.isEmpty) None else Some(eventRates.map(_._2).sum / eventRates.size) + + val (errorTime, errorMessage, error) = receiverInfo match { + case None => (None, None, None) + case Some(info) => + val someTime = + if (info.lastErrorTime >= 0) Some(new Date(info.lastErrorTime)) else None + val someMessage = + if (info.lastErrorMessage.length > 0) Some(info.lastErrorMessage) else None + val someError = + if (info.lastError.length > 0) Some(info.lastError) else None + + (someTime, someMessage, someError) + } + + new ReceiverInfo( + streamId = streamId, + streamName = streamName, + isActive = receiverInfo.map(_.active), + executorId = receiverInfo.map(_.executorId), + executorHost = receiverInfo.map(_.location), + lastErrorTime = errorTime, + lastErrorMessage = errorMessage, + lastError = error, + avgEventRate = avgEventRate, + eventRates = eventRates + ) + }.toSeq + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala new file mode 100644 index 0000000000..e64830a945 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingApp.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1.streaming + +import javax.ws.rs.{Path, PathParam} + +import org.apache.spark.status.api.v1.UIRootFromServletContext + +@Path("/v1") +private[v1] class ApiStreamingApp extends UIRootFromServletContext { + + @Path("applications/{appId}/streaming") + def getStreamingRoot(@PathParam("appId") appId: String): ApiStreamingRootResource = { + uiRoot.withSparkUI(appId, None) { ui => + new ApiStreamingRootResource(ui) + } + } + + @Path("applications/{appId}/{attemptId}/streaming") + def getStreamingRoot( + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): ApiStreamingRootResource = { + uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + new ApiStreamingRootResource(ui) + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingRootResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingRootResource.scala new file mode 100644 index 0000000000..1ccd586c84 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/ApiStreamingRootResource.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1.streaming + +import javax.ws.rs.Path + +import org.apache.spark.status.api.v1.NotFoundException +import org.apache.spark.streaming.ui.StreamingJobProgressListener +import org.apache.spark.ui.SparkUI + +private[v1] class ApiStreamingRootResource(ui: SparkUI) { + + import org.apache.spark.status.api.v1.streaming.ApiStreamingRootResource._ + + @Path("statistics") + def getStreamingStatistics(): StreamingStatisticsResource = { + new StreamingStatisticsResource(getListener(ui)) + } + + @Path("receivers") + def getReceivers(): AllReceiversResource = { + new AllReceiversResource(getListener(ui)) + } + + @Path("receivers/{streamId: \\d+}") + def getReceiver(): OneReceiverResource = { + new OneReceiverResource(getListener(ui)) + } + + @Path("batches") + def getBatches(): AllBatchesResource = { + new AllBatchesResource(getListener(ui)) + } + + @Path("batches/{batchId: \\d+}") + def getBatch(): OneBatchResource = { + new OneBatchResource(getListener(ui)) + } + + @Path("batches/{batchId: \\d+}/operations") + def getOutputOperations(): AllOutputOperationsResource = { + new AllOutputOperationsResource(getListener(ui)) + } + + @Path("batches/{batchId: \\d+}/operations/{outputOpId: \\d+}") + def getOutputOperation(): OneOutputOperationResource = { + new OneOutputOperationResource(getListener(ui)) + } + +} + +private[v1] object ApiStreamingRootResource { + def getListener(ui: SparkUI): StreamingJobProgressListener = { + ui.getStreamingJobProgressListener match { + case Some(listener) => listener.asInstanceOf[StreamingJobProgressListener] + case None => throw new NotFoundException("no streaming listener attached to " + ui.getAppName) + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneBatchResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneBatchResource.scala new file mode 100644 index 0000000000..d3c689c790 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneBatchResource.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1.streaming + +import javax.ws.rs.{GET, PathParam, Produces} +import javax.ws.rs.core.MediaType + +import org.apache.spark.status.api.v1.NotFoundException +import org.apache.spark.streaming.ui.StreamingJobProgressListener + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class OneBatchResource(listener: StreamingJobProgressListener) { + + @GET + def oneBatch(@PathParam("batchId") batchId: Long): BatchInfo = { + val someBatch = AllBatchesResource.batchInfoList(listener) + .find { _.batchId == batchId } + someBatch.getOrElse(throw new NotFoundException("unknown batch: " + batchId)) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneOutputOperationResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneOutputOperationResource.scala new file mode 100644 index 0000000000..aabcdb29b0 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneOutputOperationResource.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1.streaming + +import javax.ws.rs.{GET, PathParam, Produces} +import javax.ws.rs.core.MediaType + +import org.apache.spark.status.api.v1.NotFoundException +import org.apache.spark.streaming.ui.StreamingJobProgressListener +import org.apache.spark.streaming.ui.StreamingJobProgressListener._ + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class OneOutputOperationResource(listener: StreamingJobProgressListener) { + + @GET + def oneOperation( + @PathParam("batchId") batchId: Long, + @PathParam("outputOpId") opId: OutputOpId): OutputOperationInfo = { + + val someOutputOp = AllOutputOperationsResource.outputOperationInfoList(listener, batchId) + .find { _.outputOpId == opId } + someOutputOp.getOrElse(throw new NotFoundException("unknown output operation: " + opId)) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneReceiverResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneReceiverResource.scala new file mode 100644 index 0000000000..c0cc99da3a --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/OneReceiverResource.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1.streaming + +import javax.ws.rs.{GET, PathParam, Produces} +import javax.ws.rs.core.MediaType + +import org.apache.spark.status.api.v1.NotFoundException +import org.apache.spark.streaming.ui.StreamingJobProgressListener + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class OneReceiverResource(listener: StreamingJobProgressListener) { + + @GET + def oneReceiver(@PathParam("streamId") streamId: Int): ReceiverInfo = { + val someReceiver = AllReceiversResource.receiverInfoList(listener) + .find { _.streamId == streamId } + someReceiver.getOrElse(throw new NotFoundException("unknown receiver: " + streamId)) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/StreamingStatisticsResource.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/StreamingStatisticsResource.scala new file mode 100644 index 0000000000..6cff87be59 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/StreamingStatisticsResource.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1.streaming + +import java.util.Date +import javax.ws.rs.{GET, Produces} +import javax.ws.rs.core.MediaType + +import org.apache.spark.streaming.ui.StreamingJobProgressListener + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class StreamingStatisticsResource(listener: StreamingJobProgressListener) { + + @GET + def streamingStatistics(): StreamingStatistics = { + listener.synchronized { + val batches = listener.retainedBatches + val avgInputRate = avgRate(batches.map(_.numRecords * 1000.0 / listener.batchDuration)) + val avgSchedulingDelay = avgTime(batches.flatMap(_.schedulingDelay)) + val avgProcessingTime = avgTime(batches.flatMap(_.processingDelay)) + val avgTotalDelay = avgTime(batches.flatMap(_.totalDelay)) + + new StreamingStatistics( + startTime = new Date(listener.startTime), + batchDuration = listener.batchDuration, + numReceivers = listener.numReceivers, + numActiveReceivers = listener.numActiveReceivers, + numInactiveReceivers = listener.numInactiveReceivers, + numTotalCompletedBatches = listener.numTotalCompletedBatches, + numRetainedCompletedBatches = listener.retainedCompletedBatches.size, + numActiveBatches = listener.numUnprocessedBatches, + numProcessedRecords = listener.numTotalProcessedRecords, + numReceivedRecords = listener.numTotalReceivedRecords, + avgInputRate = avgInputRate, + avgSchedulingDelay = avgSchedulingDelay, + avgProcessingTime = avgProcessingTime, + avgTotalDelay = avgTotalDelay + ) + } + } + + private def avgRate(data: Seq[Double]): Option[Double] = { + if (data.isEmpty) None else Some(data.sum / data.size) + } + + private def avgTime(data: Seq[Long]): Option[Long] = { + if (data.isEmpty) None else Some(data.sum / data.size) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/api.scala b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/api.scala new file mode 100644 index 0000000000..403b0eb0b5 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/status/api/v1/streaming/api.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status.api.v1.streaming + +import java.util.Date + +import org.apache.spark.streaming.ui.StreamingJobProgressListener._ + +class StreamingStatistics private[spark]( + val startTime: Date, + val batchDuration: Long, + val numReceivers: Int, + val numActiveReceivers: Int, + val numInactiveReceivers: Int, + val numTotalCompletedBatches: Long, + val numRetainedCompletedBatches: Long, + val numActiveBatches: Long, + val numProcessedRecords: Long, + val numReceivedRecords: Long, + val avgInputRate: Option[Double], + val avgSchedulingDelay: Option[Long], + val avgProcessingTime: Option[Long], + val avgTotalDelay: Option[Long]) + +class ReceiverInfo private[spark]( + val streamId: Int, + val streamName: String, + val isActive: Option[Boolean], + val executorId: Option[String], + val executorHost: Option[String], + val lastErrorTime: Option[Date], + val lastErrorMessage: Option[String], + val lastError: Option[String], + val avgEventRate: Option[Double], + val eventRates: Seq[(Long, Double)]) + +class BatchInfo private[spark]( + val batchId: Long, + val batchTime: Date, + val status: String, + val batchDuration: Long, + val inputSize: Long, + val schedulingDelay: Option[Long], + val processingTime: Option[Long], + val totalDelay: Option[Long], + val numActiveOutputOps: Int, + val numCompletedOutputOps: Int, + val numFailedOutputOps: Int, + val numTotalOutputOps: Int, + val firstFailureReason: Option[String]) + +class OutputOperationInfo private[spark]( + val outputOpId: OutputOpId, + val name: String, + val description: String, + val startTime: Option[Date], + val endTime: Option[Date], + val duration: Option[Long], + val failureReason: Option[String], + val jobIds: Seq[SparkJobId]) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 444261da8d..0a4c141e5b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -45,7 +45,8 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContextState._ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.streaming.scheduler.{ExecutorAllocationManager, JobScheduler, StreamingListener} +import org.apache.spark.streaming.scheduler. + {ExecutorAllocationManager, JobScheduler, StreamingListener, StreamingListenerStreamingStarted} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils, Utils} @@ -583,6 +584,8 @@ class StreamingContext private[streaming] ( scheduler.start() } state = StreamingContextState.ACTIVE + scheduler.listenerBus.post( + StreamingListenerStreamingStarted(System.currentTimeMillis())) } catch { case NonFatal(e) => logError("Error starting the context, marking it as stopped", e) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala index db0bae9958..28cb86c9f3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala @@ -21,6 +21,9 @@ import org.apache.spark.streaming.Time private[streaming] trait PythonStreamingListener{ + /** Called when the streaming has been started */ + def onStreamingStarted(streamingStarted: JavaStreamingListenerStreamingStarted) { } + /** Called when a receiver has been started */ def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted) { } @@ -51,6 +54,11 @@ private[streaming] trait PythonStreamingListener{ private[streaming] class PythonStreamingListenerWrapper(listener: PythonStreamingListener) extends JavaStreamingListener { + /** Called when the streaming has been started */ + override def onStreamingStarted(streamingStarted: JavaStreamingListenerStreamingStarted): Unit = { + listener.onStreamingStarted(streamingStarted) + } + /** Called when a receiver has been started */ override def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { listener.onReceiverStarted(receiverStarted) @@ -99,6 +107,9 @@ private[streaming] class PythonStreamingListenerWrapper(listener: PythonStreamin */ private[streaming] class JavaStreamingListener { + /** Called when the streaming has been started */ + def onStreamingStarted(streamingStarted: JavaStreamingListenerStreamingStarted): Unit = { } + /** Called when a receiver has been started */ def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { } @@ -131,6 +142,9 @@ private[streaming] class JavaStreamingListener { */ private[streaming] sealed trait JavaStreamingListenerEvent +private[streaming] class JavaStreamingListenerStreamingStarted(val time: Long) + extends JavaStreamingListenerEvent + private[streaming] class JavaStreamingListenerBatchSubmitted(val batchInfo: JavaBatchInfo) extends JavaStreamingListenerEvent diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala index b109b9f1cb..ee8370d262 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala @@ -77,6 +77,11 @@ private[streaming] class JavaStreamingListenerWrapper(javaStreamingListener: Jav ) } + override def onStreamingStarted(streamingStarted: StreamingListenerStreamingStarted): Unit = { + javaStreamingListener.onStreamingStarted( + new JavaStreamingListenerStreamingStarted(streamingStarted.time)) + } + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { javaStreamingListener.onReceiverStarted( new JavaStreamingListenerReceiverStarted(toJavaReceiverInfo(receiverStarted.receiverInfo))) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala index 58fc78d552..b57f9b772f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -29,6 +29,9 @@ import org.apache.spark.util.Distribution @DeveloperApi sealed trait StreamingListenerEvent +@DeveloperApi +case class StreamingListenerStreamingStarted(time: Long) extends StreamingListenerEvent + @DeveloperApi case class StreamingListenerBatchSubmitted(batchInfo: BatchInfo) extends StreamingListenerEvent @@ -66,6 +69,9 @@ case class StreamingListenerReceiverStopped(receiverInfo: ReceiverInfo) @DeveloperApi trait StreamingListener { + /** Called when the streaming has been started */ + def onStreamingStarted(streamingStarted: StreamingListenerStreamingStarted) { } + /** Called when a receiver has been started */ def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index 39f6e711a6..5fb0bd057d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -65,6 +65,8 @@ private[streaming] class StreamingListenerBus(sparkListenerBus: LiveListenerBus) listener.onOutputOperationStarted(outputOperationStarted) case outputOperationCompleted: StreamingListenerOutputOperationCompleted => listener.onOutputOperationCompleted(outputOperationCompleted) + case streamingStarted: StreamingListenerStreamingStarted => + listener.onStreamingStarted(streamingStarted) case _ => } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 61f852a0d3..95f582106c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -27,7 +27,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.scheduler._ -private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) +private[spark] class StreamingJobProgressListener(ssc: StreamingContext) extends SparkListener with StreamingListener { private val waitingBatchUIData = new HashMap[Time, BatchUIData] @@ -39,6 +39,8 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) private var totalProcessedRecords = 0L private val receiverInfos = new HashMap[Int, ReceiverInfo] + private var _startTime = -1L + // Because onJobStart and onBatchXXX messages are processed in different threads, // we may not be able to get the corresponding BatchUIData when receiving onJobStart. So here we // cannot use a map of (Time, BatchUIData). @@ -66,6 +68,10 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) val batchDuration = ssc.graph.batchDuration.milliseconds + override def onStreamingStarted(streamingStarted: StreamingListenerStreamingStarted) { + _startTime = streamingStarted.time + } + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { synchronized { receiverInfos(receiverStarted.receiverInfo.streamId) = receiverStarted.receiverInfo @@ -152,6 +158,8 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } + def startTime: Long = _startTime + def numReceivers: Int = synchronized { receiverInfos.size } @@ -267,7 +275,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) } } -private[streaming] object StreamingJobProgressListener { +private[spark] object StreamingJobProgressListener { type SparkJobId = Int type OutputOpId = Int } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 46cd3092e9..7abafd6ba7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -143,7 +143,8 @@ private[ui] class StreamingPage(parent: StreamingTab) import StreamingPage._ private val listener = parent.listener - private val startTime = System.currentTimeMillis() + + private def startTime: Long = listener.startTime /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index c5f8aada3f..9d1b82a634 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -38,6 +38,7 @@ private[spark] class StreamingTab(val ssc: StreamingContext) ssc.addStreamingListener(listener) ssc.sc.addSparkListener(listener) + parent.setStreamingJobProgressListener(listener) attachPage(new StreamingPage(this)) attachPage(new BatchPage(this)) diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java index ff0be820e0..63fd6c4422 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java @@ -22,6 +22,11 @@ public class JavaStreamingListenerAPISuite extends JavaStreamingListener { + @Override + public void onStreamingStarted(JavaStreamingListenerStreamingStarted streamingStarted) { + super.onStreamingStarted(streamingStarted); + } + @Override public void onReceiverStarted(JavaStreamingListenerReceiverStarted receiverStarted) { JavaReceiverInfo receiverInfo = receiverStarted.receiverInfo(); diff --git a/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala b/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala index 0295e059f7..cfd4323531 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala +++ b/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala @@ -29,6 +29,10 @@ class JavaStreamingListenerWrapperSuite extends SparkFunSuite { val listener = new TestJavaStreamingListener() val listenerWrapper = new JavaStreamingListenerWrapper(listener) + val streamingStarted = StreamingListenerStreamingStarted(1000L) + listenerWrapper.onStreamingStarted(streamingStarted) + assert(listener.streamingStarted.time === streamingStarted.time) + val receiverStarted = StreamingListenerReceiverStarted(ReceiverInfo( streamId = 2, name = "test", @@ -249,6 +253,7 @@ class JavaStreamingListenerWrapperSuite extends SparkFunSuite { class TestJavaStreamingListener extends JavaStreamingListener { + var streamingStarted: JavaStreamingListenerStreamingStarted = null var receiverStarted: JavaStreamingListenerReceiverStarted = null var receiverError: JavaStreamingListenerReceiverError = null var receiverStopped: JavaStreamingListenerReceiverStopped = null @@ -258,6 +263,10 @@ class TestJavaStreamingListener extends JavaStreamingListener { var outputOperationStarted: JavaStreamingListenerOutputOperationStarted = null var outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted = null + override def onStreamingStarted(streamingStarted: JavaStreamingListenerStreamingStarted): Unit = { + this.streamingStarted = streamingStarted + } + override def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { this.receiverStarted = receiverStarted } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 4e702bbb92..a3062ac946 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.streaming import java.util.concurrent.ConcurrentLinkedQueue -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials import scala.reflect.ClassTag +import org.scalatest.concurrent.Eventually.eventually + import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} @@ -657,48 +657,57 @@ class BasicOperationsSuite extends TestSuiteBase { .window(Seconds(4), Seconds(2)) } - val operatedStream = runCleanupTest(conf, operation _, - numExpectedOutput = cleanupTestInput.size / 2, rememberDuration = Seconds(3)) - val windowedStream2 = operatedStream.asInstanceOf[WindowedDStream[_]] - val windowedStream1 = windowedStream2.dependencies.head.asInstanceOf[WindowedDStream[_]] - val mappedStream = windowedStream1.dependencies.head - - // Checkpoint remember durations - assert(windowedStream2.rememberDuration === rememberDuration) - assert(windowedStream1.rememberDuration === rememberDuration + windowedStream2.windowDuration) - assert(mappedStream.rememberDuration === - rememberDuration + windowedStream2.windowDuration + windowedStream1.windowDuration) - - // WindowedStream2 should remember till 7 seconds: 10, 9, 8, 7 - // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5, 4 - // MappedStream should remember till 2 seconds: 10, 9, 8, 7, 6, 5, 4, 3, 2 - - // WindowedStream2 - assert(windowedStream2.generatedRDDs.contains(Time(10000))) - assert(windowedStream2.generatedRDDs.contains(Time(8000))) - assert(!windowedStream2.generatedRDDs.contains(Time(6000))) - - // WindowedStream1 - assert(windowedStream1.generatedRDDs.contains(Time(10000))) - assert(windowedStream1.generatedRDDs.contains(Time(4000))) - assert(!windowedStream1.generatedRDDs.contains(Time(3000))) - - // MappedStream - assert(mappedStream.generatedRDDs.contains(Time(10000))) - assert(mappedStream.generatedRDDs.contains(Time(2000))) - assert(!mappedStream.generatedRDDs.contains(Time(1000))) + runCleanupTest( + conf, + operation _, + numExpectedOutput = cleanupTestInput.size / 2, + rememberDuration = Seconds(3)) { operatedStream => + eventually(eventuallyTimeout) { + val windowedStream2 = operatedStream.asInstanceOf[WindowedDStream[_]] + val windowedStream1 = windowedStream2.dependencies.head.asInstanceOf[WindowedDStream[_]] + val mappedStream = windowedStream1.dependencies.head + + // Checkpoint remember durations + assert(windowedStream2.rememberDuration === rememberDuration) + assert( + windowedStream1.rememberDuration === rememberDuration + windowedStream2.windowDuration) + assert(mappedStream.rememberDuration === + rememberDuration + windowedStream2.windowDuration + windowedStream1.windowDuration) + + // WindowedStream2 should remember till 7 seconds: 10, 9, 8, 7 + // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5, 4 + // MappedStream should remember till 2 seconds: 10, 9, 8, 7, 6, 5, 4, 3, 2 + + // WindowedStream2 + assert(windowedStream2.generatedRDDs.contains(Time(10000))) + assert(windowedStream2.generatedRDDs.contains(Time(8000))) + assert(!windowedStream2.generatedRDDs.contains(Time(6000))) + + // WindowedStream1 + assert(windowedStream1.generatedRDDs.contains(Time(10000))) + assert(windowedStream1.generatedRDDs.contains(Time(4000))) + assert(!windowedStream1.generatedRDDs.contains(Time(3000))) + + // MappedStream + assert(mappedStream.generatedRDDs.contains(Time(10000))) + assert(mappedStream.generatedRDDs.contains(Time(2000))) + assert(!mappedStream.generatedRDDs.contains(Time(1000))) + } + } } test("rdd cleanup - updateStateByKey") { val updateFunc = (values: Seq[Int], state: Option[Int]) => { Some(values.sum + state.getOrElse(0)) } - val stateStream = runCleanupTest( - conf, _.map(_ -> 1).updateStateByKey(updateFunc).checkpoint(Seconds(3))) - - assert(stateStream.rememberDuration === stateStream.checkpointDuration * 2) - assert(stateStream.generatedRDDs.contains(Time(10000))) - assert(!stateStream.generatedRDDs.contains(Time(4000))) + runCleanupTest( + conf, _.map(_ -> 1).updateStateByKey(updateFunc).checkpoint(Seconds(3))) { stateStream => + eventually(eventuallyTimeout) { + assert(stateStream.rememberDuration === stateStream.checkpointDuration * 2) + assert(stateStream.generatedRDDs.contains(Time(10000))) + assert(!stateStream.generatedRDDs.contains(Time(4000))) + } + } } test("rdd cleanup - input blocks and persisted RDDs") { @@ -779,13 +788,16 @@ class BasicOperationsSuite extends TestSuiteBase { } } - /** Test cleanup of RDDs in DStream metadata */ + /** + * Test cleanup of RDDs in DStream metadata. `assertCleanup` is the function that asserts the + * cleanup of RDDs is successful. + */ def runCleanupTest[T: ClassTag]( conf2: SparkConf, operation: DStream[Int] => DStream[T], numExpectedOutput: Int = cleanupTestInput.size, rememberDuration: Duration = null - ): DStream[T] = { + )(assertCleanup: (DStream[T]) => Unit): DStream[T] = { // Setup the stream computation assert(batchDuration === Seconds(1), @@ -794,7 +806,11 @@ class BasicOperationsSuite extends TestSuiteBase { val operatedStream = ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]] if (rememberDuration != null) ssc.remember(rememberDuration) - val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput) + val output = runStreams[(Int, Int)]( + ssc, + cleanupTestInput.size, + numExpectedOutput, + () => assertCleanup(operatedStream)) val clock = ssc.scheduler.clock.asInstanceOf[Clock] assert(clock.getTimeMillis() === Seconds(10).milliseconds) assert(output.size === numExpectedOutput) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 9ecfa48091..6fb50a4052 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -67,42 +67,33 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val expectedOutput = input.map(_.toString) for (i <- input.indices) { testServer.send(input(i).toString + "\n") - Thread.sleep(500) clock.advance(batchDuration.milliseconds) } - // Make sure we finish all batches before "stop" - if (!batchCounter.waitUntilBatchesCompleted(input.size, 30000)) { - fail("Timeout: cannot finish all batches in 30 seconds") + + eventually(eventuallyTimeout) { + clock.advance(batchDuration.milliseconds) + // Verify whether data received was as expected + logInfo("--------------------------------") + logInfo("output.size = " + outputQueue.size) + logInfo("output") + outputQueue.asScala.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + + // Verify whether all the elements received are as expected + // (whether the elements were received one in each interval is not verified) + val output: Array[String] = outputQueue.asScala.flatMap(x => x).toArray + assert(output.length === expectedOutput.size) + for (i <- output.indices) { + assert(output(i) === expectedOutput(i)) + } } - // Ensure progress listener has been notified of all events - ssc.sparkContext.listenerBus.waitUntilEmpty(500) - - // Verify all "InputInfo"s have been reported - assert(ssc.progressListener.numTotalReceivedRecords === input.size) - assert(ssc.progressListener.numTotalProcessedRecords === input.size) - - logInfo("Stopping server") - testServer.stop() - logInfo("Stopping context") - ssc.stop() - - // Verify whether data received was as expected - logInfo("--------------------------------") - logInfo("output.size = " + outputQueue.size) - logInfo("output") - outputQueue.asScala.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") - expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("--------------------------------") - - // Verify whether all the elements received are as expected - // (whether the elements were received one in each interval is not verified) - val output: Array[String] = outputQueue.asScala.flatMap(x => x).toArray - assert(output.length === expectedOutput.size) - for (i <- output.indices) { - assert(output(i) === expectedOutput(i)) + eventually(eventuallyTimeout) { + assert(ssc.progressListener.numTotalReceivedRecords === input.length) + assert(ssc.progressListener.numTotalProcessedRecords === input.length) } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index fa975a1462..dbab708861 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -359,14 +359,20 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached. * * Returns a sequence of items for each RDD. + * + * @param ssc The StreamingContext + * @param numBatches The number of batches should be run + * @param numExpectedOutput The number of expected output + * @param preStop The function to run before stopping StreamingContext */ def runStreams[V: ClassTag]( ssc: StreamingContext, numBatches: Int, - numExpectedOutput: Int + numExpectedOutput: Int, + preStop: () => Unit = () => {} ): Seq[Seq[V]] = { // Flatten each RDD into a single Seq - runStreamsWithPartitions(ssc, numBatches, numExpectedOutput).map(_.flatten.toSeq) + runStreamsWithPartitions(ssc, numBatches, numExpectedOutput, preStop).map(_.flatten.toSeq) } /** @@ -376,11 +382,17 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { * * Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each * representing one partition. + * + * @param ssc The StreamingContext + * @param numBatches The number of batches should be run + * @param numExpectedOutput The number of expected output + * @param preStop The function to run before stopping StreamingContext */ def runStreamsWithPartitions[V: ClassTag]( ssc: StreamingContext, numBatches: Int, - numExpectedOutput: Int + numExpectedOutput: Int, + preStop: () => Unit = () => {} ): Seq[Seq[Seq[V]]] = { assert(numBatches > 0, "Number of batches to run stream computation is zero") assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") @@ -424,6 +436,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") Thread.sleep(100) // Give some time for the forgetting old RDDs to complete + preStop() } finally { ssc.stop(stopSparkContext = true) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala index b49e579071..1d2bf35a6d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala @@ -36,11 +36,11 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite private val batchDurationMillis = 1000L private var allocationClient: ExecutorAllocationClient = null - private var clock: ManualClock = null + private var clock: StreamManualClock = null before { allocationClient = mock[ExecutorAllocationClient] - clock = new ManualClock() + clock = new StreamManualClock() } test("basic functionality") { @@ -57,10 +57,14 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite reset(allocationClient) when(allocationClient.getExecutorIds()).thenReturn(Seq("1", "2")) addBatchProcTime(allocationManager, batchProcTimeMs.toLong) - clock.advance(SCALING_INTERVAL_DEFAULT_SECS * 1000 + 1) + val advancedTime = SCALING_INTERVAL_DEFAULT_SECS * 1000 + 1 + val expectedWaitTime = clock.getTimeMillis() + advancedTime + clock.advance(advancedTime) + // Make sure ExecutorAllocationManager.manageAllocation is called eventually(timeout(10 seconds)) { - body + assert(clock.isStreamWaitingAt(expectedWaitTime)) } + body } /** Verify that the expected number of total executor were requested */ @@ -394,3 +398,27 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite } } } + +/** + * A special manual clock that provide `isStreamWaitingAt` to allow the user to check if the clock + * is blocking. + */ +class StreamManualClock(time: Long = 0L) extends ManualClock(time) with Serializable { + private var waitStartTime: Option[Long] = None + + override def waitTillTime(targetTime: Long): Long = synchronized { + try { + waitStartTime = Some(getTimeMillis()) + super.waitTillTime(targetTime) + } finally { + waitStartTime = None + } + } + + /** + * Returns if the clock is blocking and the time it started to block is the parameter `time`. + */ + def isStreamWaitingAt(time: Long): Boolean = synchronized { + waitStartTime == Some(time) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 46ab3ac8de..56b400850f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -62,6 +62,10 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { 0 -> StreamInputInfo(0, 300L), 1 -> StreamInputInfo(1, 300L, Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "test"))) + // onStreamingStarted + listener.onStreamingStarted(StreamingListenerStreamingStarted(100L)) + listener.startTime should be (100) + // onBatchSubmitted val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToInputInfo, 1000, None, None, Map.empty) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) diff --git a/tools/pom.xml b/tools/pom.xml index b9be8db684..938ba2f6ac 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.2.0-SNAPSHOT ../pom.xml diff --git a/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala new file mode 100644 index 0000000000..ffa0b58ee7 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster + +import org.mockito.Mockito.when +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.serializer.JavaSerializer + +class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with LocalSparkContext { + + test("RequestExecutors reflects node blacklist and is serializable") { + sc = new SparkContext("local", "YarnSchedulerBackendSuite") + val sched = mock[TaskSchedulerImpl] + when(sched.sc).thenReturn(sc) + val yarnSchedulerBackend = new YarnSchedulerBackend(sched, sc) { + def setHostToLocalTaskCount(hostToLocalTaskCount: Map[String, Int]): Unit = { + this.hostToLocalTaskCount = hostToLocalTaskCount + } + } + val ser = new JavaSerializer(sc.conf).newInstance() + for { + blacklist <- IndexedSeq(Set[String](), Set("a", "b", "c")) + numRequested <- 0 until 10 + hostToLocalCount <- IndexedSeq( + Map[String, Int](), + Map("a" -> 1, "b" -> 2) + ) + } { + yarnSchedulerBackend.setHostToLocalTaskCount(hostToLocalCount) + when(sched.nodeBlacklist()).thenReturn(blacklist) + val req = yarnSchedulerBackend.prepareRequestExecutors(numRequested) + assert(req.requestedTotal === numRequested) + assert(req.nodeBlacklist === blacklist) + assert(req.hostToLocalTaskCount.keySet.intersect(blacklist).isEmpty) + // Serialize to make sure serialization doesn't throw an error + ser.serialize(req) + } + } + +}