Source code for mulberry.tree
import io
from .errors import NotConnectedError
[docs]class GenericTree:
def __init__(self, backend):
"""
Creates an empty tree
"""
self._backend = backend
self._keys = set() # Set of all keys
self._parents = {} # Dictionary {child key -> parent key}
self._transforms = {} # Dictionary {(from key, to key) -> 4x4 np.array}
[docs] def hasFrame(self, key):
"""
Checks if there is a frame `key`
"""
return key in self._keys
def _getAncestry(self, key):
"""
Gets path from `key` to the tree root (including `key`)
"""
path = []
while key is not None:
path.append(key)
key = self._parents.get(key, None)
return path
def _getRoot(self, key):
"""
Gets root of tree containing `key`
"""
return self._getAncestry(key)[-1]
def _getDirectTransform(self, from_key, to_key):
"""
If `from_key` and `to_key` are directly connected, get the transformation
"""
ret = self._transforms.get((from_key, to_key), None)
if ret is not None:
return ret
ret = self._transforms.get((to_key, from_key), None)
if ret is not None:
return self._backend.invert(ret)
return None
[docs] def getPath(self, from_key, to_key):
"""
Gets the shortest path between `from_key` and `to_key`
"""
if from_key == to_key:
return [from_key]
from_root_path = self._getAncestry(from_key)
to_root_path = self._getAncestry(to_key)
if from_root_path[-1] != to_root_path[-1]:
# Not connected
return None
i = 0
for from_path_node, to_path_node in zip(
reversed(from_root_path), reversed(to_root_path)
):
if from_path_node != to_path_node:
break
i += 1
# +1 on from_path to include connecting node
from_path = from_root_path[: len(from_root_path) - i + 1]
to_path = to_root_path[: len(to_root_path) - i]
return from_path + list(reversed(to_path))
def _getTransformFromPath(self, path):
"""
Calculates the composition of transforms along a path
"""
if len(path) < 2:
return self._backend.identity()
T = self._backend.identity()
for from_key, to_key in zip(path, path[1:]):
this_transform = self._getDirectTransform(from_key, to_key)
T = self._backend.compose(T, this_transform)
return T
[docs] def outputDOT(self, title=""):
"""
Outputs the graph in undirected DOT format::
// Autogenerated by mulberry
graph title {
"a" -- "b" -- "c";
"b" -- "d";
}
"""
escape = lambda s: s.replace('"', '\\"')
s = io.StringIO()
s.write("// Autogenerated by mulberry\ngraph %s {\n" % title)
for child, parent in self._parents.items():
s.write(' "%s" -- "%s";\n' % (escape(parent), escape(child)))
s.write("}\n")
return s.getvalue()
class Tree(GenericTree):
def __init__(self, *args):
from .backends.numpy_backend import NumpyBackend
super().__init__(*args, backend=NumpyBackend)