namespace BayesServer.HelpSamples
{
using BayesServer.Analysis;
using BayesServer.Data;
using BayesServer.Inference;
using BayesServer.Inference.RelevanceTree;
using BayesServer.Learning.Parameters;
using System;
using System.Collections.Generic;
using System.Data;
using System.Linq;
using System.Text;
public static class CrossValidationExample
{
public static void Main()
{
var network = LoadNetwork();
var factory = new RelevanceTreeInferenceFactory();
var result = Score(
network,
3,
factory,
(networkCopy, dataPartitioning) => NewEvidenceReaderCommand(networkCopy, dataPartitioning),
(networkCopy, evidenceReaderCommand) => learn(networkCopy, evidenceReaderCommand, factory));
Console.WriteLine("Cross validated Log-likelihood: " + result.Score);
}
private static void learn(Network network, IEvidenceReaderCommand evidenceReaderCommand, IInferenceFactory inferenceFactory)
{
var options = new ParameterLearningOptions();
var parameterLearning = new ParameterLearning(network, inferenceFactory);
parameterLearning.Learn(evidenceReaderCommand, options);
}
private static ICrossValidationScore Score(
Network network,
int partitions,
IInferenceFactory inferenceFactory,
Func<Network, DataPartitioning, IEvidenceReaderCommand> createEvidenceReaderCommand,
Action<Network, IEvidenceReaderCommand> learn)
{
var metricScores = CrossValidation.kFold(
partitionCount: partitions,
testMetricCount: 1,
learn: (trainingPartitioning) =>
{
var networkCopy = network.Copy();
var evidenceReaderCommand = createEvidenceReaderCommand(networkCopy, trainingPartitioning);
learn(networkCopy, evidenceReaderCommand);
return new CrossValidationNetwork(networkCopy);
}
,
test: (testPartitioning, crossValidationNetwork) =>
{
var inference = inferenceFactory.CreateInferenceEngine(crossValidationNetwork.Network);
var queryOptions = inferenceFactory.CreateQueryOptions();
queryOptions.LogLikelihood = true;
var queryOutput = inferenceFactory.CreateQueryOutput();
var sumLogLikelihood = 0.0;
var weightedCaseCount = 0.0;
var evidenceReaderCommand = createEvidenceReaderCommand(crossValidationNetwork.Network, testPartitioning);
var evidenceReader = evidenceReaderCommand.ExecuteReader();
try
{
var readOptions = new ReadOptions();
while (evidenceReader.Read(inference.Evidence, readOptions))
{
inference.Query(queryOptions, queryOutput);
weightedCaseCount += inference.Evidence.Weight;
sumLogLikelihood += queryOutput.LogLikelihood.Value;
}
}
finally
{
evidenceReader.Dispose();
}
var testResults = new ICrossValidationTestResult[1];
testResults[0] = new CrossValidationTestResult(weightedCaseCount, sumLogLikelihood, sumLogLikelihood);
return testResults;
}
,
combine: (metric, testResults) =>
{
var combineMethod = CrossValidationCombineMethod.UnweightedSum;
double score = CrossValidation.Combine(
testResults,
combineMethod
);
return new CrossValidationScore(score);
}
);
if (metricScores.Length != 1)
throw new InvalidOperationException();
return metricScores[0];
}
private static Network LoadNetwork()
{
var network = new Network();
var nodeA = new Node("A", new string[] { "False", "True" });
network.Nodes.Add(nodeA);
var nodeB = new Node("B", new string[] { "False", "True" });
network.Nodes.Add(nodeB);
network.Links.Add(new Link(nodeA, nodeB));
return network;
}
private static IEvidenceReaderCommand NewEvidenceReaderCommand(Network network, DataPartitioning dataPartitioning)
{
var dataReaderCommand = NewDataReaderCommand(dataPartitioning);
var variableReferences = network.Variables.Select(v => new VariableReference(v, ColumnValueType.Name, v.Name)).ToArray();
return new EvidenceReaderCommand(
dataReaderCommand,
variableReferences,
new ReaderOptions()
);
}
private static IDataReaderCommand NewDataReaderCommand(DataPartitioning dataPartitioning)
{
var table = new DataTable();
table.Columns.Add("A", typeof(string));
table.Columns.Add("B", typeof(string));
var method = dataPartitioning.Method;
var partition = dataPartitioning.PartitionNumber;
if (IncludeData(0, partition, method))
{
table.Rows.Add("False", "True");
table.Rows.Add("True", "True");
table.Rows.Add("True", "True");
table.Rows.Add("False", "False");
table.Rows.Add("False", "True");
table.Rows.Add("True", "False");
table.Rows.Add("True", "True");
}
if (IncludeData(1, partition, method))
{
table.Rows.Add("True", "False");
table.Rows.Add("True", "False");
table.Rows.Add("False", "False");
table.Rows.Add("True", "False");
table.Rows.Add("False", "True");
table.Rows.Add("False", "False");
}
if (IncludeData(2, partition, method))
{
table.Rows.Add("True", "True");
table.Rows.Add("True", "True");
table.Rows.Add("True", "True");
table.Rows.Add("False", "True");
table.Rows.Add("True", "True");
table.Rows.Add("False", "False");
}
Console.WriteLine("Method = {0}, partition = {1}, count = {2}", method, partition, table.Rows.Count);
return new DataTableDataReaderCommand(table);
}
private static bool IncludeData(int sourcePartition, int currentPartition, DataPartitionMethod method)
{
switch (method)
{
case DataPartitionMethod.IncludePartitionData:
return sourcePartition == currentPartition;
case DataPartitionMethod.ExcludePartitionData:
return sourcePartition != currentPartition;
default:
throw new InvalidOperationException();
}
}
}
}