Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions benchmarks/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,225 @@
from py_rete.conditions import Bind
from py_rete.conditions import Filter
from py_rete.network import ReteNetwork
import py_rete.network
from py_rete.negative_node import NegativeNode
from py_rete.pnode import PNode, Production
from py_rete.ncc_node import NccNode, NccPartnerNode
from py_rete.filter_node import FilterNode
from py_rete.bind_node import BindNode
from py_rete.beta import BetaMemory

from unittest.mock import MagicMock, Mock, sentinel, call
from pytest import mark, raises


def test_network_run(monkeypatch):
match = Mock(name="Match")
n = ReteNetwork()
p = property(Mock(side_effect=[[match], []]))
monkeypatch.setattr(ReteNetwork, 'matches', p)
n.run()
assert match.fire.mock_calls == [call()]


def test_network_repr():
n = ReteNetwork()
p = Mock(name="Production", id=sentinel.PROD_ID)
n.productions = set([p])

dup_fact = {
sentinel.SIMPLE: sentinel.VALUE,
sentinel.SUBFACT: Fact(name="subfact")
}
dup_fact[sentinel.SUBFACT].id = sentinel.SUBFACT_ID
f = Mock(name="Fact", id=sentinel.ID, duplicate=Mock(return_value=dup_fact))
n.facts = {
f.id: f
}
wme = Mock(name="Working Memory")
n.working_memory = set([wme])
text = repr(n)

assert list(dup_fact.items()) == [(sentinel.SIMPLE, sentinel.VALUE), (sentinel.SUBFACT, sentinel.SUBFACT_ID)]
assert "sentinel.PROD_ID: <Mock name='Production'" in text
assert "sentinel.ID: {sentinel.SIMPLE: sentinel.VALUE, sentinel.SUBFACT: sentinel.SUBFACT_ID}" in text
assert "<Mock name='Working Memory'" in text


def test_network_num_nodes():
n = ReteNetwork()

n.beta_root = Mock(children=[])
assert n.num_nodes() == 1

n.beta_root = Mock(children=[Mock(children=[])])
assert n.num_nodes() == 2


def test_network_add_remove_get_fact():
n = ReteNetwork()
subfact_id = Fact(name="sub_id")
n.add_fact(subfact_id) # Assigns ID.
subfact = Fact(name="sub")
top_fact = Fact(name="top", subfact=subfact, subfact_id=subfact_id)
n.add_fact(top_fact)
# EQ tests require comparing objects with the id's updated.
assert n.facts == {'f-0': subfact_id, 'f-1': subfact, 'f-2': top_fact}

with raises(ValueError):
n.add_fact(top_fact)
with raises(ValueError):
n.remove_fact(Fact(not_in_evidence=True))

assert n.get_fact_by_id('f-0') == subfact_id


def test_network_get_new_match():
n = ReteNetwork()
old = Mock(name="PNode", new=False)
n.pnodes = [old]
assert n.get_new_match() is None


def test_network_add_production():
n = ReteNetwork()
n.productions = {sentinel.EXISTING: Mock}
with raises(ValueError):
n.add_production(Mock(id=sentinel.EXISTING))
with raises(ValueError):
n.remove_production(Mock(id=None))


def test_network_add_remove_wme():
n = ReteNetwork()
wme_0 = Mock(identifier="#*#", attribute=sentinel.ATTR, value=sentinel.VALUE)
with raises(ValueError):
n.add_wme(wme_0)

jr = Mock(owner=Mock(node=None, join_results=[]))
jr.owner.join_results.append(jr)

wme_1 = Mock(amems=[], tokens=[], negative_join_results=[jr])
n.working_memory = set([wme_1])
n.remove_wme(wme_1)


def test_network_alpha_mem():
n = ReteNetwork()
with raises(ValueError):
cond_0 = Mock(identifier="#*#", attribute=sentinel.ATTR, value=sentinel.VALUE)
n.build_or_share_alpha_memory(cond_0)

cond_1 = Mock(
identifier=sentinel.ID,
attribute=V(sentinel.ATTR_V),
value=sentinel.VALUE
)
key = (sentinel.ID, "#*#", sentinel.VALUE)
n.alpha_hash = {key: sentinel.COND}
c = n.build_or_share_alpha_memory(cond_1)
assert c == sentinel.COND
# assert n.alpha_hash == {key: sentinel.COND}


def test_network_negative_node():
n = ReteNetwork()
amem = Mock(successors=[], reference_count=0)
other = Mock(name="ReteNode", items=[], amem=sentinel.AMEM, condition=sentinel.COND)
condition = Mock(name="condition", vars=[("name", sentinel.VALUE)])
negative = NegativeNode(items=[], amem=amem, condition=condition)
parent = Mock(name="Parent ReteNode", children=[other, negative])
new = n.build_or_share_negative_node(parent, amem, condition)
assert new == negative


