Source code for ctneat.ctrnn.ctrnn_visualize

"""
This file contains functions used to visualize the CTRNN model.
"""
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
import numpy as np
import graphviz
from sklearn.decomposition import PCA
from typing import Optional, Union
import warnings
from sklearn.preprocessing import StandardScaler

[docs] def draw_ctrnn_net(node_list: list, node_inputs: dict, iznn: Optional[bool] = False, dir_name: Optional[str] = 'ctneat_outputs', file_name: Optional[str] = None) -> None: """ This function draws the CTRNN network structure. Args: node_list: A list of node IDs in the network. node_inputs: A dictionary where keys are node IDs and values are lists of input connections for each node. (I.e. each list contains tuples of (input_node_id, weight) for the node with the corresponding ID) iznn: Whether the network is an Izhikevich spiking neural network (IZNN). If True, the function gives correct labels. dir_name: Optional directory name to save the output file. If None, saves in the current directory. file_name: Optional file name to save the output file. If None, defaults to 'ctrnn_network'. Returns: None """ dot = graphviz.Digraph() for node_id in node_list: dot.node(str(node_id)) for node_id, inputs in node_inputs.items(): for input_node_id, weight in inputs: dot.edge(str(input_node_id), str(node_id), label=str(weight)) dot.render(file_name or 'iznn_network' if iznn else 'ctrnn_network', format='png', cleanup=True, directory=dir_name or '.')
[docs] def draw_ctrnn_dynamics(states: np.ndarray, normalize: bool = False, uniform_time: bool = True, times: Optional[Union[np.ndarray, list]] = None, iznn: Optional[bool] = False, save: bool = False, show: bool = True, dir_name: Optional[str] = 'ctneat_outputs', file_name: Optional[str] = None) -> None: """ This function draws the dynamics of the CTRNN over time. Args: states: A 2D numpy array where each row corresponds to the state of the network at a given time step, and each column corresponds to a specific node's state. normalize: Whether to normalize the states before plotting. If True, each node's state is scaled to [0, 1]. uniform_time: Whether the time steps are uniform. If True, the function generates a uniform time array. times: If uniform_time is False, a list or numpy array of time points corresponding to each row in states. If uniform_time is True, this parameter is ignored. iznn: Whether the network is an Izhikevich spiking neural network (IZNN). If True, the function gives correct labels. save: Whether to save the plot as a file. If False, the plot is shown interactively. show: Whether to display the plot interactively. If False, the plot is only saved to a file if 'save' is True. dir_name: Optional directory name to save the output file. If None, saves in the current directory. file_name: Optional file name to save the output file. If None, defaults to 'ctrnn_dynamics'. Returns: None """ if uniform_time: times = np.arange(states.shape[0]) else: if times is None or len(times) != states.shape[0]: raise ValueError("Invalid times array. Must be provided and match the number of time steps in states.") if normalize: scaler = StandardScaler() states = scaler.fit_transform(states) plt.figure() plt.title(f"{'IZNN' if iznn else 'CTRNN'} Dynamics") plt.xlabel("Time") plt.ylabel("Output") plt.grid() for i in range(states.shape[1]): plt.plot(times, states[:, i], label=f"Node {i+1}") plt.legend(loc="best") if save: plt.savefig(f"{dir_name + '/' if dir_name else '.'}{file_name or 'iznn_dynamics' if iznn else 'ctrnn_dynamics'}.png") if show: plt.show()
[docs] def draw_ctrnn_trajectory(states: np.ndarray, n_components: int = 2, iznn: Optional[bool] = False, save: bool = False, show: bool = True, dir_name: Optional[str] = 'ctneat_outputs', file_name: Optional[str] = None) -> None: """ This function draws a trajectory of the CTRNN's state space. If there are more than 'n_components' nodes, the PCA is used to reduce the dimensionality to 'n_components'. Args: states: A 2D numpy array where each row corresponds to the state of the network at a given time step, and each column corresponds to a specific node's state. n_components: The number of components to reduce to (default is 2, max is 3). Note that if n_components is equal to 1 and the number of nodes is also 1, this function is equivalent to draw_ctrnn_dynamics. iznn: Whether the network is an Izhikevich spiking neural network (IZNN). If True, the function gives correct labels. save: Whether to save the plot as a file. If False, the plot is shown interactively. show: Whether to display the plot interactively. If False, the plot is only saved to a file if 'save' is True. dir_name: Optional directory name to save the output file. If None, saves in the current directory. file_name: Optional file name to save the output file. If None, defaults to 'ctrnn_face_portrait'. Returns: None Raises: ValueError: If n_components is not 1, 2, or 3. """ if n_components < 1 or n_components > 3: raise ValueError("Invalid number of components. Must be 1, 2, or 3.") if states.shape[1] > n_components: pca = PCA(n_components=n_components) reduced_states = pca.fit_transform(states) elif states.shape[1] < n_components: reduced_states = states.flatten() n_components = states.shape[1] warnings.warn(f"Cannot reduce to {n_components} components from {states.shape[1]} nodes. Falling back to {states.shape[1]}D.", UserWarning) else: reduced_states = states if n_components == 1: plt.figure(figsize=(10,5)) plt.title(f"{'IZNN' if iznn else 'CTRNN'} Trajectory ({n_components}D)") plt.xlabel("Time") plt.ylabel("Principal Component 1") plt.grid() plt.plot(range(len(reduced_states)), reduced_states, color='b', marker='o', markersize=3) # denoting the start and end points plt.text(range(len(reduced_states))[0], reduced_states[0], 'Start', fontsize=12, color='green') plt.text(range(len(reduced_states))[-1], reduced_states[-1], 'End', fontsize=12, color='red') elif n_components == 2: plt.figure(figsize=(10,10)) plt.title(f"{'IZNN' if iznn else 'CTRNN'} Trajectory ({n_components}D)") plt.xlabel("Principal Component 1") plt.ylabel("Principal Component 2") plt.grid() plt.plot(reduced_states[:, 0], reduced_states[:, 1], color='b', marker='o', markersize=3) # denoting the start and end points plt.text(reduced_states[0][0], reduced_states[0][1], 'Start', fontsize=12, color='green') plt.text(reduced_states[-1][0], reduced_states[-1][1], 'End', fontsize=12, color='red') elif n_components == 3: plt.figure(figsize=(10,10)) ax = plt.axes(projection='3d') ax.grid() ax.set_title(f"{'IZNN' if iznn else 'CTRNN'} Trajectory ({n_components}D)") ax.set_xlabel("Principal Component 1") ax.set_ylabel("Principal Component 2") ax.set_zlabel("Principal Component 3") ax.plot3D(reduced_states[:, 0], reduced_states[:, 1], reduced_states[:, 2], color='b', marker='o', markersize=3) # denoting the start and end points ax.text(reduced_states[0][0], reduced_states[0][1], reduced_states[0][2], 'Start', fontsize=12, color='green') ax.text(reduced_states[-1][0], reduced_states[-1][1], reduced_states[-1][2], 'End', fontsize=12, color='red') if save: plt.savefig(f"{dir_name + '/' if dir_name else ''}{file_name or 'iznn_trajectory' if iznn else 'ctrnn_trajectory'}.png") if show: plt.show()