Permalink
Browse files

[SPARK-16421][EXAMPLES][ML] Improve ML Example Outputs

## What changes were proposed in this pull request?
Improve example outputs to better reflect the functionality that is being presented.  This mostly consisted of modifying what was printed at the end of the example, such as calling show() with truncate=False, but sometimes required minor tweaks in the example data to get relevant output.  Explicitly set parameters when they are used as part of the example.  Fixed Java examples that failed to run because of using old-style MLlib Vectors or problem with schema.  Synced examples between different APIs.

## How was this patch tested?
Ran each example for Scala, Python, and Java and made sure output was legible on a terminal of width 100.

Author: Bryan Cutler <[email protected]>

Closes #14308 from BryanCutler/ml-examples-improve-output-SPARK-16260.
  • Loading branch information...
1 parent 2460f03 commit 180fd3e0a3426db200c97170926afb60751dfd0e @BryanCutler BryanCutler committed with srowen Aug 5, 2016
Showing with 427 additions and 2,757 deletions.
  1. +0 −1,000 data/mllib/lr-data/random.data
  2. +0 −1,000 data/mllib/lr_data.txt
  3. +0 −569 data/mllib/sample_tree_data.csv
  4. +5 −0 examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
  5. +3 −2 examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java
  6. +6 −5 examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java
  7. +6 −1 examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java
  8. +4 −0 examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java
  9. +1 −1 examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java
  10. +5 −1 examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java
  11. +2 −2 examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java
  12. +14 −1 examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java
  13. +2 −2 examples/src/main/java/org/apache/spark/examples/ml/JavaIsotonicRegressionExample.java
  14. +23 −5 examples/src/main/java/org/apache/spark/examples/ml/JavaMaxAbsScalerExample.java
  15. +25 −5 examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java
  16. +7 −1 examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java
  17. +7 −11 examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java
  18. +10 −3 examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java
  19. +21 −2 examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java
  20. +3 −1 examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java
  21. +1 −1 examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
  22. +1 −1 examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java
  23. +5 −9 examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java
  24. +1 −1 examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java
  25. +3 −0 examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java
  26. +5 −7 examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java
  27. +22 −11 examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java
  28. +4 −2 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java
  29. +2 −2 examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java
  30. +7 −2 examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java
  31. +6 −4 examples/src/main/python/ml/binarizer_example.py
  32. +3 −1 examples/src/main/python/ml/bucketizer_example.py
  33. +2 −0 examples/src/main/python/ml/chisq_selector_example.py
  34. +3 −1 examples/src/main/python/ml/count_vectorizer_example.py
  35. +1 −2 examples/src/main/python/ml/dct_example.py
  36. +3 −3 examples/src/main/python/ml/gaussian_mixture_example.py
  37. +11 −3 examples/src/main/python/ml/index_to_string_example.py
  38. +2 −2 examples/src/main/python/ml/isotonic_regression_example.py
  39. +10 −2 examples/src/main/python/ml/linear_regression_with_elastic_net.py
  40. +8 −2 examples/src/main/python/ml/max_abs_scaler_example.py
  41. +8 −2 examples/src/main/python/ml/min_max_scaler_example.py
  42. +1 −1 examples/src/main/python/ml/multilayer_perceptron_classification.py
  43. +4 −5 examples/src/main/python/ml/n_gram_example.py
  44. +8 −4 examples/src/main/python/ml/naive_bayes_example.py
  45. +8 −1 examples/src/main/python/ml/normalizer_example.py
  46. +2 −2 examples/src/main/python/ml/onehot_encoder_example.py
  47. +3 −2 examples/src/main/python/ml/pipeline_example.py
  48. +5 −6 examples/src/main/python/ml/polynomial_expansion_example.py
  49. +1 −1 examples/src/main/python/ml/stopwords_remover_example.py
  50. +4 −5 examples/src/main/python/ml/tf_idf_example.py
  51. +9 −5 examples/src/main/python/ml/tokenizer_example.py
  52. +4 −3 examples/src/main/python/ml/train_validation_split.py
  53. +2 −1 examples/src/main/python/ml/vector_assembler_example.py
  54. +4 −0 examples/src/main/python/ml/vector_indexer_example.py
  55. +3 −2 examples/src/main/python/ml/word2vec_example.py
  56. +5 −2 examples/src/main/python/pagerank.py
  57. +5 −0 examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
  58. +3 −2 examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala
  59. +5 −3 examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala
  60. +4 −1 examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala
  61. +3 −0 examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala
  62. +1 −1 examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala
  63. +1 −1 examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala
  64. +2 −2 examples/src/main/scala/org/apache/spark/examples/ml/GaussianMixtureExample.scala
  65. +13 −1 examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala
  66. +2 −2 examples/src/main/scala/org/apache/spark/examples/ml/IsotonicRegressionExample.scala
  67. +1 −1 examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala
  68. +2 −1 examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala
  69. +8 −2 examples/src/main/scala/org/apache/spark/examples/ml/MaxAbsScalerExample.scala
  70. +8 −2 examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala
  71. +1 −1 examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala
  72. +4 −3 examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala
  73. +1 −1 examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala
  74. +8 −1 examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala
  75. +2 −1 examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala
  76. +1 −1 examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
  77. +4 −3 examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala
  78. +7 −5 examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala
  79. +1 −1 examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala
  80. +4 −4 examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala
  81. +8 −3 examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala
  82. +2 −0 examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala
  83. +2 −1 examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala
  84. +5 −2 examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala
  85. +4 −1 examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala
