from typing import Callable, List, Tuple, Dict, Optional
"""Handles the continuous-time recurrent neural network implementation."""
from ctneat.graphs import required_for_output
[docs]
class CTRNNNodeEval(object):
[docs]
def __init__(self, time_constant: float, activation: Callable, aggregation: Callable,
bias: float, response: float, links: List[Tuple[int, float]]):
"""
Initialize a CTRNN node evaluation.
Args:
time_constant: The time constant of the node.
activation: The activation function of the node.
aggregation: The aggregation function of the node.
bias: The bias term of the node.
response: The response term of the node.
links: The links from other nodes (incoming connections).
"""
self.time_constant = time_constant
self.activation = activation
self.aggregation = aggregation
self.bias = bias
self.response = response
self.links = links
[docs]
class CTRNN(object):
"""Sets up the ctrnn network itself."""
[docs]
def __init__(self, inputs: List[int], outputs: List[int], node_evals: Dict[int, CTRNNNodeEval],
custom_advance: Optional[Callable] = None):
"""
Initialize the CTRNN with the given input and output nodes, and node evaluations.
Args:
inputs: The input node IDs.
outputs: The output node IDs.
node_evals: A dictionary mapping node IDs to their evaluations (CTRNNNodeEval objects).
custom_advance: An optional custom advance function.
"""
self.input_nodes = inputs
self.output_nodes = outputs
self.node_evals = node_evals
self.custom_advance = custom_advance
self.values = [{}, {}]
for v in self.values:
# setting the initial value of all input and output nodes to 0.0
for k in inputs + outputs:
v[k] = 0.0
# for every node that is (possibly) not an input or output but is part of an
# active network - setting its initial value to 0.0
for node, ne in self.node_evals.items():
v[node] = 0.0
for i, w in ne.links:
v[i] = 0.0
self.active = 0
self.time_seconds = 0.0
[docs]
def reset(self):
"""
Reset the CTRNN to its initial state. (I.e. all node values to 0.0, and all time-related variables to 0.0)
"""
self.values = [dict((k, 0.0) for k in v) for v in self.values]
self.active = 0
self.time_seconds = 0.0
[docs]
def set_node_value(self, node_key: int, value: float):
"""
Set a value of a specific node.
Args:
node_key: The ID of the node to set the value for.
value: The value to set for the node.
"""
for v in self.values:
v[node_key] = value
[docs]
def get_max_time_step(self): # pragma: no cover
# TODO: Compute max time step that is known to be numerically stable for
# the current network configuration.
# pylint: disable=no-self-use
raise NotImplementedError()
[docs]
def advance(self, inputs: List[float], advance_time: float, time_step: Optional[float] = None):
"""
Advance the simulation by the given amount of time, assuming that inputs are
constant at the given values during the simulated time.
Args:
inputs: The input values to the network.
advance_time: The amount of time to advance the simulation.
time_step: The time step to use for the simulation.
Returns:
The output values of the network after the simulation.
"""
if self.custom_advance is not None:
return self.custom_advance(inputs, advance_time, time_step)
return self._simple_advance(inputs, advance_time, time_step)
def _simple_advance(self, inputs: List[float], advance_time: float, time_step: Optional[float] = None):
"""
Advance the simulation by the given amount of time, assuming that inputs are
constant at the given values during the simulated time.
Args:
inputs: The input values to the network.
advance_time: The amount of time to advance the simulation.
time_step: The time step to use for the simulation.
Returns:
The output values of the network after the simulation.
"""
final_time_seconds = self.time_seconds + advance_time
# Use half of the max allowed time step if none is given.
if time_step is None: # pragma: no cover
time_step = 0.5 * self.get_max_time_step()
if len(self.input_nodes) != len(inputs):
raise RuntimeError(f"Expected {len(self.input_nodes)} inputs, got {len(inputs)}")
while self.time_seconds < final_time_seconds:
# Ensure time_step is not None and is a float
assert time_step is not None, "time_step must be a float"
dt = min(time_step, final_time_seconds - self.time_seconds)
# self.values is a list containing two dictionaries, such that during the simulation step
# one is maintained and the other is updated.
ivalues = self.values[self.active]
ovalues = self.values[1 - self.active]
self.active = 1 - self.active
# all nodes that have an input are set to that value
for i, v in zip(self.input_nodes, inputs):
ivalues[i] = v
ovalues[i] = v
# for every node in the network, compute its new value
for node_key, ne in self.node_evals.items():
# the input for a given node is the weighted sum of its inputs
node_inputs = [ivalues[i] * w for i, w in ne.links]
# compute the node's new state by:
# aggregating the inputs
s = ne.aggregation(node_inputs)
# applying the activation function
z = ne.activation(ne.bias + ne.response * s)
# updating the output value (new value of that node)
ovalues[node_key] += dt / ne.time_constant * (-ovalues[node_key] + z)
self.time_seconds += dt
ovalues = self.values[1 - self.active]
return [ovalues[i] for i in self.output_nodes]
[docs]
@staticmethod
def create(genome, config, time_constant):
"""
Receives a genome and returns its phenotype (a CTRNN).
Args:
genome: The genome to create the CTRNN from.
config: The configuration object.
time_constant: The time constant to use for all nodes.
"""
genome_config = config.genome_config
required = required_for_output(genome_config.input_keys, genome_config.output_keys, genome.connections)
# Gather inputs and expressed connections.
node_inputs = {}
for cg in genome.connections.values():
if not cg.enabled:
continue
i, o = cg.key
if o not in required and i not in required:
continue
if o not in node_inputs:
node_inputs[o] = [(i, cg.weight)]
else:
node_inputs[o].append((i, cg.weight))
node_evals = {}
for node_key, inputs in node_inputs.items():
node = genome.nodes[node_key]
activation_function = genome_config.activation_defs.get(node.activation)
aggregation_function = genome_config.aggregation_function_defs.get(node.aggregation)
node_evals[node_key] = CTRNNNodeEval(time_constant,
activation_function,
aggregation_function,
node.bias,
node.response,
inputs)
return CTRNN(genome_config.input_keys, genome_config.output_keys, node_evals)