Source code for ctneat.iznn

from typing import Tuple, List, Dict, Callable, Optional, Union
"""
This module implements a spiking neural network.
Neurons are based on the model described by:

Izhikevich, E. M.
Simple Model of Spiking Neurons
IEEE TRANSACTIONS ON NEURAL NETWORKS, VOL. 14, NO. 6, NOVEMBER 2003

http://www.izhikevich.org/publications/spikes.pdf
"""
import numpy as np
from scipy.integrate import solve_ivp

from ctneat.attributes import FloatAttribute
from ctneat.genes import BaseGene, DefaultConnectionGene
from ctneat.genome import DefaultGenomeConfig, DefaultGenome
from ctneat.graphs import required_for_output

# a, b, c, d are the parameters of the Izhikevich model.
# a: the time scale of the recovery variable
# b: the sensitivity of the recovery variable
# c: the after-spike reset value of the membrane potential
# d: after-spike reset of the recovery variable
# The following parameter sets produce some known spiking behaviors:
# pylint: disable=bad-whitespace
REGULAR_SPIKING_PARAMS        = {'a': 0.02, 'b': 0.20, 'c': -65.0, 'd': 8.00}
INTRINSICALLY_BURSTING_PARAMS = {'a': 0.02, 'b': 0.20, 'c': -55.0, 'd': 4.00}
CHATTERING_PARAMS             = {'a': 0.02, 'b': 0.20, 'c': -50.0, 'd': 2.00}
FAST_SPIKING_PARAMS           = {'a': 0.10, 'b': 0.20, 'c': -65.0, 'd': 2.00}
THALAMO_CORTICAL_PARAMS       = {'a': 0.02, 'b': 0.25, 'c': -65.0, 'd': 0.05}
RESONATOR_PARAMS              = {'a': 0.10, 'b': 0.25, 'c': -65.0, 'd': 2.00}
LOW_THRESHOLD_SPIKING_PARAMS  = {'a': 0.02, 'b': 0.25, 'c': -65.0, 'd': 2.00}


# TODO: Add mechanisms analogous to axon & dendrite propagation delay.


