import jpype
import jpype.imports
from jpype.types import *
classpath = "lib/bayesserver-10.8.jar"
jpype.startJVM(classpath=[classpath])
from com.bayesserver import *
from com.bayesserver.inference import *
network = Network("Functions")
nodes = network.getNodes()
d1 = State('D1')
d2 = State('D2')
d = Variable('D', [d1, d2])
d.setExpressionAlias('d')
network.getNodes().add(Node(d))
c = Variable('C', VariableValueType.CONTINUOUS)
c.setExpressionAlias('c')
nodes.add(Node(c))
f1 = Variable('F1', VariableValueType.FUNCTION)
f1.setExpressionAlias('f1')
nodes.add(Node(f1))
f2 = Variable('F2', VariableValueType.FUNCTION)
network.getNodes().add(Node(f2))
f3 = Variable('F3', VariableValueType.FUNCTION)
nodes.add(Node(f3))
links = network.getLinks()
links.add(Link(d.getNode(), c.getNode()))
links.add(Link(f1.getNode(), f2.getNode()))
tableD = d.getNode().newDistribution().getTable()
tableD.set(0.6961106280502056, d1)
tableD.set(0.3038893719497943, d2)
d.getNode().setDistribution(tableD)
gaussianC = c.getNode().newDistribution()
gaussianC.setMean(c, 100.0, d1)
gaussianC.setVariance(c, 20.0, d1)
gaussianC.setMean(c, -50.0, d2)
gaussianC.setVariance(c, 45.0, d2)
c.getNode().setDistribution(gaussianC)
f1.setFunction(FunctionVariableExpression(
'''
var d1 = d.States.Get("D1", true);
var probD1 = ctx.TableValue(d1);
var meanC1 = ctx.Mean(c);
return probD1 * meanC1;
''',
ExpressionReturnType.DOUBLE))
f2.setFunction(FunctionVariableExpression(
'''
var f1Val = ctx.FunctionValue(f1);
return f1Val * 2.0;
''',
ExpressionReturnType.DOUBLE))
f3.setFunction(FunctionVariableExpression(
'''
return "PI = " + Math.PI;
''',
ExpressionReturnType.STRING))
factory = RelevanceTreeInferenceFactory()
inference = factory.createInferenceEngine(network)
queryOptions = factory.createQueryOptions()
queryOutput = factory.createQueryOutput()
queryD = Table(d)
inference.getQueryDistributions().add(queryD)
queryC = CLGaussian(c)
inference.getQueryDistributions().add(queryC)
queryF1 = QueryFunctionOutput(f1)
inference.getQueryFunctions().add(queryF1)
queryF2 = QueryFunctionOutput(f2)
inference.getQueryFunctions().add(queryF2)
queryF3 = QueryFunctionOutput(f3)
inference.getQueryFunctions().add(queryF3)
inference.query(queryOptions, queryOutput)
print(f"P(D|-) = {{ {queryD.get(d1)}, {queryD.get(d2)} }}.")
print(f"P(C|-) = {{ {queryC.getMean(c)}, {queryC.getVariance(c)} }}.")
print(f"F1 = {{ {queryF1.getValue()} }}.")
print(f"F2 = {{ {queryF2.getValue()} }}.")
print(f"F3 = {{ {queryF3.getValue()} }}.")