//// --------------------------------------------------------------------------------------------------------------------
//// <copyright file="MixtureModelDf.scala" company="Bayes Server">
//// Copyright (C) Bayes Server. All rights reserved.
//// </copyright>
//// --------------------------------------------------------------------------------------------------------------------
//
//package com.bayesserver.spark.examples.parameterlearning
//
//import com.bayesserver._
//import com.bayesserver.inference.Evidence
//import com.bayesserver.learning.parameters.{InitializationMethod, ParameterLearning, ParameterLearningOptions}
//import com.bayesserver.spark.core.{BayesSparkDistributer, BroadcastNameValues, IteratorEvidenceReader}
//import org.apache.spark.SparkContext
//import org.apache.spark.sql.{Row, DataFrame, SQLContext}
//
///**
// * Example that learns the parameters of a mixture model, also known as a probabilistic cluster model, using a DataFrame.
// */
//object MixtureModelDf {
//
// /**
// * @param sc The SparkContext (see Apache Spark user guide for help on this)
// */
// def apply(sc: SparkContext, licenseKey: Option[String] = None) = {
//
// licenseKey.foreach(key => License.validate(key))
//
// val sqlContext = new SQLContext(sc)
//
// // hard code some test data. Normally you would read data from your cluster.
// val data = createDf(sqlContext).cache()
//
// // A network could be loaded from a file or stream
// // we create it manually here to keep the example self contained
// val network = createNetwork
//
// val parameterLearningOptions = new ParameterLearningOptions
//
// // Bayes Server supports multi-threaded learning
// // which we want to turn off as Spark takes care of this
// parameterLearningOptions.setMaximumConcurrency(1)
// parameterLearningOptions.getInitialization.setMethod(InitializationMethod.CLUSTERING)
// /// parameterLearningOptions.setMaximumIterations(...) // this can be useful to limit the number of iterations
//
// val driverToWorker = new BroadcastNameValues(sc)
//
// val output = ParameterLearning.learnDistributed(network, parameterLearningOptions,
// new BayesSparkDistributer[Row](
// data.rdd,
// driverToWorker,
// (ctx, iterator) => new MixtureModelEvidenceReader(ctx.getNetwork, iterator),
// licenseKey
// ))
//
// // we could now call network.save(...) to a file or stream
// // and the file could be opened in the Bayes Server User Interface
//
// println("Mixture model parameter learning complete")
// println("Case count = " + output.getCaseCount)
// println("Log-likelihood = " + output.getLogLikelihood)
// println("Converged = " + output.getConverged)
// println("Iterations = " + output.getIterationCount)
// }
//
// case class Point(x: Double, y: Double)
//
// /**
// * Some test data. Normally you would load the data from the cluster.
// *
// * We have hard coded it here to keep the example self contained.
// * @return A DataFrame
// */
// def createDf(sqlContext: SQLContext): DataFrame = {
//
// sqlContext.createDataFrame(Seq(
// Point(0.176502224, 7.640580199),
// Point(1.308020831, 8.321963251),
// Point(7.841271129, 3.34044587),
// Point(2.623799516, 6.667664279),
// Point(8.617288623, 3.319091539),
// Point(0.292639161, 9.070469416),
// Point(1.717525934, 6.509707265),
// Point(0.347388367, 9.144193334),
// Point(4.332228381, 0.129103276),
// Point(0.550570479, 9.925610034),
// Point(10.18819907, 3.414009144),
// Point(9.796154937, 4.335498562),
// Point(4.492011746, 0.527572356),
// Point(8.793496377, 3.811848391),
// Point(0.479689038, 8.041976487),
// Point(0.460045193, 10.74481444),
// Point(3.249955813, 5.58667984),
// Point(1.677468832, 8.742639202),
// Point(2.567398263, 3.338528008),
// Point(8.507535409, 3.358378353),
// Point(8.863647208, 3.533757566),
// Point(-0.612339597, 11.27289689),
// Point(10.38075113, 3.657256133),
// Point(9.443691262, 3.561824026),
// Point(1.589644185, 7.936062309),
// Point(7.680055137, 2.541577306),
// Point(1.047477704, 6.382052946),
// Point(0.735659679, 8.029083014),
// Point(0.489446685, 11.40715477),
// Point(3.258072314, 1.451124598),
// Point(0.140278917, 7.78885888),
// Point(9.237538442, 2.647543473),
// Point(2.28453948, 5.836716478),
// Point(7.22011534, 1.51979264),
// Point(1.474811913, 1.942052919),
// Point(1.674889251, 5.601765101),
// Point(1.30742068, 6.137114076),
// Point(6.957133145, 3.957540541),
// Point(10.87472856, 5.598949484),
// Point(1.110499364, 9.241584372),
// Point(7.233905739, 2.322237847),
// Point(7.474329505, 2.920099189),
// Point(0.455631413, 7.356350266),
// Point(1.234318558, 6.592203772),
// Point(10.72837103, 5.371838788),
// Point(0.655168407, 6.713544957),
// Point(2.001307579, 5.30283356),
// Point(0.061834893, 2.071499561),
// Point(1.86460938, 6.013710897)
// ))
// }
//
// /**
// * Create a network in code. An existing network could also be read from file or stream using Network.load.
// * @return A Bayes Server network.
// */
// def createNetwork = {
//
// val network = new Network
//
// val mixture = new Node("Mixture", 2)
// network.getNodes.add(mixture)
//
// val gaussian = new Node()
// gaussian.setName("Gaussian")
// val x = new Variable("X", VariableValueType.CONTINUOUS)
// gaussian.getVariables.add(x)
// val y = new Variable("Y", VariableValueType.CONTINUOUS)
// gaussian.getVariables.add(y)
// network.getNodes.add(gaussian)
//
// network.getLinks.add(new Link(mixture, gaussian))
//
// network
// }
//
// /**
// * Implements the Bayes Server EvidenceReader interface, for reading our data.
// * @param network The network
// * @param iterator The iterator, which will be generated by DataFrame.rdd.mapPartitions.
// */
// class MixtureModelEvidenceReader(val network: Network, val iterator: Iterator[Row])
// extends IteratorEvidenceReader[Row] {
//
// val x = network.getVariables.get("X")
// val y = network.getVariables.get("Y")
//
// var xIndex = Option.empty[Int]
// var yIndex = Option.empty[Int]
//
// private def setContinuousEvidence(row: Row, evidence: Evidence, variable: Variable, fieldIndex: Int): Unit = {
//
// if (row.isNullAt(fieldIndex)) {
// evidence.clear(variable)
// } else {
// evidence.set(variable, row.getDouble(fieldIndex))
// }
//
// }
//
// override def setEvidence(row: Row, evidence: Evidence): Unit = {
//
// if (xIndex.isEmpty) {
// xIndex = Some(row.fieldIndex("x"))
// yIndex = Some(row.fieldIndex("y"))
// }
//
// setContinuousEvidence(row, evidence, x, xIndex.get)
// setContinuousEvidence(row, evidence, y, yIndex.get)
//
// }
// }
//
//}