// --------------------------------------------------------------------------------------------------------------------
// <copyright file="BayesSparkDistributer.scala" company="Bayes Server">
// Copyright (C) Bayes Server. All rights reserved.
// </copyright>
// <version>0.13</version>
// <dependencies>
// <dependency>
// <name>bayes-server</name>
// <version>7.x</version>
// </dependency>
// <dependency>
// <name>spark-core</name>
// <version>1.x.x</version>
// </dependency>
// <dependency>
// <name>scala</name>
// <version>2.10.4</version>
// </dependency>
// </dependencies>
// --------------------------------------------------------------------------------------------------------------------
package com.bayesserver.spark.core
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import com.bayesserver._
import com.bayesserver.inference._
import com.bayesserver.data.DefaultReadOptions
import org.apache.spark.SparkContext
import org.apache.spark.broadcast.Broadcast
import org.slf4j.LoggerFactory
import scala.util.Try
import com.bayesserver.learning.parameters._
import com.bayesserver.Network
import com.bayesserver.Table
import com.bayesserver.VariableValueType
import org.apache.spark.rdd.RDD
import com.bayesserver.data._
import com.bayesserver.data.distributed._ // Remove this line if using Bayes Server 6.x
import com.bayesserver.{WriteStreamAction, NameValuesReader, NameValuesWriter, Distributer}
import scala.collection.JavaConversions
import java.io._
import scala.language.implicitConversions
/**
* Implements the Bayes Server Distributer interface for distributed parameter learning of Bayesian networks.
*
* To use, call com.bayesserver.learning.parameters.ParameterLearning.learnDistributed(...)
*
* @param data The spark RDD.
* @param driverToWorker The configuration name value pairs to be passed to the mappers and reducers
* @param newEvidenceReader A factory method for creating evidence readers when required
*/
class BayesSparkDistributer[T]
(
val data: RDD[T],
val driverToWorker: NameValuesReader with NameValuesWriter,
val newEvidenceReader: (DistributedMapperContext, Iterator[T]) => EvidenceReader,
val licenseKey: Option[String] = None,
val workerToDriver: () => NameValuesReader with NameValuesWriter = () => new MemoryNameValues()
) extends Distributer[DistributerContext] with Serializable {
private val logger = LoggerFactory.getLogger(classOf[BayesSparkDistributer[T]])
/**
* @inheritdoc
*/
override def getConfiguration: NameValuesWriter = this.driverToWorker
/**
* @inheritdoc
*/
override def distribute(ctx: DistributerContext): NameValuesReader = {
logger.info("Distributer stage: " + ctx.getName)
// We use mapPartitions as the Bayes Server library reads/pulls the data via its EvidenceReader interface
// The calls to Bayes Server are not inline as the code is not serializable which is required by Spark
data
.mapPartitions(iterator => {
// this code executes on the worker nodes, so a license will not yet have been validated
licenseKey.foreach(s => License.validate(s))
new Mapper(driverToWorker, workerToDriver).call(
iterator,
newEvidenceReader
)
})
.reduce((a, b) => {
// this code executes on the worker nodes, so a license will not yet have been validated
licenseKey.foreach(s => License.validate(s))
new Reducer(driverToWorker).call(a, b, workerToDriver()) // this calls the Bayes Server ParameterLearning.learnDistributedReducer method
})
}
/**
* Calls the mapper phase of Bayes Server distributed parameter learning.
* @param driverToWorker Configuration information which is required on all nodes.
* @param workerToDriver Creates a new name value store, which Bayes Server uses to pass information.
*/
class Mapper(val driverToWorker: NameValuesReader, val workerToDriver: () => NameValuesReader with NameValuesWriter) extends Serializable {
/**
* Performs the Bayes Server map operation on an iterator generated from the RDD.mapPartitions.
* @param iterator The iterator generated from RDD.mapPartitions.
* @param newEvidenceReader A factory for creating a new EvidenceReader.
* @return An iterator of outputs from the map operation.
*/
def call(
iterator: Iterator[T],
newEvidenceReader: (DistributedMapperContext, Iterator[T]) => EvidenceReader
): Iterator[NameValuesReader] = {
if (iterator.isEmpty)
return Iterator.empty // Spark RDD.mapPartitions can pass in an empty Iterator
val output = workerToDriver()
ParameterLearning.learnDistributedMapper(
new EvidencePartition[DistributedMapperContext] {
override def createEvidenceReader(ctx: DistributedMapperContext): EvidenceReader = {
newEvidenceReader(ctx, iterator)
}
},
driverToWorker,
output,
new RelevanceTreeInferenceFactory)
Iterator(output)
}
}
/**
* Calls the reducer phase of Bayes Server distributed parameter learning.
*
* @param configuration Configuration information which is required on all nodes.
*/
class Reducer(val configuration: NameValuesReader) extends Serializable {
def call(a: NameValuesReader, b: NameValuesReader, output: NameValuesReader with NameValuesWriter): NameValuesReader = {
val inputs = Iterable(a, b)
ParameterLearning.learnDistributedReducer(JavaConversions.asJavaIterable(inputs), configuration, output)
output
}
}
}
/**
* Methods to compress and decompress values from a store such as MemoryNameValues or BroadcastNameValues.
*/
object CompressedNameValues {
def write(writeStreamAction: WriteStreamAction, output: OutputStream) = {
val zipped = new GZIPOutputStream(output)
writeStreamAction.write(zipped)
zipped.finish() // required to write final bytes
}
def read(input: InputStream): InputStream = {
new GZIPInputStream(input)
}
}
/**
* An adapter than can be used to add compression to an existing store.
* MemoryNameValues and BroadcastNameValues both have compression options already.
* @param wrapped The underlying store
*/
class CompressedNameValues(val wrapped: NameValuesReader with NameValuesWriter with Serializable) extends NameValuesReader with NameValuesWriter with Serializable {
/**
* @inheritdoc
*/
def write(name: String, writeStreamAction: WriteStreamAction) = {
this.wrapped.write(name, new WriteStreamAction {
override def write(output: OutputStream) = CompressedNameValues.write(writeStreamAction, output)
})
}
/**
* @inheritdoc
*/
def contains(name: String): Boolean = this.wrapped.contains(name)
/**
* @inheritdoc
*/
override def read(name: String): InputStream = CompressedNameValues.read(this.wrapped.read(name))
}
/**
* A Spark broadcast variable based implementation of both NameValuesReader and NameValuesWriter.
* Note that values can only be written in driver code, not on the workers.
*/
class BroadcastNameValues(@transient val sc: SparkContext, val compress: Boolean = true) extends NameValuesReader with NameValuesWriter with Serializable {
val map = new scala.collection.mutable.HashMap[String, Broadcast[Array[Byte]]]
/**
* @inheritdoc
*/
def write(name: String, writeStreamAction: WriteStreamAction) {
if (sc == null)
throw new IllegalStateException("SparkContext is null. BroadcastNameValues.write is not supported on worker nodes")
val output: ByteArrayOutputStream = new ByteArrayOutputStream
compress match {
case true => CompressedNameValues.write(writeStreamAction, output)
case false => writeStreamAction.write(output)
}
this.map.put(name, sc.broadcast(output.toByteArray))
}
/**
* @inheritdoc
*/
def contains(name: String): Boolean = {
this.map.contains(name)
}
/**
* @inheritdoc
*/
override def read(name: String): InputStream = {
val value = this.map.get(name).get
compress match {
case true => CompressedNameValues.read(new ByteArrayInputStream(value.value))
case false => new ByteArrayInputStream(value.value)
}
}
}
/**
* An in memory implementation of both NameValuesReader and NameValuesWriter.
*
*/
class MemoryNameValues(val compress: Boolean = true) extends NameValuesReader with NameValuesWriter with Serializable {
val map = new scala.collection.mutable.HashMap[String, Array[Byte]]
/**
* @inheritdoc
*/
def write(name: String, writeStreamAction: WriteStreamAction) {
val output: ByteArrayOutputStream = new ByteArrayOutputStream
compress match {
case true => CompressedNameValues.write(writeStreamAction, output)
case false => writeStreamAction.write(output)
}
this.map.put(name, output.toByteArray)
}
/**
* @inheritdoc
*/
def contains(name: String): Boolean = {
this.map.contains(name)
}
/**
* @inheritdoc
*/
override def read(name: String): InputStream = {
val value = this.map.get(name).get
compress match {
case true => CompressedNameValues.read(new ByteArrayInputStream(value))
case false => new ByteArrayInputStream(value)
}
}
}
/**
* Trait which can be used to help implement the Bayes Server EvidenceReader interface.
* @tparam T The type of data contained in the RDD. This could be a class, an Array[Double] or anything else which is convenient.
*/
trait IteratorEvidenceReader[T] extends EvidenceReader {
val iterator: Iterator[T]
require(iterator != null)
if (iterator.isEmpty)
throw new UnsupportedOperationException("Iterator is empty.")
/**
* Maps information from the current RDD element to Bayes Server variables using the Bayes Server evidence instance.
* @param item The RDD element.
* @param evidence The evidence to be updated.
*/
def setEvidence(item: T, evidence: Evidence)
/**
* @inheritdoc
*/
override def read(evidence: Evidence, readOptions: ReadOptions): Boolean = {
readOption(evidence, readOptions).isDefined
}
/**
* Converts the next RDD element into evidence.
* @param evidence The evidence instance which is to be updated.
* @param readOptions Options affecting the read.
* @return The RDD element, or None if no more records.
*/
def readOption(evidence: Evidence, readOptions: ReadOptions): Option[T] = {
if (!iterator.hasNext)
return None
if (!readOptions.getCleared) {
evidence.clear()
}
val current = iterator.next()
setEvidence(current, evidence)
Some(current)
}
/**
* @inheritdoc
*/
override def close(): Unit = {}
}
object TimeMode extends Enumeration {
type TimeMode = Value
/**
* Query times are absolute and zero based.
*/
val Absolute = Value
/**
* Query times are zero based but relative to the maximum evidence time.
*/
val Relative = Value
}
import TimeMode._
case class PredictTime(time: Int, timeMode: TimeMode)
/**
* Base class for predictions.
*/
sealed abstract class PredictValue
/**
* The predicted probability of a discrete variable state. When state is not specified, the probability of the most likely state (modal) is returned
* @param variable The variable to predict.
* @param state The state to predict or null to return the probability of the most likely state (modal).
* @param time The time at which to predict. Only required for temporal variables (time series predictions).
*/
final case class PredictState(variable: String, state: Option[String] = None, time: Option[PredictTime] = None) extends PredictValue
/**
* Predicts the most likely state (modal) for discrete variables or the predicted mean for continuous variables.
* @param name The variable name.
* @param time The time at which to predict. Only required for temporal variables (time series predictions).
*/
final case class PredictVariable(name: String, time: Option[PredictTime] = None) extends PredictValue
/**
* Predicts the variance of a continuous variable.
* @param name The variable name.
* @param time The time at which to predict. Only required for temporal variables (time series predictions).
*/
final case class PredictVariance(name: String, time: Option[PredictTime] = None) extends PredictValue
/**
* Predicts the log-likelihood of the case.
*/
final case class PredictLogLikelihood() extends PredictValue
/**
* Helper to make predictions easier.
*/
object Prediction {
private final class Reader[T](
val network: Network,
val factory: InferenceFactory,
val reader: IteratorEvidenceReader[T],
val predictions: Iterable[PredictValue]
) extends Iterator[(T, Try[Seq[Double]])] {
private val inference = factory.createInferenceEngine(network)
private val queryOptions = factory.createQueryOptions()
private val queryOutput = factory.createQueryOutput()
private val readOptions = new DefaultReadOptions()
private val predictionQueries: Seq[(PredictValue, Option[Distribution], Option[PredictTime])] = addQueries(predictions.toSeq, inference, queryOptions)
private val relativePredictions = predictionQueries.filter(pv => pv._3 match {
case Some(x) => x.timeMode == Relative
case None => false
})
private var current: Option[T] = read()
/**
* Adjusts the time for queries that use relative times.
* @param shift The amount to shift the query times.
*/
private def adjustRelativeTimes(shift: Integer, plus: Boolean) = {
if (shift != null) {
this.relativePredictions.foreach(p => {
p._2.get.timeShift(if (plus) shift else -shift)
})
}
}
override def hasNext: Boolean = current.isDefined
/**
* @inheritdoc
*/
override def next(): (T, Try[Seq[Double]]) = {
require(current.isDefined)
val result = (this.current.get, Try({
val maxEvidenceTime = this.inference.getEvidence.getMaxTime
adjustRelativeTimes(maxEvidenceTime, plus = true)
this.inference.query(this.queryOptions, this.queryOutput)
adjustRelativeTimes(maxEvidenceTime, plus = false) // reset
this.predictionQueries.map({
case (prediction, query, _) =>
prediction match {
case PredictState(variableName, stateName, time) =>
val variable = network.getVariables.get(variableName, true)
stateName match {
case Some(name) => query.get.getTable.get(variable.getStates.get(name, true))
case None => query.get.getTable.getMaxValue.getValue
}
case PredictVariable(variableName, time) =>
val variable = network.getVariables.get(variableName, true)
variable.getValueType match {
case VariableValueType.DISCRETE => query.get.getTable.getMaxValue.getIndex
case VariableValueType.CONTINUOUS => query.get.asInstanceOf[CLGaussian].getMean(0, 0)
}
case PredictVariance(variable, time) => query.get.asInstanceOf[CLGaussian].getVariance(0, 0)
case PredictLogLikelihood() =>
this.queryOutput.getLogLikelihood.doubleValue()
}
})
}))
this.current = read()
result
}
/**
* Reads the next element in the underlying evidence reader.
* @return The original RDD element or None if no further elements are available in this partition.
*/
private def read(): Option[T] = reader.readOption(this.inference.getEvidence, this.readOptions)
/**
* Add queries to the inference engine to cover all predictions, but do not duplicate
* @param predictions The predictions
* @param inference The inference engine
* @return The predictions and their associated queries. Note that the same query may be used for multiple predictions.
*/
private def addQueries(predictions: Seq[PredictValue], inference: Inference, queryOptions: QueryOptions): Seq[(PredictValue, Option[Distribution], Option[PredictTime])] = {
case class VariableTime(variable: String, time: Option[PredictTime])
val variableTimes: Seq[(PredictValue, Option[VariableTime])] = for (prediction <- predictions) yield {
prediction match {
case PredictState(variableName, state, time) => (prediction, Some(VariableTime(variableName, time)))
case PredictVariance(variableName, time) => (prediction, Some(VariableTime(variableName, time)))
case PredictVariable(variableName, time) => (prediction, Some(VariableTime(variableName, time)))
case PredictLogLikelihood() => queryOptions.setLogLikelihood(true); (prediction, None)
}
}
val distinctVariableTimes = variableTimes.filter(_._2.isDefined).map(_._2.get).distinct
implicit def toJavaInt(time: Option[PredictTime]): Integer = time match {
case Some(PredictTime(t, _)) => t // Relative times will be adjusted per case
case None => null
case _ => throw new MatchError()
}
val queries = distinctVariableTimes.map(vt => {
val variable = network.getVariables.get(vt.variable, true)
variable.getValueType match {
case VariableValueType.DISCRETE => (vt, new Table(variable, vt.time).asInstanceOf[Distribution])
case VariableValueType.CONTINUOUS => (vt, new CLGaussian(variable, vt.time).asInstanceOf[Distribution])
}
})
for (query <- queries)
inference.getQueryDistributions.add(query._2)
val queryMap: Map[VariableTime, Distribution] = queries.toMap
variableTimes.map(pvt => (pvt._1, pvt._2.map(queryMap(_)), pvt._2.map(vt => vt.time).flatten))
}
}
/**
* Make predictions from data.
* @param network The trained network.
* @param data The data to make predictions on. Typically test data.
* @param predictions The predictions to be made.
* @param newReader Creates a reader for a partition. Typically the variable you are trying to predict is not mapped.
* @tparam T The RDD element type.
* @return An RDD of pairs containing the original RDD element and the predictions as Double values.
*/
def predict[T](
network: Network,
data: RDD[T],
predictions: Iterable[PredictValue],
newReader: (Network, Iterator[T]) => IteratorEvidenceReader[T],
licenseKey: Option[String] = None): RDD[(T, Try[Seq[Double]])] = {
// save the network to a string, as the Network class does not support serialization which is required by mapPartitions
val networkString = data.sparkContext.broadcast(network.saveToString())
data.mapPartitions(iterator => {
licenseKey.foreach(s => License.validate(s))
if (iterator.isEmpty)
Iterator.empty // as of Spark 1.1 mapPartitions can pass empty iterators
else {
val networkPartition = new Network
networkPartition.loadFromString(networkString.value)
// TODO allow configuration of Inference engine
new Reader(networkPartition, new RelevanceTreeInferenceFactory, newReader(networkPartition, iterator), predictions)
}
})
}
}