import jpype
import jpype.imports
from jpype.types import *
classpath = "lib/bayesserver-10.8.jar"
jpype.startJVM(classpath=[classpath])
from com.bayesserver import *
from com.bayesserver.inference import *
from jpype import java
kinds = [
NodeDistributionKind.PROBABILITY,
NodeDistributionKind.EXPERIENCE,
NodeDistributionKind.FADING
]
class CacheItem:
def __init__(self, key, kind, distribution):
self.key = key
self.kind = kind
self.distribution = distribution
def map_variable_contexts(source, destination, source_to_dest):
"""
Maps variable contexts from one distribution to another
"""
if len(source) != len(destination):
raise ValueError
mapped = []
for v in range(len(source)):
source_v = source[v]
source_variable = source_v.getVariable()
dest_variable = source_to_dest[source_variable]
mapped.append(destination.indexOf(dest_variable, source_v.getTime()))
return mapped
def copy_distribution(source, dest, source_to_dest):
"""
Copies one distribution into another, using a mapping between source and
destination variables.
"""
source_to_dest_discrete = map_variable_contexts(
source.getTable().getSortedVariables(),
dest.getTable().getSortedVariables(),
source_to_dest)
dest_order_discrete = []
for i in range(dest.getTable().getSortedVariables().size()):
dest_order_discrete.append(dest.getTable().getSortedVariables().get(source_to_dest_discrete[i]))
iterator_dest = TableIterator(dest.getTable(), java.util.Arrays.asList(dest_order_discrete))
order_head = None
order_tail = None
copy_gaussian = isinstance(source, CLGaussian)
if copy_gaussian:
order_head = map_variable_contexts(
source.getSortedContinuousHead(),
dest.getSortedContinuousHead(),
source_to_dest)
order_tail = map_variable_contexts(
source.getSortedContinuousTail(),
dest.getSortedContinuousTail(),
source_to_dest)
for p in range(source.getTable().size()):
source_prob = source.getTable().get(p)
iterator_dest.setValue(source_prob)
if copy_gaussian:
for h1 in range(order_head.Length):
h1_dest = order_head[h1]
dest.setMean(p, h1_dest, source.getMean(p, h1))
for h2 in range(order_head.Length):
h2_dest = order_head[h2]
dest.setCovariance(p, h1_dest, h2_dest, source.getCovariance(p, h1, h2))
for t in range(order_tail.Length):
t_dest = order_tail[t]
dest.SetWeight(p, h1_dest, t_dest, source.getWeight(p, h1, t))
iterator_dest.increment()
def copy_fragment(source_network, source_nodes, dest_network, suffix, migrate_distributions):
"""
Copies nodes from a source network into a destination network.
The source and destination can be the same.
"""
source_nodes_lookup = set(source_nodes)
source_nodes_and_parents = []
source_nodes_and_parents_lookup = set()
source_node_groups_lookup = set()
source_node_groups_to_add = []
for source_node in source_nodes:
if source_node not in source_nodes_and_parents_lookup:
source_nodes_and_parents.append(source_node)
source_nodes_and_parents_lookup.add(source_node)
for link in source_node.getLinksIn():
if link.getFrom() not in source_nodes_and_parents_lookup:
source_nodes_and_parents.append(link.getFrom())
source_nodes_and_parents_lookup.add(link.getFrom())
for groupName in source_node.getGroups():
group = source_network.NodeGroups[groupName]
if source_node_groups_lookup.add(group):
source_node_groups_to_add.append(group)
source_links_to_add = []
source_links_to_add_lookup = set()
for source_node in source_nodes_and_parents:
for link in source_node.getLinks():
if link.getFrom() in source_nodes_and_parents_lookup and link.getTo() in source_nodes_and_parents_lookup:
if link not in source_links_to_add_lookup:
source_links_to_add.append(link)
source_links_to_add_lookup.add(link)
source_nodes_and_parents.sort(key=lambda x: x.getIndex())
source_links_to_add.sort(key=lambda x: x.getIndex())
source_node_to_copy_node = {}
variable_map = {}
for source_group in source_node_groups_to_add:
if dest_network.getNodeGroups().get(source_group.getName()) is None:
dest_network.getNodeGroups().add(source_group.copy())
for source_node in source_nodes_and_parents:
dest_node = source_node.copy()
dest_node.setName(dest_node.getName() + suffix)
if source_network == dest_network:
source_bounds = source_node.getBounds()
dest_node.setBounds(Bounds(
source_bounds.getX() + 50,
source_bounds.getY() + 50,
source_bounds.getWidth(),
source_bounds.getHeight()))
for v in range(source_node.getVariables().size()):
source_variable = source_node.getVariables().get(v)
dest_variable = dest_node.getVariables().get(v)
dest_variable.setName(dest_variable.getName() + suffix)
variable_map[source_variable] = dest_variable
variable_map[dest_variable] = source_variable
if source_node not in source_nodes_lookup:
dest_node.getGroups().clear()
dest_network.getNodes().add(dest_node)
source_node_to_copy_node[source_node] = dest_node
for source_link in source_links_to_add:
from_dest = source_node_to_copy_node[source_link.getFrom()]
to_dest = source_node_to_copy_node[source_link.getTo()]
link_dest = source_link.copy(from_dest, to_dest, source_link.getTemporalOrder())
dest_network.getLinks().add(link_dest)
for source_node, node_dest in source_node_to_copy_node.items():
if source_node not in source_nodes_lookup:
continue
for source_key in source_node.getDistributions().getKeys():
related_node_dest = None
if source_key.getRelatedNode() is not None:
related_node_dest = source_node_to_copy_node[source_key.getRelatedNode()]
key_dest = NodeDistributionKey(source_key.getOrder(), related_node_dest)
if not node_dest.getDistributions().canUpdate(key_dest):
continue
for kind in kinds:
source_distribution = source_node.getDistributions().get(source_key, kind)
if source_distribution is None:
continue
distribution_dest = node_dest.newDistribution(key_dest, kind)
copy_distribution(source_distribution, distribution_dest, variable_map)
node_dest.getDistributions().set(key_dest, kind, distribution_dest)
for source_parent in source_nodes_and_parents:
if source_parent in source_nodes_lookup:
continue
parent_dest = source_node_to_copy_node[source_parent]
old_distributions_dest = None
if migrate_distributions:
old_distributions_dest = []
for source_link in source_parent.getLinksOut():
source_child = source_link.getTo()
if source_child not in source_nodes_lookup:
continue
assert (source_child in source_nodes_lookup)
child_dest = source_node_to_copy_node[source_child]
for key_dest in child_dest.getDistributions().getKeys():
for kind in kinds:
distribution_dest = child_dest.getDistributions().get(key_dest, kind)
if distribution_dest is None:
continue
cache_item = CacheItem(key_dest, kind, distribution_dest)
old_distributions_dest.append((child_dest, cache_item))
dest_network.getNodes().remove(parent_dest)
if migrate_distributions:
for node_dest, cache_item in old_distributions_dest:
old_distribution_dest = cache_item.distribution
if cache_item.key not in node_dest.getDistributions().getKeys():
continue
new_distribution_dest = node_dest.newDistribution(cache_item.key, cache_item.kind)
return dest_network
network_path = 'networks/Asia.bayes'
network = Network()
network.load(network_path)
nodes = network.getNodes()
visit_to_asia = nodes.get('Visit to Asia', True)
has_lung_cancer = nodes.get('Has Lung Cancer', True)
tuberculosis_or_cancer = nodes.get('Tuberculosis or Cancer', True)
smoker = nodes.get('Smoker', True)
has_tuberculosis = nodes.get('Has Tuberculosis', True)
dyspnea = nodes.get('Dyspnea', True)
xray_result = nodes.get('XRay Result', True)
has_bronchitis = nodes.get('Has Bronchitis', True)
dest_network = network
source_nodes = [has_lung_cancer, tuberculosis_or_cancer, dyspnea, has_bronchitis]
copy_fragment(network, source_nodes, dest_network, '_copy', True)
print(network.saveToString())