def test_network_beta_memory():
n = ReteNetwork()
other = Mock(name="not BetaMemory")
beta = BetaMemory(items=[])
parent = Mock(name="Parent ReteNode", children=[other, beta])
new = n.build_or_share_beta_memory(parent)
assert new == beta


def test_network_pnode():
n = ReteNetwork()
other = Mock(name="ReteNode")
pnode = PNode(production=sentinel.PROD, items=[])
parent = Mock(name="Parent ReteNode", children=[other, pnode])
new = n.build_or_share_p(parent, sentinel.PROD)
assert new == pnode


def test_network_ncc_node(monkeypatch):
monkeypatch.setattr(ReteNetwork, 'build_or_share_network_for_conditions', Mock(return_value=sentinel.BOTTOM))
n = ReteNetwork()
other = Mock(name="ReteNode")
nccnode = NccNode(partner=Mock(name="NccPartnerNode", parent=None), items=[])
nccnode.partner.parent = sentinel.BOTTOM
parent = Mock(name="Parent JoinNode", children=[other, nccnode])
new = n.build_or_share_ncc_nodes(parent, sentinel.CANDIDATE_NCC, earlier_conds=[])
assert new == nccnode


def test_network_filter_node():
n = ReteNetwork()
other = Mock(name="ReteNode")
filternode = FilterNode(children=[], parent=Mock(), func=sentinel.FUNC, rete=sentinel.RETE)
parent = Mock(name="Parent ReteNode", children=[other, filternode])
new = n.build_or_share_filter_node(parent, Mock(func=sentinel.FUNC))
assert new == filternode


def test_network_bind_node():
n = ReteNetwork()
other = Mock(name="ReteNode")
bindnode = BindNode(children=[], parent=Mock(), func=sentinel.FUNC, to=sentinel.TO, rete=sentinel.RETE)
parent = Mock(name="Parent ReteNode", children=[other, bindnode])
new = n.build_or_share_bind_node(parent, Mock(func=sentinel.FUNC, to=sentinel.TO))
assert new == bindnode


def test_network_delete_unused():
n = ReteNetwork()
node = NccPartnerNode()
def cleanup():
# Side-effect of a complex bit of processing in Token.
node.new_result_buffer = []
token = Mock(network=n, delete_token_and_descendents=Mock(side_effect=cleanup))
node.new_result_buffer = [token]
n.delete_node_and_any_unused_ancestors(node)
assert token.delete_token_and_descendents.mock_calls == [call()]


def test_network_update_beta_memory():
n = ReteNetwork()
parent = BetaMemory(items=[sentinel.TOKEN])
new_node = Mock(left_activation=Mock())
new_node.parent = parent
n.update_new_node_with_matches_from_above(new_node)
assert new_node.left_activation.mock_calls == [call(token=sentinel.TOKEN)]


def test_network_update_negative_node():
n = ReteNetwork()
token_0 = Mock(binding=sentinel.BINDING, join_results=True)
token_1 = Mock(binding=sentinel.BINDING, join_results=False)
parent = NegativeNode(items=[token_0, token_1], amem=sentinel.AMEM, condition=Mock(vars=[(sentinel.NAME, sentinel.VALUE)]))
new_node = Mock(left_activation=Mock())
new_node.parent = parent
n.update_new_node_with_matches_from_above(new_node)
assert new_node.left_activation.mock_calls == [call(token_1, None, sentinel.BINDING)]


def test_network_update_ncc_node():
n = ReteNetwork()
token_0 = Mock(binding=sentinel.BINDING, ncc_results=True)
token_1 = Mock(binding=sentinel.BINDING, ncc_results=False)
parent = NccNode(items=[token_0, token_1], partner=sentinel.PARTNER)
new_node = Mock(left_activation=Mock())
new_node.parent = parent
n.update_new_node_with_matches_from_above(new_node)


def init_network():
Expand Down
5 changes: 4 additions & 1 deletion py_rete/alpha.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""
Alpha Memory structure.
"""
from __future__ import annotations
from typing import TYPE_CHECKING

if TYPE_CHECKING: # pragma: no cover
from typing import List
from typing import Optional
from typing import List
from py_rete.join_node import JoinNode
from py_rete.common import WME

Expand Down
15 changes: 10 additions & 5 deletions py_rete/beta.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""
Rete Node and Beta Memory -- a kind of Rete Node.
"""
from __future__ import annotations
from typing import TYPE_CHECKING

from py_rete.common import Token

if TYPE_CHECKING: # pragma: no cover
from typing import Any
from typing import List
from typing import Dict
from typing import Optional
from typing import Any, Optional
# 3.8 and below.
from typing import List, Dict
# Otherwise, use list and dict
from py_rete.common import V
from py_rete.common import WME
from py_rete.alpha import AlphaMemory
Expand Down Expand Up @@ -45,7 +48,9 @@ def __init__(self, items: Optional[List[Token]] = None, **kwargs):