[docs] class IZNodeGene(BaseGene): """Contains attributes for the iznn node genes and determines genomic distances.""" _gene_attributes = [FloatAttribute('bias'), FloatAttribute('a'), FloatAttribute('b'), FloatAttribute('c'), FloatAttribute('d')] # @property # def a(self): # return self.__getattribute__('a') # @property # def b(self): # return self.__getattribute__('b') # @property # def c(self): # return self.__getattribute__('c') # @property # def d(self): # return self.__getattribute__('d')
[docs] def distance(self, other, config): s = abs(self.a - other.a) + abs(self.b - other.b) \ + abs(self.c - other.c) + abs(self.d - other.d) return s * config.compatibility_weight_coefficient
[docs] class IZGenome(DefaultGenome):
[docs] @classmethod def parse_config(cls, param_dict): param_dict['node_gene_type'] = IZNodeGene param_dict['connection_gene_type'] = DefaultConnectionGene return DefaultGenomeConfig(param_dict)
[docs] class IZNeuron(object): """Sets up and simulates the iznn nodes (neurons)."""
[docs] def __init__(self, bias: float, a: float, b: float, c: float, d: float, inputs: List[Tuple[int, float]]): """ a, b, c, d are the parameters of the Izhikevich model. Args: bias (float): The bias of the neuron. a (float): The time-scale of the recovery variable. b (float): The sensitivity of the recovery variable. c (float): The after-spike reset value of the membrane potential. d (float): The after-spike reset value of the recovery variable. inputs (list(tuple(int, float))): A list of (input key, weight) pairs for incoming connections. """ self.a = a self.b = b self.c = c self.d = d self.bias = bias self.inputs = inputs # Membrane potential (millivolts). self.v = self.c # Membrane recovery variable. self.u = self.b * self.v # 1.0 if the neuron has fired, 0.0 otherwise. self.fired = 0.0 # Input current (milliamps). self.current = self.bias
def _spike_event(self, t, y): """Event function: returns 0 when v crosses 30mV.""" v, u = y return v - 30.0 _spike_event.terminal = True # Stop integration when the event is found _spike_event.direction = 1 # Trigger only when crossing from below (v is increasing) def _derivatives(self, state: np.ndarray, current: float) -> np.ndarray: """ Calculates the derivatives dv/dt and du/dt for a given state. Args: state (np.array): A numpy array [v, u]. current (float): The input current I. Returns: np.array: A numpy array [dv/dt, du/dt]. """ v, u = state dv_dt = 0.04 * v**2 + 5 * v + 140 - u + current du_dt = self.a * (self.b * v - u) return np.array([dv_dt, du_dt]) def _derivatives_scipy(self, t, y): """ Calculates the derivatives for SciPy's solve_ivp. Signature must be f(t, y). Args: t (float): Current time (not used in this model). y (list): A list [v, u] where v is the membrane potential and u is the recovery variable. Returns: list: A list [dv/dt, du/dt]. """ v, u = y dv_dt = 0.04 * v**2 + 5 * v + 140 - u + self.current du_dt = self.a * (self.b * v - u) return [dv_dt, du_dt]
[docs] def advance(self, dt_msec: float): """ This is a default advance method which is simply a wrapper to the advance_scipy method. Args: dt_msec (float): The time step in milliseconds. """ self.advance_scipy(dt_msec)
[docs] def advance_rk4(self, dt_msec: float): """ Advances simulation time using 4th-Order Runge-Kutta. if v >= 30 then v <- c, u <- u + d else v' = 0.04 * v^2 + 5v + 140 - u + I u' = a * (b * v - u) Args: dt_msec (float): The time step in milliseconds. """ # The spike detection and reset logic must happen *after* the integration step. self.fired = 0.0 if self.v >= 30.0: # Output spike and reset. self.fired = 1.0 self.v = self.c self.u += self.d return # End the step here after a reset try: y = np.array([self.v, self.u]) h = dt_msec k1 = self._derivatives(y, self.current) k2 = self._derivatives(y + 0.5 * h * k1, self.current) k3 = self._derivatives(y + 0.5 * h * k2, self.current) k4 = self._derivatives(y + h * k3, self.current) # Update state variables v and u y_new = y + (h / 6.0) * (k1 + 2*k2 + 2*k3 + k4) self.v, self.u = y_new except (OverflowError, ValueError): # If integration fails (e.g., due to instability), reset without spiking. # This is more robust than just catching OverflowError. self.v = self.c self.u = self.b * self.v
[docs] def advance_scipy(self, dt_msec: float, method: str = 'LSODA'): """ Advances simulation time using a solver from SciPy. Args: dt_msec (float): The time step in milliseconds. method (str): The integration method to use (e.g., 'RK45', 'LSODA'). Other options are listed in the SciPy documentation for solve_ivp (https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html). """ self.fired = 0.0 if self.v >= 30.0: self.fired = 1.0 self.v = self.c self.u += self.d return try: y0 = [self.v, self.u] t_span = [0, dt_msec] # Call the solver sol = solve_ivp( fun=self._derivatives_scipy, t_span=t_span, y0=y0, method=method, t_eval=[dt_msec] # We only need the state at the end of the interval ) # Update the neuron's state from the solution self.v, self.u = sol.y[:, -1] except (OverflowError, ValueError): self.v = self.c self.u = self.b * self.v
[docs] def advance_scipy_events(self, dt_msec: float, method: str = 'LSODA'): """ Advances the simulation using SciPy's solve_ivp with event detection. This method detects spikes (when v crosses 30mV) during the integration step. Args: dt_msec (float): The time step in milliseconds. method (str): The integration method to use (e.g., 'RK45', 'LSODA'). """ self.fired = 0.0 try: y0 = [self.v, self.u] t_span = [0, dt_msec] sol = solve_ivp( fun=self._derivatives_scipy, t_span=t_span, y0=y0, method=method, events=self._spike_event ) # Check if a spike event was triggered if sol.t_events[0].size > 0: self.fired = 1.0 self.v = self.c self.u += self.d else: # No spike, just update to the final state self.v, self.u = sol.y[:, -1] except (OverflowError, ValueError): self.v = self.c self.u = self.b * self.v
[docs] def solve_for_interval(self, dt_msec: float, method: str = 'LSODA'): """ Solves the neuron's ODE for a given interval and reports the solution and any spike events. This method DOES NOT change the neuron's state. """ if self.v >= 30.0: # Already in a spiked state from a previous step return None, 0.0 # Spike at the very beginning of the interval y0 = [self.v, self.u] t_span = [0, dt_msec] sol = solve_ivp( fun=self._derivatives_scipy, t_span=t_span, y0=y0, method=method, events=self._spike_event, dense_output=True # Needed to evaluate the solution at any time ) spike_time = sol.t_events[0][0] if sol.t_events[0].size > 0 else None return sol, spike_time
[docs] def reset(self): """Resets all state variables.""" self.v = self.c self.u = self.b * self.v self.fired = 0.0 self.current = self.bias
[docs] class IZNN(object): """Basic iznn network object."""
[docs] def __init__(self, neurons: Dict[int, IZNeuron], inputs: List[int], outputs: List[int], event_driven: bool = False): """ Initializes the IZNN with the given neurons, inputs, and outputs. Args: neurons (dict): A dictionary mapping neuron IDs to IZNeuron instances. inputs (list): A list of input neuron IDs. outputs (list): A list of output neuron IDs. event_driven (bool): If True, uses event-driven simulation for spike timing. """ self.neurons = neurons self.inputs = inputs self.outputs = outputs self.input_values = {} self.time_ms = 0.0 self.event_driven = event_driven
[docs] def set_inputs(self, inputs: List[float]): """ Assigns input voltages. Args: inputs (list): A list of input voltages. (in millivolts, where each voltage corresponds to an input neuron) """ if len(inputs) != len(self.inputs): raise RuntimeError( "Number of inputs {0:d} does not match number of input nodes {1:d}".format( len(inputs), len(self.inputs))) # Set the input values for the input neurons (in the same order as input_nodes). for i, v in zip(self.inputs, inputs): self.input_values[i] = v
[docs] def reset(self): """Resets all neurons to their default state.""" for n in self.neurons.values(): n.reset() self.time_ms = 0.0
[docs] def get_time_step_msec(self): """ Returns a safe time step in milliseconds for the current network configuration. This is a placeholder implementation and should be replaced with a proper calculation. """ return 0.05
[docs] def advance(self, dt: float, method: Optional[str] = 'LSODA', events: bool = False, ret: Optional[Union[List[str], str]] = None) -> Union[List[float], List[List[float]]]: """ Advances the simulation by the given time step in milliseconds. Args: dt_msec (float): The time step in milliseconds. method (str): The integration method to use. If None, uses manually written RK4, otherwise defaults to SciPy's LSODA. If specified, uses SciPy's solve_ivp with the given method. Valid methods are listed in the SciPy documentation for solve_ivp (https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html). And here is a summary of available methods: - 'RK45' (Default): An adaptive Runge-Kutta method of order 5(4). It's a great general-purpose choice and a good starting point. - 'RK23': A lower-order adaptive Runge-Kutta method. Faster but less accurate than RK45. - 'DOP853': A high-order (8th) adaptive Runge-Kutta method for when you need very high precision. - 'LSODA': This is a particularly important one for spiking neurons. It's a solver that automatically switches between methods for non-stiff and stiff problems. A "stiff" ODE is one where some parts of the solution change very rapidly while others change slowly (like the membrane potential during a spike!). LSODA is often very efficient and stable for these kinds of systems. - 'BDF' and 'Radau': Other excellent methods for stiff problems. events (bool): Whether to use event detection for spikes. Only applicable if 'method' is specified. ret (list(str) or str or None): Specifies what to return. If a list of strings, returns a list of lists, where each inner list corresponds to the requested attribute for all output neurons. If a single string, returns a list corresponding to the requested attribute for all output neurons. If None, returns a list of firing states for all output neurons. Valid strings are: 'fired' - returns the firing states (1.0 if fired, 0.0 otherwise) 'voltages' - returns the membrane potentials (in millivolts) 'recovery' - returns the recovery variables 'all' - returns a list of lists: [fired states, voltages, recovery variables] Returns: A list or a list of lists as specified by the 'ret' parameter. Raises: ValueError: If an invalid integration method is specified. """ if self.event_driven: return self.advance_event_driven(dt, method=method or 'LSODA', ret=ret) else: return self.advance_simple(dt, method=method, events=events, ret=ret)
[docs] def advance_simple(self, dt_msec, method: Optional[str] = 'LSODA', events: bool = False, ret: Optional[Union[List[str], str]] = None) -> Union[List[float], List[List[float]]]: """ Advances the simulation by the given time step in milliseconds. Args: dt_msec (float): The time step in milliseconds. method (str): The integration method to use. If None, uses manually written RK4, otherwise defaults to SciPy's LSODA. If specified, uses SciPy's solve_ivp with the given method. Valid methods are listed in the SciPy documentation for solve_ivp (https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html). And here is a summary of available methods: - 'RK45' (Default): An adaptive Runge-Kutta method of order 5(4). It's a great general-purpose choice and a good starting point. - 'RK23': A lower-order adaptive Runge-Kutta method. Faster but less accurate than RK45. - 'DOP853': A high-order (8th) adaptive Runge-Kutta method for when you need very high precision. - 'LSODA': This is a particularly important one for spiking neurons. It's a solver that automatically switches between methods for non-stiff and stiff problems. A "stiff" ODE is one where some parts of the solution change very rapidly while others change slowly (like the membrane potential during a spike!). LSODA is often very efficient and stable for these kinds of systems. - 'BDF' and 'Radau': Other excellent methods for stiff problems. events (bool): Whether to use event detection for spikes. Only applicable if 'method' is specified. ret (list(str) or str or None): Specifies what to return. If a list of strings, returns a list of lists, where each inner list corresponds to the requested attribute for all output neurons. If a single string, returns a list corresponding to the requested attribute for all output neurons. If None, returns a list of firing states for all output neurons. Valid strings are: 'fired' - returns the firing states (1.0 if fired, 0.0 otherwise) 'voltages' - returns the membrane potentials (in millivolts) 'recovery' - returns the recovery variables 'all' - returns a list of lists: [fired states, voltages, recovery variables] Returns: A list or a list of lists as specified by the 'ret' parameter. Raises: ValueError: If an invalid integration method is specified. """ if method not in ['RK45', 'RK23', 'DOP853', 'LSODA', 'BDF', 'Radau']: raise ValueError(f"Invalid integration method '{method}'. Valid methods are 'RK45', 'RK23', 'DOP853', 'LSODA', 'BDF', 'Radau'.") if method is None and events: raise ValueError("Event detection requires a valid integration method.") for n in self.neurons.values(): n.current = n.bias # In the advance step, we compute the new current for each neuron. # Each input contributes its value * weight to the current. # Where value is 1.0 if the input neuron fired, and 0.0 otherwise. # In case the input is not a neuron, we use the externally set input value. for i, w in n.inputs: ineuron = self.neurons.get(i) if ineuron is not None: ivalue = ineuron.fired else: ivalue = self.input_values.get(i, 0.0) n.current += ivalue * w for n in self.neurons.values(): if method is None: n.advance_rk4(dt_msec) elif events: n.advance_scipy_events(dt_msec, method=method) else: n.advance_scipy(dt_msec, method=method) self.time_ms += dt_msec out_neurons_firing = [self.neurons[i].fired for i in self.outputs] out_neurons_voltages = [self.neurons[i].v for i in self.outputs] out_neurons_recovery = [self.neurons[i].u for i in self.outputs] ret_keys = {'fired': out_neurons_firing, 'voltages': out_neurons_voltages, 'recovery': out_neurons_recovery} if isinstance(ret, list): return [ret_keys[k] for k in ret if k in ret_keys] elif isinstance(ret, str): if ret == 'all': return [out_neurons_firing, out_neurons_voltages, out_neurons_recovery] return ret_keys.get(ret, []) else: return out_neurons_firing
[docs] def advance_event_driven(self, dt_msec: float, method: str = 'LSODA', ret: Optional[Union[List[str], str]] = None) -> Union[List[float], List[List[float]]]: """ Advances the simulation by at most dt_msec using a true event-driven approach. The simulation advances to the time of the earliest spike event in the network, or by the full dt_msec if no spikes occur in that interval. This ensures that spike timing is captured with high precision. Args: dt_msec (float): The maximum time step to advance in milliseconds. method (str): The integration method to use. If None, uses manually written RK4, otherwise defaults to SciPy's LSODA. If specified, uses SciPy's solve_ivp with the given method. Valid methods are listed in the SciPy documentation for solve_ivp (https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html). And here is a summary of available methods: - 'RK45' (Default): An adaptive Runge-Kutta method of order 5(4). It's a great general-purpose choice and a good starting point. - 'RK23': A lower-order adaptive Runge-Kutta method. Faster but less accurate than RK45. - 'DOP853': A high-order (8th) adaptive Runge-Kutta method for when you need very high precision. - 'LSODA': This is a particularly important one for spiking neurons. It's a solver that automatically switches between methods for non-stiff and stiff problems. A "stiff" ODE is one where some parts of the solution change very rapidly while others change slowly (like the membrane potential during a spike!). LSODA is often very efficient and stable for these kinds of systems. - 'BDF' and 'Radau': Other excellent methods for stiff problems. ret (list(str) or str or None): Specifies what to return. If a list of strings, returns a list of lists, where each inner list corresponds to the requested attribute for all output neurons. If a single string, returns a list corresponding to the requested attribute for all output neurons. If None, returns a list of firing states for all output neurons. Valid strings are: 'fired' - returns the firing states (1.0 if fired, 0.0 otherwise) 'voltages' - returns the membrane potentials (in millivolts) 'recovery' - returns the recovery variables 'all' - returns a list of lists: [fired states, voltages, recovery variables] Returns: A list or a list of lists as specified by the 'ret' parameter, representing the state of the output neurons after the time step. Raises: ValueError: If an invalid integration method is specified. """ if method not in ['RK45', 'RK23', 'DOP853', 'LSODA', 'BDF', 'Radau']: raise ValueError(f"Invalid integration method '{method}'. Valid methods are 'RK45', 'RK23', 'DOP853', 'LSODA', 'BDF', 'Radau'.") # Calculate input currents for all neurons based on the current state. for n in self.neurons.values(): n.current = n.bias for i, w in n.inputs: ineuron = self.neurons.get(i) if ineuron is not None: # Input from another neuron is based on its 'fired' state from the previous step. ivalue = ineuron.fired else: # Input from an external source. ivalue = self.input_values.get(i, 0.0) n.current += ivalue * w # Poll all neurons to get their solutions and potential spike times. # This step does NOT change the state of any neuron. solutions = {} event_times = {} for nid, n in self.neurons.items(): sol, spike_time = n.solve_for_interval(dt_msec, method=method) if sol: solutions[nid] = sol if spike_time is not None: event_times[nid] = spike_time # Determine the actual time to advance the simulation. # This is the time of the earliest spike, or the full dt_msec if no spikes occur. if not event_times: time_to_advance = dt_msec else: min_event_time = min(event_times.values()) time_to_advance = min(min_event_time, dt_msec) # Update all neuron states to the new global time. # We use the 'dense_output' from the solution to find the precise state at 'time_to_advance'. for nid, n in self.neurons.items(): if nid in solutions: new_state = solutions[nid].sol(time_to_advance) n.v, n.u = new_state # Reset the 'fired' flag for all neurons before processing the new spikes. n.fired = 0.0 # Process the spike(s): reset the neuron(s) that fired at this exact moment. for nid, t in event_times.items(): # Use a small tolerance for floating point comparison. if abs(t - time_to_advance) < 1e-9: n = self.neurons[nid] n.fired = 1.0 n.v = n.c n.u += n.d # FUTURE TODO: This is where you would add the spike to a delivery # queue if implementing axonal propagation delays. # Advance the global clock. self.time_ms += time_to_advance # Return the requested output values, consistent with the other advance method. out_neurons_firing = [self.neurons[i].fired for i in self.outputs] out_neurons_voltages = [self.neurons[i].v for i in self.outputs] out_neurons_recovery = [self.neurons[i].u for i in self.outputs] ret_keys = {'fired': out_neurons_firing, 'voltages': out_neurons_voltages, 'recovery': out_neurons_recovery} if isinstance(ret, list): return [ret_keys[k] for k in ret if k in ret_keys] elif isinstance(ret, str): if ret == 'all': return [out_neurons_firing, out_neurons_voltages, out_neurons_recovery] return ret_keys.get(ret, []) else: return out_neurons_firing
@property def state(self) -> Dict[int, Tuple[float, float, float]]: """ Returns the current state of the network as a dictionary mapping neuron IDs to their (v, u, fired) state. """ return {nid: (n.v, n.u, n.fired) for nid, n in self.neurons.items()} @property def fired(self) -> List[float]: """Returns a list of firing states for all output neurons.""" return [self.neurons[i].fired for i in self.outputs] @property def voltages(self) -> List[float]: """Returns a list of voltage states for all output neurons.""" return [self.neurons[i].v for i in self.outputs] @property def recovery(self) -> List[float]: """Returns a list of recovery variable states for all output neurons.""" return [self.neurons[i].u for i in self.outputs]
[docs] @staticmethod def create(genome, config): """ Receives a genome and returns its phenotype (a neural network). """ # Get the genome configuration. genome_config = config.genome_config # Get the set of node keys that are required for the output. required = required_for_output(genome_config.input_keys, genome_config.output_keys, genome.connections) # Gather inputs and expressed connections. node_inputs = {} for cg in genome.connections.values(): # Skip disabled connections. if not cg.enabled: continue # i, o = cg.key if o not in required and i not in required: continue if o not in node_inputs: node_inputs[o] = [(i, cg.weight)] else: node_inputs[o].append((i, cg.weight)) neurons = {} for node_key in required: ng = genome.nodes[node_key] inputs = node_inputs.get(node_key, []) neurons[node_key] = IZNeuron(ng.bias, ng.a, ng.b, ng.c, ng.d, inputs) genome_config = config.genome_config return IZNN(neurons, genome_config.input_keys, genome_config.output_keys)