Oops, something went wrong.
View
Oops, something went wrong.
Oops, something went wrong.
@@ -45,6 +45,11 @@
*
* This is an example implementation for learning how to use Spark. For more conventional use,
* please refer to org.apache.spark.graphx.lib.PageRank
+ *
+ * Example Usage:
+ * <pre>
+ * bin/run-example JavaPageRank data/mllib/pagerank_data.txt 10
+ * </pre>
*/
public final class JavaPageRank {
private static final Pattern SPACES = Pattern.compile("\\s+");
@@ -71,8 +71,9 @@ public static void main(String[] args) {
AFTSurvivalRegressionModel model = aft.fit(training);
// Print the coefficients, intercept and scale parameter for AFT survival regression
- System.out.println("Coefficients: " + model.coefficients() + " Intercept: "
- + model.intercept() + " Scale: " + model.scale());
+ System.out.println("Coefficients: " + model.coefficients());
+ System.out.println("Intercept: " + model.intercept());
+ System.out.println("Scale: " + model.scale());
model.transform(training).show(false);
// $example off$
@@ -51,17 +51,18 @@ public static void main(String[] args) {
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
});
Dataset<Row> continuousDataFrame = spark.createDataFrame(data, schema);
+
Binarizer binarizer = new Binarizer()
.setInputCol("feature")
.setOutputCol("binarized_feature")
.setThreshold(0.5);
+
Dataset<Row> binarizedDataFrame = binarizer.transform(continuousDataFrame);
- Dataset<Row> binarizedFeatures = binarizedDataFrame.select("binarized_feature");
- for (Row r : binarizedFeatures.collectAsList()) {
- Double binarized_value = r.getDouble(0);
- System.out.println(binarized_value);
- }
+
+ System.out.println("Binarizer output with Threshold = " + binarizer.getThreshold());
+ binarizedDataFrame.show();
// $example off$
+
spark.stop();
}
}
@@ -44,10 +44,12 @@ public static void main(String[] args) {
double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY};
List<Row> data = Arrays.asList(
+ RowFactory.create(-999.9),
RowFactory.create(-0.5),
RowFactory.create(-0.3),
RowFactory.create(0.0),
- RowFactory.create(0.2)
+ RowFactory.create(0.2),
+ RowFactory.create(999.9)
);
StructType schema = new StructType(new StructField[]{
new StructField("features", DataTypes.DoubleType, false, Metadata.empty())
@@ -61,8 +63,11 @@ public static void main(String[] args) {
// Transform original data into its bucket index.
Dataset<Row> bucketedData = bucketizer.transform(dataFrame);
+
+ System.out.println("Bucketizer output with " + (bucketizer.getSplits().length-1) + " buckets");
bucketedData.show();
// $example off$
+
spark.stop();
}
}
@@ -63,7 +63,11 @@ public static void main(String[] args) {
.setOutputCol("selectedFeatures");
Dataset<Row> result = selector.fit(df).transform(df);
+
+ System.out.println("ChiSqSelector output with top " + selector.getNumTopFeatures()
+ + " features selected");
result.show();
+
// $example off$
spark.stop();
}
@@ -61,7 +61,7 @@ public static void main(String[] args) {
.setInputCol("text")
.setOutputCol("feature");
- cvModel.transform(df).show();
+ cvModel.transform(df).show(false);
// $example off$
spark.stop();
@@ -51,13 +51,17 @@ public static void main(String[] args) {
new StructField("features", new VectorUDT(), false, Metadata.empty()),
});
Dataset<Row> df = spark.createDataFrame(data, schema);
+
DCT dct = new DCT()
.setInputCol("features")
.setOutputCol("featuresDCT")
.setInverse(false);
+
Dataset<Row> dctDf = dct.transform(df);
- dctDf.select("featuresDCT").show(3);
+
+ dctDf.select("featuresDCT").show(false);
// $example off$
+
spark.stop();
}
}
@@ -54,8 +54,8 @@ public static void main(String[] args) {
// Output the parameters of the mixture model
for (int i = 0; i < model.getK(); i++) {
- System.out.printf("weight=%f\nmu=%s\nsigma=\n%s\n",
- model.weights()[i], model.gaussians()[i].mean(), model.gaussians()[i].cov());
+ System.out.printf("Gaussian %d:\nweight=%f\nmu=%s\nsigma=\n%s\n\n",
+ i, model.weights()[i], model.gaussians()[i].mean(), model.gaussians()[i].cov());
}
// $example off$
@@ -24,6 +24,7 @@
import java.util.Arrays;
import java.util.List;
+import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
@@ -63,11 +64,23 @@ public static void main(String[] args) {
.fit(df);
Dataset<Row> indexed = indexer.transform(df);
+ System.out.println("Transformed string column '" + indexer.getInputCol() + "' " +
+ "to indexed column '" + indexer.getOutputCol() + "'");
+ indexed.show();
+
+ StructField inputColSchema = indexed.schema().apply(indexer.getOutputCol());
+ System.out.println("StringIndexer will store labels in output column metadata: " +
+ Attribute.fromStructField(inputColSchema).toString() + "\n");
+
IndexToString converter = new IndexToString()
.setInputCol("categoryIndex")
.setOutputCol("originalCategory");
Dataset<Row> converted = converter.transform(indexed);
- converted.select("id", "originalCategory").show();
+
+ System.out.println("Transformed indexed column '" + converter.getInputCol() + "' back to " +
+ "original string column '" + converter.getOutputCol() + "' using labels in metadata");
+ converted.select("id", "categoryIndex", "originalCategory").show();
+
// $example off$
spark.stop();
}
@@ -50,8 +50,8 @@ public static void main(String[] args) {
IsotonicRegression ir = new IsotonicRegression();
IsotonicRegressionModel model = ir.fit(dataset);
- System.out.println("Boundaries in increasing order: " + model.boundaries());
- System.out.println("Predictions associated with the boundaries: " + model.predictions());
+ System.out.println("Boundaries in increasing order: " + model.boundaries() + "\n");
+ System.out.println("Predictions associated with the boundaries: " + model.predictions() + "\n");
// Makes predictions.
model.transform(dataset).show();
@@ -18,10 +18,20 @@
package org.apache.spark.examples.ml;
// $example on$
+import java.util.Arrays;
+import java.util.List;
+
import org.apache.spark.ml.feature.MaxAbsScaler;
import org.apache.spark.ml.feature.MaxAbsScalerModel;
+import org.apache.spark.ml.linalg.Vectors;
+import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
// $example off$
import org.apache.spark.sql.SparkSession;
@@ -34,10 +44,17 @@ public static void main(String[] args) {
.getOrCreate();
// $example on$
- Dataset<Row> dataFrame = spark
- .read()
- .format("libsvm")
- .load("data/mllib/sample_libsvm_data.txt");
+ List<Row> data = Arrays.asList(
+ RowFactory.create(0, Vectors.dense(1.0, 0.1, -8.0)),
+ RowFactory.create(1, Vectors.dense(2.0, 1.0, -4.0)),
+ RowFactory.create(2, Vectors.dense(4.0, 10.0, 8.0))
+ );
+ StructType schema = new StructType(new StructField[]{
+ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
+ new StructField("features", new VectorUDT(), false, Metadata.empty())
+ });
+ Dataset<Row> dataFrame = spark.createDataFrame(data, schema);
+
MaxAbsScaler scaler = new MaxAbsScaler()
.setInputCol("features")
.setOutputCol("scaledFeatures");
@@ -47,8 +64,9 @@ public static void main(String[] args) {
// rescale each feature to range [-1, 1].
Dataset<Row> scaledData = scalerModel.transform(dataFrame);
- scaledData.show();
+ scaledData.select("features", "scaledFeatures").show();
// $example off$
+
spark.stop();
}
@@ -20,10 +20,20 @@
import org.apache.spark.sql.SparkSession;
// $example on$
+import java.util.Arrays;
+import java.util.List;
+
import org.apache.spark.ml.feature.MinMaxScaler;
import org.apache.spark.ml.feature.MinMaxScalerModel;
+import org.apache.spark.ml.linalg.Vectors;
+import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
// $example off$
public class JavaMinMaxScalerExample {
@@ -34,10 +44,17 @@ public static void main(String[] args) {
.getOrCreate();
// $example on$
- Dataset<Row> dataFrame = spark
- .read()
- .format("libsvm")
- .load("data/mllib/sample_libsvm_data.txt");
+ List<Row> data = Arrays.asList(
+ RowFactory.create(0, Vectors.dense(1.0, 0.1, -1.0)),
+ RowFactory.create(1, Vectors.dense(2.0, 1.1, 1.0)),
+ RowFactory.create(2, Vectors.dense(3.0, 10.1, 3.0))
+ );
+ StructType schema = new StructType(new StructField[]{
+ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
+ new StructField("features", new VectorUDT(), false, Metadata.empty())
+ });
+ Dataset<Row> dataFrame = spark.createDataFrame(data, schema);
+
MinMaxScaler scaler = new MinMaxScaler()
.setInputCol("features")
.setOutputCol("scaledFeatures");
@@ -47,8 +64,11 @@ public static void main(String[] args) {
// rescale each feature to range [min, max].
Dataset<Row> scaledData = scalerModel.transform(dataFrame);
- scaledData.show();
+ System.out.println("Features scaled to range: [" + scaler.getMin() + ", "
+ + scaler.getMax() + "]");
+ scaledData.select("features", "scaledFeatures").show();
// $example off$
+
spark.stop();
}
}
@@ -41,28 +41,34 @@ public static void main(String[] args) {
// Load training data
String path = "data/mllib/sample_multiclass_classification_data.txt";
Dataset<Row> dataFrame = spark.read().format("libsvm").load(path);
+
// Split the data into train and test
Dataset<Row>[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
Dataset<Row> train = splits[0];
Dataset<Row> test = splits[1];
+
// specify layers for the neural network:
// input layer of size 4 (features), two intermediate of size 5 and 4
// and output of size 3 (classes)
int[] layers = new int[] {4, 5, 4, 3};
+
// create the trainer and set its parameters
MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier()
.setLayers(layers)
.setBlockSize(128)
.setSeed(1234L)
.setMaxIter(100);
+
// train the model
MultilayerPerceptronClassificationModel model = trainer.fit(train);
+
// compute accuracy on the test set
Dataset<Row> result = model.transform(test);
Dataset<Row> predictionAndLabels = result.select("prediction", "label");
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
.setMetricName("accuracy");
- System.out.println("Accuracy = " + evaluator.evaluate(predictionAndLabels));
+
+ System.out.println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels));
// $example off$
spark.stop();
@@ -42,29 +42,25 @@ public static void main(String[] args) {
// $example on$
List<Row> data = Arrays.asList(
- RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")),
- RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")),
- RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat"))
+ RowFactory.create(0, Arrays.asList("Hi", "I", "heard", "about", "Spark")),
+ RowFactory.create(1, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")),
+ RowFactory.create(2, Arrays.asList("Logistic", "regression", "models", "are", "neat"))
);
StructType schema = new StructType(new StructField[]{
- new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
+ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField(
"words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
});
Dataset<Row> wordDataFrame = spark.createDataFrame(data, schema);
- NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams");
+ NGram ngramTransformer = new NGram().setN(2).setInputCol("words").setOutputCol("ngrams");
Dataset<Row> ngramDataFrame = ngramTransformer.transform(wordDataFrame);
-
- for (Row r : ngramDataFrame.select("ngrams", "label").takeAsList(3)) {
- java.util.List<String> ngrams = r.getList(0);
- for (String ngram : ngrams) System.out.print(ngram + " --- ");
- System.out.println();
- }
+ ngramDataFrame.select("ngrams").show(false);
// $example off$
+
spark.stop();
}
}
@@ -48,14 +48,21 @@ public static void main(String[] args) {
// create the trainer and set its parameters
NaiveBayes nb = new NaiveBayes();
+
// train the model
NaiveBayesModel model = nb.fit(train);
+
+ // Select example rows to display.
+ Dataset<Row> predictions = model.transform(test);
+ predictions.show();
+
// compute accuracy on the test set
- Dataset<Row> result = model.transform(test);
- Dataset<Row> predictionAndLabels = result.select("prediction", "label");
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
+ .setLabelCol("label")
+ .setPredictionCol("prediction")
.setMetricName("accuracy");
- System.out.println("Accuracy = " + evaluator.evaluate(predictionAndLabels));
+ double accuracy = evaluator.evaluate(predictions);
+ System.out.println("Test set accuracy = " + accuracy);
// $example off$
spark.stop();
Oops, something went wrong.

0 comments on commit 180fd3e

Please sign in to comment.