// --------------------------------------------------------------------------------------------------------------------
// <copyright file="MultivariateTimeSeries.scala" company="Bayes Server">
// Copyright (C) Bayes Server. All rights reserved.
// </copyright>
// --------------------------------------------------------------------------------------------------------------------
package com.bayesserver.spark.examples.parameterlearning
import org.apache.spark.{SparkContext, SparkConf}
import com.bayesserver._
import com.bayesserver.learning.parameters.{ParameterLearningOptions, ParameterLearning}
import com.bayesserver.inference.Evidence
import com.bayesserver.spark.core.{MemoryNameValues, BroadcastNameValues, IteratorEvidenceReader, BayesSparkDistributer}
/**
* Example that learns the parameters of a time series model.
*/
object MultivariateTimeSeries {
/**
* @param sc The SparkContext (see Apache Spark user guide for help on this)
*/
def apply(sc: SparkContext, licenseKey: Option[String] = None) = {
licenseKey.foreach(key => License.validate(key))
// hard code some test data. Normally you would read data from your cluster.
val data = createRDD(sc).cache()
// A network could be loaded from a file or stream
// we create it manually here to keep the example self contained
val network = createNetwork
val parameterLearningOptions = new ParameterLearningOptions
// Bayes Server supports multi-threaded learning
// which we want to turn off as Spark takes care of this
parameterLearningOptions.setMaximumConcurrency(1)
/// parameterLearningOptions.setMaximumIterations(...) // this can be useful to limit the number of iterations
val driverToWorker = new BroadcastNameValues(sc)
val output = ParameterLearning.learnDistributed(network, parameterLearningOptions,
new BayesSparkDistributer[Seq[(Double, Double)]](
data,
driverToWorker,
(ctx, iterator) => new TimeSeriesEvidenceReader(ctx.getNetwork, iterator),
licenseKey
))
// we could now call network.save(...) to a file or stream
// and the file could be opened in the Bayes Server User Interface
println("Time series parameter learning complete")
println("Case count = " + output.getCaseCount)
println("Log-likelihood = " + output.getLogLikelihood)
println("Converged = " + output.getConverged)
println("Iterations = " + output.getIterationCount)
}
/**
* Some test data. Normally you would load the data from the cluster.
*
* We have hard coded it here to keep the example self contained.
* @return An RDD
*/
def createRDD(sc: SparkContext) = {
sc.parallelize(Seq(
Seq((1.0, 2.3), (2.3, 4.5), (6.2, 7.2), (4.2, 6.6)),
Seq((3.3, -1.2), (3.2, 4.4), (-3.3, -2.3), (4.15, 1.2), (8.8, 2.2), (4.1, 9.9)),
Seq((1.0, 2.0), (3.3, 4.1)),
Seq((5.0, 21.3), (4.3, 6.6), (-2.1, 4.5)),
Seq((4.35, -3.25), (13.44, 12.4), (-1.3, 3.33), (4.2, 2.15), (12.8, 4.25)),
Seq((1.46, 2.22), (1.37, 3.15), (2.2, 2.25))
))
}
/**
* Create a network in code. An existing network could also be read from file or stream using Network.load.
* @return A Bayes Server network.
*/
def createNetwork = {
val network = new Network
val series = new Node()
series.setName("Series")
series.setTemporalType(TemporalType.TEMPORAL)
val x = new Variable("X", VariableValueType.CONTINUOUS)
series.getVariables.add(x)
val y = new Variable("Y", VariableValueType.CONTINUOUS)
series.getVariables.add(y)
network.getNodes.add(series)
network.getLinks.add(new Link(series, series, 1))
network.getLinks.add(new Link(series, series, 2))
network.getLinks.add(new Link(series, series, 3))
network
}
/**
* Implements the Bayes Server EvidenceReader interface, for reading our data.
* @param network The network
* @param iterator The iterator, which will be generated by RDD.mapPartitions.
*/
class TimeSeriesEvidenceReader(val network: Network, val iterator: Iterator[Seq[(Double, Double)]])
extends IteratorEvidenceReader[Seq[(Double, Double)]] {
val x = network.getVariables.get("X")
val y = network.getVariables.get("Y")
override def setEvidence(item: Seq[(Double, Double)], evidence: Evidence): Unit = {
for (time <- 0 until item.length) {
evidence.set(x, item(time)._1, time)
evidence.set(y, item(time)._2, time)
}
}
}
}