"""
This module contains the class which is used to discretize continuous network dymamics.
"""
import numpy as np
from sklearn.cluster import KMeans, DBSCAN
from scipy.optimize import linear_sum_assignment
from typing import Callable, Optional, Tuple, Union, List, Dict, Iterable, Collection
from ctneat.iznn.dynamic_attractors import resample_data, dynamic_attractors_pipeline
[docs]
class Discretizer:
"""
This class is used to discretize continuous network dynamics.
"""
[docs]
def __init__(self, network, inputs: Collection[Collection], outputs: Collection[Union[Collection, int, float]],
max_time: float = 20.0, dt: float = 0.05,
force_cluster_num: bool = False, epsilon: float = 0.5, min_samples: int = 1,
random_state: Optional[int] = 3, verbose: bool = False, printouts: bool = True,
advance_args: Optional[Dict] = None, resample_data_args: Optional[Dict] = None,
dynamics_args: Optional[Dict] = None, kmeans_args: Optional[Dict] = None,
dbscan_args: Optional[Dict] = None):
"""
Initializes the Discretizer with the given parameters.
Args:
network: The continuous network to be discretized. The type is not strictly defined here,
however, any object passed here must have an `advance` method and a `time_ms` attribute.
inputs (List[Union[Tuple, List]]): List of input vectors to the network.
outputs (List[Union[int, float]]): List of expected output values corresponding to the inputs.
max_time (float): Maximum time to run the network for each input (in ms).
dt (float): Time step for the simulation (in ms).
force_cluster_num (bool): If True, forces KMeans clustering with number of clusters equal to
number of unique outputs.
epsilon (float): Epsilon parameter for DBSCAN clustering. This is the maximum distance between two samples
for one to be considered as in the neighborhood of the other.
min_samples (int): Minimum samples parameter for DBSCAN clustering. This is the number of samples
in a neighborhood for a point to be considered as a core point.
random_state (Optional[int]): Random state for reproducibility. If None, randomness is not controlled.
verbose (bool): If True, prints detailed logs during processing.
printouts (bool): If True, prints summary information after processing.
advance_args (Optional[Dict]): Additional arguments for the network's advance method.
For reference, see the `advance` method of the network being used.
resample_data_args (Optional[Dict]): Additional arguments for the resample_data function.
For reference, see the `resample_data` function in ctneat.iznn.dynamic_attractors.
dynamics_args (Optional[Dict]): Additional arguments for the network's dynamics method.
For reference, see the `dynamic_attractors_pipeline` function in ctneat.iznn.dynamic_attractors.
kmeans_args (Optional[Dict]): Additional arguments for KMeans clustering.
For reference, see sklearn.cluster.KMeans.
dbscan_args (Optional[Dict]): Additional arguments for DBSCAN clustering.
For reference, see sklearn.cluster.DBSCAN.
"""
self.network = network
self.inputs = inputs
self.outputs = outputs
self.max_time = max_time
self.dt = dt
self.force_cluster_num = force_cluster_num
self.epsilon = epsilon
self.min_samples = min_samples
self.random_state = random_state
self.verbose = verbose
self.printouts = printouts
# processing the advance_args dictionary as some arguments are passed directly to the functions
self.advance_args = advance_args if advance_args is not None else {}
self._advance_args_ret = self.advance_args.get('ret', ['voltages', 'fired'])
if 'ret' in self.advance_args:
del self.advance_args['ret']
# similarly for resample_data_args
self.resample_data_args = resample_data_args if resample_data_args is not None else {}
self._dt_uniform_ms = self.resample_data_args.get('dt_uniform_ms', 'min')
if 'dt_uniform_ms' in self.resample_data_args:
del self.resample_data_args['dt_uniform_ms']
self._using_simulation = self.resample_data_args.get('using_simulation', True)
if 'using_simulation' in self.resample_data_args:
del self.resample_data_args['using_simulation']
self._events = self.resample_data_args.get('events', False)
if 'events' in self.resample_data_args:
del self.resample_data_args['events']
# similarly for dynamics_args
self.dynamics_args = dynamics_args if dynamics_args is not None else {}
self._variable_burn_in = self.dynamics_args.get('variable_burn_in', True)
if 'variable_burn_in' in self.dynamics_args:
del self.dynamics_args['variable_burn_in']
# similarly for kmeans_args
self.kmeans_args = kmeans_args if kmeans_args is not None else {}
# similarly for dbscan_args
self.dbscan_args = dbscan_args if dbscan_args is not None else {}
# calculate number of unique outputs
self.unique_outputs = list(set(self.outputs))
self.num_unique_outputs = len(self.unique_outputs)
# force an order on the unique outputs
self.unique_outputs.sort(key=lambda x: (isinstance(x, str), x))
if self.verbose:
print(f"Unique outputs identified: {self.unique_outputs}")
# placeholder for network attractors produced by each input
self.network_attractors = {}
[docs]
def run_network(self):
"""
Run the network for each input and the specified max_time.
Network dynamics is measured and used to find the attractor state which is stored in self.network_attractors.
If an attractor state cannot be found, None is stored for that input.
"""
for i, input_vector in enumerate(self.inputs):
if self.verbose:
print(f"Running network for input {i+1}/{len(self.inputs)}: {input_vector}")
self.network.reset()
self.network.set_inputs(input_vector)
times = [self.network.time_ms]
voltage_history = [self.network.voltages]
fired_history = [self.network.fired]
while self.network.time_ms < self.max_time:
voltages, fired = self.network.advance(dt=min(self.dt, max(self.max_time - self.network.time_ms, 0.0001)),
ret=self._advance_args_ret, **self.advance_args)
times.append(self.network.time_ms)
voltage_history.append(voltages)
fired_history.append(fired)
times = np.array(times)
voltage_history = np.array(voltage_history)
fired_history = np.array(fired_history)
# resample to uniform time steps
uniform_time_steps, uniform_voltage_history = resample_data(times, voltage_history, dt_uniform_ms=self._dt_uniform_ms,
using_simulation=self._using_simulation, net=self.network,
events=self._events, ret='voltages')
_, uniform_fired_history = resample_data(times, fired_history, dt_uniform_ms=self._dt_uniform_ms,
using_simulation=self._using_simulation, net=self.network,
events=self._events, ret='fired')
# analyze dynamics to find attractor state
attractor_state = dynamic_attractors_pipeline(voltage_history=uniform_voltage_history, fired_history=uniform_fired_history, times_np=uniform_time_steps,
variable_burn_in=self._variable_burn_in, fingerprint_vec=True, verbose=self.verbose, printouts=self.printouts, **self.dynamics_args)
self.network_attractors[i] = attractor_state
if self.printouts:
print(f"Attractor state for input {i+1}: {attractor_state}")
if self.printouts:
print("Network run complete. Attractor states recorded.")
[docs]
def cluster_attractors(self) -> Dict[int, int]:
"""
Cluster the attractor states using either KMeans or DBSCAN.
If force_cluster_num is True, KMeans is used with number of clusters equal to number of unique outputs.
Otherwise, DBSCAN is used.
Returns:
A dictionary mapping input index to cluster label.
"""
# select only the attractor states which were found (not None)
attractor_states = [np.array(self.network_attractors[i]) for i in range(len(self.inputs)) if self.network_attractors[i] is not None]
attractor_states = np.array(attractor_states)
if len(attractor_states) == 0:
return {}
if self.force_cluster_num:
if self.verbose:
print("Clustering attractors using KMeans...")
kmeans = KMeans(n_clusters=self.num_unique_outputs, random_state=self.random_state, verbose=int(self.verbose),
**self.kmeans_args)
cluster_labels = kmeans.fit_predict(attractor_states)
else:
if self.verbose:
print("Clustering attractors using DBSCAN...")
dbscan = DBSCAN(eps=self.epsilon, min_samples=self.min_samples, **self.dbscan_args)
cluster_labels = dbscan.fit_predict(attractor_states)
if self.printouts:
print(f"Cluster labels assigned: {cluster_labels}")
# map input index to cluster label
input_to_cluster = {i: cluster_labels[i] for i in range(len(self.inputs))}
return input_to_cluster
[docs]
def map_clusters_to_outputs(self, input_to_cluster: Dict[int, int]) -> Dict[int, Union[int, float]]:
"""
Map the clusters to the expected outputs using the Hungarian algorithm to minimize total mismatch.
Args:
input_to_cluster (Dict[int, int]): A dictionary mapping input index to cluster label.
Returns:
A dictionary mapping cluster label to expected output value.
"""
if len(input_to_cluster) == 0:
return {}
# create cost matrix
cost_matrix = np.zeros((self.num_unique_outputs, self.num_unique_outputs))
for i in range(len(self.inputs)):
if i in input_to_cluster:
cluster_label = input_to_cluster[i]
if cluster_label != -1: # ignore noise points from DBSCAN
output_value = self.outputs[i]
output_index = self.unique_outputs.index(output_value)
cost_matrix[cluster_label, output_index] += 1
# convert counts to costs
cost_matrix = np.max(cost_matrix) - cost_matrix
# apply Hungarian algorithm
row_ind, col_ind = linear_sum_assignment(cost_matrix)
# create mapping from cluster label to output value
cluster_to_output = {row: self.unique_outputs[col] for row, col in zip(row_ind, col_ind)}
if self.verbose:
print(f"Cluster to output mapping: {cluster_to_output}")
return cluster_to_output
[docs]
def discretize(self) -> Dict[int, Union[int, float, None]]:
"""
Run the full discretization pipeline: run the network, cluster attractors, and map clusters to outputs.
Returns:
A dictionary mapping input index to predicted output value.
"""
self.run_network()
input_to_cluster = self.cluster_attractors()
cluster_to_output = self.map_clusters_to_outputs(input_to_cluster)
input_to_output = {i: cluster_to_output.get(input_to_cluster.get(i, -1), None) for i in range(len(self.inputs))}
if self.printouts:
print("Discretization complete.")
return input_to_output