def find_nearest_ancestor_with_same_amem(self, amem: AlphaMemory
) -> Optional[JoinNode]:
return self.parent.find_nearest_ancestor_with_same_amem(amem)
if self.parent:
return self.parent.find_nearest_ancestor_with_same_amem(amem)
return None # pragma: no cover

def left_activation(self, token: Optional[Token] = None,
wme: Optional[WME] = None,
Expand Down
30 changes: 21 additions & 9 deletions py_rete/bind_node.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
"""
Bind Node: a kind of Rete Node with a code block and a variable binding.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import inspect

from py_rete.beta import ReteNode
from py_rete.beta import ReteNode, BetaMemory
from py_rete.common import V

if TYPE_CHECKING: # pragma: no cover
# Mypy for 3.8...
from typing import Any
from typing import Dict
from typing_extensions import TypeAlias
from typing import Dict, List
from typing import Callable
from py_rete.network import ReteNetwork

ProdFunc: TypeAlias = Callable[[], Any]


class BindNode(ReteNode):
"""
A beta network class. This class stores a code snipit, with variables in
A beta network class. This class stores a code snippet, with variables in
it. It gets all the bindings from the incoming token, updates them with the
current bindings, binds the result to the target variable (to), then
activates its children with the updated bindings.
"""

def __init__(self, children, parent, func, to, rete: ReteNetwork):
def __init__(self, children: List[BetaMemory], parent: BetaMemory,
func: ProdFunc, to: str, rete: ReteNetwork
) -> None:
"""
:type children:
:type parent: BetaNode
:type to: str
:param children: list of ReteNodes
:param parent: BetaNode
:param func: The Production function
:param to: Name of variable to bind to
:param rete: Overall ReteNetwork
"""
super().__init__(children=children, parent=parent)
self.func = func
Expand All @@ -35,11 +47,11 @@ def get_function_result(self, binding: Dict[V, Any]):
Given a binding that maps variables to values, this instantiates the
arguments for the function and executes it.
"""
args = inspect.getfullargspec(self.func)[0]
argspec = inspect.getfullargspec(self.func)[0]
args = {arg: self._rete_net if arg == 'net' else
self._rete_net.facts[binding[V(arg)]] if
binding[V(arg)] in self._rete_net.facts else
binding[V(arg)] for arg in args}
binding[V(arg)] for arg in argspec}
return self.func(**args)

def left_activation(self, token, wme, binding):
Expand Down
18 changes: 12 additions & 6 deletions py_rete/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class WME:
__slots__ = ['identifier', 'attribute', 'value', 'amems', 'tokens',
'negative_join_results']

def __init__(self, identifier: Hashable, attribute: Hashable, value:
Hashable) -> None:
def __init__(self, identifier: Hashable, attribute: Hashable,
value: Hashable) -> None:
"""
Identifier, attribute, and value can be any kind of object except V
objects (i.e., variables).
Expand All @@ -80,7 +80,7 @@ def __eq__(self, other: object) -> bool:
:type other: WME
"""
if not isinstance(other, WME):
return False
return NotImplemented
return self.identifier == other.identifier and \
self.attribute == other.attribute and \
self.value == other.value
Expand Down Expand Up @@ -134,7 +134,10 @@ def __hash__(self):
def is_root(self) -> bool:
return not self.parent and not self.wme

def render_tokens(self):
def render_tokens(self): # pragma: no cover
"""
.. todo:: Consider refactoring as a function **outside** the class.
"""
import networkx as nx
from networkx.drawing.nx_agraph import graphviz_layout
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -186,6 +189,9 @@ def delete_token_and_descendents(self) -> None:
- Add optimization for right unlinking (pg 87 of Doorenbois
thesis).

- Would introducing weakref help break the circularity
and simplify this?

:type token: Token
"""
from py_rete.ncc_node import NccNode
Expand Down Expand Up @@ -218,7 +224,7 @@ def delete_token_and_descendents(self) -> None:
bmchild.amem.successors.remove(bmchild)

if isinstance(self.node, NegativeNode):
if not self.node.items:
if not self.node.items: # pragma: no cover
self.node.amem.successors.remove(self.node)
for jr in self.join_results:
jr.wme.negative_join_results.remove(jr)
Expand All @@ -227,7 +233,7 @@ def delete_token_and_descendents(self) -> None:
for result_tok in self.ncc_results:
if result_tok.wme:
result_tok.wme.tokens.remove(result_tok)
if result_tok.parent:
if result_tok.parent: # pragma: no cover
result_tok.parent.children.remove(result_tok)

elif isinstance(self.node, NccPartnerNode):
Expand Down
Loading