main.bib
\AtBeginBibliography
marginparsep has been altered.
topmargin has been altered.
marginparpush has been altered.
The page layout violates the ICML style.
Please do not change the page layout, or include packages like geometry,
savetrees, or fullpage, which change it for you.
We’re not able to reliably undo arbitrary changes to the style. Please remove
the offending package(s), or layout-changing commands and try again.
A Differentiable Approach to Multi-scale Brain Modeling
Chaoming Wang 1 Muyang Lyu 2 3 Tianqiu Zhang 2 3 Sichao He 2 3 Si Wu 1 2 3 4 5 6
Published at the Differentiable Almost Everything Workshop at the International Conference on Machine Learning, Vienna, Austria. July 2024. Copyright 2024 by the author(s).
Abstract
We present a multi-scale differentiable brain modeling workflow utilizing BrainPy \citepwang2023brainpy,wang2024brainpy, a unique differentiable brain simulator that combines accurate brain simulation with powerful gradient-based optimization. We leverage this capability of BrainPy across different brain scales. At the single-neuron level, we implement differentiable neuron models and employ gradient methods to optimize their fit to electrophysiological data. On the network level, we incorporate connectomic data to construct biologically constrained network models. Finally, to replicate animal behavior, we train these models on cognitive tasks using gradient-based learning rules. Experiments demonstrate that our approach achieves superior performance and speed in fitting generalized leaky integrate-and-fire and Hodgkin-Huxley single neuron models. Additionally, training a biologically-informed network of excitatory and inhibitory spiking neurons on working memory tasks successfully replicates observed neural activity and synaptic weight distributions. Overall, our differentiable multi-scale simulation approach offers a promising tool to bridge neuroscience data across electrophysiological, anatomical, and behavioral scales.
1 Introduction
Modeling the entire human brain within a computer has been a long-standing dream for humanity \citepamunts2024coming. However, it represents an immense challenge, as the accurate construction of whole-brain models that coherently link multiple spatial scales faces the obstacle of insufficient biological data collection \citepd2022quest. Despite the numerous efforts dedicated to recording and measuring the brain, our observations remain partial, and the information gathered from experimental recordings falls far short of what is necessary to simulate a realistic brain \citepscheffer2021connectome. For instance, at the single neuron level, neurons exhibit diverse firing patterns, while their underlying ionic channels are difficult to discern. Automatic neuron fitting has therefore become a valuable tool for bridging the gap between models and recorded neuronal data, as it can estimate the parameters of these models \citeprossant2011fitting,rossant2010automatic. At the network level, we have recorded neural activities such as magnetoencephalography (MEG), electroencephalography (EEG), and functional magnetic resonance imaging (fMRI) under diverse conditions. However, we still do not fully understand why the underlying neuronal circuits produce such neural activities, despite the availability of connectome data \citepshiu2023leaky,dorkenwald2023neuronal,winding2023connectome. At the behavioral level, brain simulation network models still struggle to replicate the behavior of how the animal performs cognitive tasks \citeppotjans2014cell,schmidt2018multi,billeh2020systematic.
Consequently, achieving accurate multi-scale brain modeling necessitates the development of highly efficient optimization methods capable of seamlessly integrating and reconciling data across multiple scales, spanning from individual neurons to large-scale neural networks and cognitive processes. However, conventional brain simulators, such as NEURON \citepawile2022modernizing, NEST \citepgewaltig2007nest, and Brian2 \citepstimberg2019brian, pose significant challenges for high-order optimization due to their inherent black-box nature and lack of differentiability. The absence of differentiability restricts researchers to slower and less efficient optimization techniques, and even manual heuristic parameter searches \citepbilleh2020systematic. Moreover, the inability to leverage powerful gradient-based optimization techniques usually lead to longer computation times, and suboptimal model fits, further impeding the scalability of larger and more complex systems and exacerbating the challenges of multi-scale brain modeling.
To overcome these limitations, there is a pressing need for brain simulation frameworks that are natively differentiable, enabling efficient gradient-based optimizations. Recently, BrainPy \citepwang2023brainpy,wang2024brainpy has been proposed as a differentiable brain simulator to bridge this gap. By introducing fundamental features of a brain simulator, such as event-driven computation, sparse operators, numerical integrators, and a multi-scale model building interface, into the numerical computing framework JAX \citepfrostig2018compiling, BrainPy enables faithful brain simulation while inheriting the automatic differentiation (autograd) capabilities of JAX.
Leveraging BrainPy’s strengths, we propose a workflow for differentiable multi-scale brain modeling (Figure 1). This workflow utilizes gradient-based optimization to fit differentiable models of single neurons and synapses. We then incorporate connectomic data from neuroscience experiments to construct data-driven, biologically constrained spiking neural networks (SNNs). Finally, to replicate animal behavior, we train these biological-informed models on cognitive tasks using gradient-based online learning rules. Our experiments demonstrate the feasibility of our approach for achieving accurate multi-scale brain modeling.
2 Methods
We first present designs to enable our entire workflow differentiable.
2.1 Differentiable neuron models with surrogate gradients
Biological neurons generate non-differentiable binary spike events. This discontinuous nature of spiking operation, represented by the Heaviside function , where is the membrane potential, poses a challenge in applying gradient-based optimizations to SNN models. This is because the derivative of spiking operation is a Dirac delta function . In practice, surrogate gradients, which replace the delta gradient function with a smooth surrogate function, such as Gaussian \citepyin2021accurate, linear \citepbellec2018long, SLayer \citepshrestha2018slayer, or multi-Gaussian function \citepyin2021accurate, have demonstrated their efficacy in training SNNs using gradient descent \citepbohte2011error,Neftci2019Surrogate,bellec2020solution. We apply this approach in our workflow to enable gradient-based optimization. Moreover, we provide a suite of surrogate gradient functions (listed in Appendix C) to facilitate the selection of the most suitable function for a given task.
2.2 Event-driven differentiable synaptic operators
To mimic the brain’s efficient communication, traditional brain simulators leverage custom data structures for event-driven computations and spike communication \citepkunkel2012meeting,kunkel2014spiking,stimberg2019brian. However, these approaches often clash with autograd systems, hindering gradient-based optimization of synaptic computations. We address this challenge by introducing differentiable event-driven synaptic operators compatible with autograd frameworks (details in Appendix D). We utilize the compressed sparse row (CSR) format for storing synaptic connections and implement event-driven operations based on CSR arrays (Listing LABEL:lst:csrmv). Notably, these operators provide both forward and backward differentiation rules for differentiable computations (Listing LABEL:lst:csrmv:grads). Furthermore, BrainPy’s event-driven operators achieve significant speedups (one to two orders of magnitude) compared to traditional sparse and dense alternatives \citepwang2023brainpy,wang2024brainpy. This efficiency benefit applies to both forward state computations and backward gradient calculations.
3 Workflow for multi-scale differentiable brain modeling
Based on the differentiable neuronal and synaptic building blocks described earlier, we present a workflow for multi-scale differentiable brain modeling. This approach seamlessly integrates microscopic neuron models, mesoscopic neural circuit connectivity, and macroscopic computational tasks through gradient-based optimization algorithms (see Figure 1).
At the single neuron and synapse level (Figure 1A), accurate modeling of individual neurons and synaptic currents is made possible by existing knowledge and experimental techniques in cellular biophysics \citepkoch1998methods,teeter2018generalized. To facilitate large-scale neuronal network simulation and training, we employ point models, which capture diverse cellular behaviors while remaining differentiable and computationally efficient. Furthermore, to construct large sets of neuron models from empirical datasets conveniently and quickly, we employ gradient-based optimization (e.g., L-BFGS-B algorithm \citepliu1989limited,byrd1995limited). This fitting procedure, powered by JAX \citepfrostig2018compiling, is easily scalable and parallelizable through the jax.vmap or jax.pmap semantics.
At the network level (Figure 1B), we incorporate brain structure and connectome information to construct realistic brain models. Universal function approximation \citepcybenko1989approximation,leshno1993multilayer and Kolmogorov-Arnold representation theorems \citepkolmogorov1961representation,braun2009constructive suggest that distinct neural networks with complete different connectivity can perform the same computational tasks. Therefore, utilizing brain connectome-constrained neural models is an essential step in linking the organizational features of neuronal networks and the spectrum of cortical functions. Recent quantitative databases of the connectomes of various animals (e.g., Drosophila \citepwinding2023connectome, Zebrafish \citepkunst2019cellular, macaque \citepmarkov2014weighted,markov2014anatomy, mice \citepoh2014mesoscale,zingg2014neural, and marmoset \citepmajka2020open), have provided rich resources for this purpose.
At the behavioral level (Figure 1C), we utilize gradient-based optimizations to train above brain data-constrained networks on computational tasks. While handcrafted tuning and manually engineered network connectivity can implement specific functions, they fall short in generating brain-scale intelligence. Here, we optimize unknown network parameters using deep learning techniques \citepsaxe2021if, enabling the model to learn and perform tasks similar to an animal. Notably, we employ the online learning method from BrainScale \citepbrainscale, which offers an online approximation of backpropagation with low computational complexity and high training performance.
4 Training biologically-informed spiking networks on cognitive tasks
To exemplify the proposed workflow, we conducted training on a biologically-informed excitatory and inhibitory (EI) spiking network using a working memory task. The dynamics of spiking neurons in our EI network are governed by generalized integrate-and-fire (GIF) neurons \citepjolivet2004generalized. The synaptic dynamics are implemented using the Exponential model. To accurately capture the characteristic tonic spiking and adaptation \citepstaining2015allen,teeter2018generalized, the firing pattern of each GIF neuron is optimized using the L-BFGS-B algorithm (Section 4.1). The neurons are divided into excitatory and inhibitory neurons with a 4:1 EI ratio. The connectivity between the neurons in the network is established based on principles derived from the neocortical connectome \citeptheodoni2022structural. For training both the excitatory and inhibitory weights to perform the working memory tasks (Section 4.2), we employed the online learning framework BrainScale \citepbrainscale. Complete details of the EI model please see Appendix F.
4.1 Neuron fitting
Our neuron fitting procedure is depicted in Figure 2. The experimental data is obtained through current-clamp recordings, where the recorded currents mimic synaptic activity observed in vivo (Figure 2A). The neuron is defined in BrainPy as a differentiable model, and a loss function is employed to quantify the disparity between the model’s predictions and the experimental data. The mean square error can be used for fitting the membrane potential, while the gamma factor \citepjolivet2008benchmark can be employed for fitting spike trains (Appendix E). These criteria are used to calculate the gradients and subsequently update the parameter values. Through iterative gradient estimations and parameter updates, the fitting procedure aims to identify the optimal parameters that best align with the experimental recording data (Figure 2B).
Our fitting method is first tested on GIF neuron models to capture characteristic cortical firing patterns such as spike frequency adaptation, phasic spiking, and rebound spiking. We compare the performance and speed of our gradient-based methods with conventional optimization algorithms, namely differential evolution (DE), DE algorithm with two points crossover (TwoPointsDE), and particle swarm optimization (PSO) provided in Nevergrad \citepbennet2021nevergrad, as well as the Bayesian optimization method in scikit-optimize \citeplouppe2017bayesian. The experiments demonstrate that L-BFGS-B and Bayesian optimizations exhibit the best fitting performance (Figure S7), while DE, TwoPointsDE, and PSO methods demonstrate faster fitting speed (Table S2). Although Bayesian optimization shows good performance, it converges slowly. These results indicate that the gradient-based L-BFGS-B method provides a good tradeoff between fitting performance and speed.
We further evaluate our fitting method on Hodgkin-Huxley (HH) neuron models using realistic electrophysiological recording data. Figure 2C and Figure S8 demonstrate the application of five fitting methods to an in vitro intracellular recording of a cortical pyramidal cell. The fitting results reveal that L-BFGS-B exhibits the best fitting performance (Figure 2C), achieving nearly perfect fitting of membrane potentials with a loss close to zero (Table S3). Moreover, our fitting methods demonstrate comparable speed to the evolutionary algorithm while being significantly faster than the Bayesian optimization method (Table S3). These findings underscore the potential of differentiable optimization as a promising approach for neuronal fitting.
4.2 Task training
Understanding how the brain performs complex computations remains a challenge. Recent advances in training recurrent neural networks have demonstrated high performance across various tasks, offering a promising avenue for uncovering the underlying dynamical and computational mechanisms involved \citepsong2016training,barak2017recurrent. However, these networks often lack essential biological constraints, such as spike-based communication, structural connectivity, and the distinction between excitatory and inhibitory neurons. In this study, we propose training biological SNNs while explicitly considering electrophysiological, anatomical, and structural constraints. Specifically, we construct a foundational EI network using conductance-based GIF neurons fitted to data (see Section 4.1), incorporating connectomic connectivity \citeptheodoni2022structural and conductance-based synaptic dynamics \citepvogels2005signal to implement the working memory task through gradient-based optimization algorithms \citepbrainscale.
To generate the training data, we followed the experimental setup of an evidence accumulation task \citepmorcos2016history. The input spike train was divided into four segments: the left and right stimuli, the recall cue, and the background noise (Figure 3A). Our network was trained to separately count the left and right cues and generate the correct response by comparing the resulting numbers after prolonged periods of delay. We recorded the responses of both the excitatory and inhibitory neurons in the recurrent layer (Figure 3B). During the evidence accumulation period, inhibitory neurons exhibited significant responses after each stimulus presentation, whereas excitatory neurons displayed lower firing rates. However, during the recall period, inhibitory neurons rarely spiked. We also examined the membrane potential of all neurons (Figure 3C and D). In contrast to the current-based synapse model commonly used in deep learning applications, our conductance-based synapse modeling ensured that the membrane potential remained constrained between excitatory and inhibitory reversal potentials, eliminating the need for voltage regularization. Additionally, we analyzed the synaptic weight distribution before and after training (Figure 3E and F). We initialized excitatory and inhibitory weights with a normal distribution and took their absolute values (Figure 3E). After training, synaptic weights exhibited a distribution similar to that observed in biological measurements. Specifically, excitatory weights followed the tail of a Gaussian distribution \citepbarbour2007can, while inhibitory weights showed a log-normal distribution \citeploewenstein2011multiplicative,buzsaki2014log.
5 Conclusion and discussion
We proposed a novel workflow for differentiable multi-scale brain modeling by integrating various levels of information and constraints to build brain models that can reproduce cognitive behaviors observed in humans or animals. We demonstrated this workflow by training a biologically informed GIF network to accomplish an evidence accumulation task. Although the current illustration utilizes a network with hundreds of neurons, the online learning algorithms employed are readily scalable to much larger models (refer to Appendix G for scalability analysis). Overall, our proposed differentiable approach has the potential to accelerate progress in developing accurate and biologically plausible multi-scale brain models, ultimately leading to a deeper understanding of the brain.
However, many important challenges remain to be addressed in the future (see Appendix H for details). These challenges include balancing data quality and availability with model realism, determining the appropriate granularity for simplifying and approximating biological processes within the model, and ensuring the interpretability and theoretical grounding of the derived models. Additionally, addressing the computational efficiency of handling large-scale networks with high-dimensional parameter spaces is crucial.
Appendix A Software and Data
The in vitro intracellular recording of a cortical pyramidal cell can be obtained in brain2modelfitting \citepteska2020brian2modelfitting. BrainPy is available publicly on GitHub at https://github.com/brainpy/BrainPy. As of now, BrainScale \citepbrainscale is undergoing a review process and is temporarily unavailable. However, it is expected to be released in the future. Additional packages related to the BrainPy ecosystem are accessible through the BrainPy GitHub organization at https://github.com/brainpy. The code necessary to reproduce the results presented in this paper can be found in the following GitHub repository: https://github.com/chaoming0625/differentiable-brain-modeling-workflow.
Appendix B Acknowledgements
This work was supported by Science and Technology Innovation 2030-Brain Science and Brain-inspired Intelligence Project (No. 2021ZD0200204).
Appendix C Surrogate gradient functions
In recent years, spiking neural networks (SNNs) have garnered attention due to their promising advantages in energy efficiency, fault tolerance, and biological plausibility. However, training SNNs using standard gradient descent methods is challenging because their activation functions are discontinuous and have near-zero gradients across most points. To tackle this issue, a common approach is to replace the non-differentiable spiking function with a surrogate gradient function \citepNeftci2019Surrogate. A surrogate gradient function is a smooth approximation of the derivative of the activation function, enabling the application of gradient-based learning algorithms to SNNs. The BrainPy library offers a variety of surrogate gradient functions, each possessing different characteristics such as smoothness, boundedness, and biological plausibility. A comprehensive list of these functions is provided in Table S1, and an example of a surrogate gradient function is illustrated in Figure S4.
![Refer to caption](https://cdn.statically.io/img/arxiv.org/extracted/5701757/figs/surrogate_gradient_funcs.png)
In practical applications, users can employ these surrogate gradient functions to determine whether a spike is generated at the current time step, as demonstrated in the following Python code:
By utilizing the appropriate surrogate gradient function, the code above allows users to assess whether a spike occurs based on the comparison between the membrane potential (V) and the threshold (V_th).
Surrogate Gradient Function | Implementation |
---|---|
Sigmoid function | brainpy.math.surrogate.sigmoid |
Piecewise quadratic function \citepEsser2016Convolutional,Yujie2018Spatio | brainpy.math.surrogate.piecewise_quadratic |
Piecewise exponential function \citepNeftci2019Surrogate | brainpy.math.surrogate.piecewise_exp |
Soft sign function | brainpy.math.surrogate.soft_sign |
Arctan function | brainpy.math.surrogate.arctan |
Nonzero sign log function | brainpy.math.surrogate.nonzero_sign_log |
ERF function \citepEsser2015Backpropagation,Yujie2018Spatio,Yin2020Effective | brainpy.math.surrogate.erf |
Piecewise leaky ReLU function \citepShihui2018Algorithm,Yujie2018Spatio | brainpy.math.surrogate.piecewise_leaky_relu |
Squarewave Fourier series | brainpy.math.surrogate.squarewave_fourier_series |
S2NN function \citepsuetake2023s3nn | brainpy.math.surrogate.s2nn |
q-PseudoSpike function \citepherranzcelotti2022surrogate | brainpy.math.surrogate.q_pseudo_spike |
Leaky ReLU function | brainpy.math.surrogate.leaky_relu |
Log-tailed ReLU function \citepZhaowei2017Deep | brainpy.math.surrogate.log_tailed_relu |
ReLU function \citepNeftci2019Surrogate | brainpy.math.surrogate.relu_grad |
Gaussian function \citepyin2021accurate | brainpy.math.surrogate.gaussian_grad |
Multi-Gaussian function \citepyin2021accurate | brainpy.math.surrogate.multi_gaussian_grad |
Inverse-square function | brainpy.math.surrogate.inv_square_grad |
SLayer function \citepshrestha2018slayer | brainpy.math.surrogate.slayer_grad |
Appendix D Event-driven synaptic operators
Synaptic computation usually needs event-driven matrix-vector multiplication , where is the presynaptic spikes, the synaptic connection matrix, and the postsynaptic current. Specifically, it performs matrix-vector multiplication in a sparse and efficient way by exploiting the event property of the input vector . Instead of multiplying the entire matrix by the vector , which can be wasteful if has many zero elements, event-driven matrix-vector multiplication in BrainPy only performs multiplications for the non-zero elements of the vector, which are called events. This can reduce the number of operations and memory accesses, and improve the running performance of matrix-vector multiplication.
Particularly, we implement event-driven operators based on arrays with the compressed sparse row (CSR) format and provide both forward and backward differentiation rules. The CSR format represents the synaptic connectivity between pre- and post-synaptic neuron populations, comprising three arrays: <val, col_ind, row_ptr>. val stores the non-zero synaptic weights, col_ind stores the postsynaptic indices of the corresponding non-zero weights, and row_ptr stores the starting indices of each presynaptic neuron in the val and col_ind arrays.
To perform an event-driven linear transformation , where is the CSR-formatted connectivity, the pseudo-code is implemented as:
To efficiently compute the gradients of and , we implement the event-driven gradient computation as follows:
Appendix E Loss functions for neuron fitting
E.1 Mean square error for fitting membrane potentials
In order to align the simulated membrane potential with the experimentally recorded potentials, we compute the mean squared difference between the data and the simulated trace using the mean square error formula:
(1) |
where is the total number of times.
E.2 Gamma factor for fitting spike trains
The Gamma factor \citepjolivet2008benchmark serves as a metric for assessing the agreement between spike timings in the simulated and target traces. It is commonly employed to evaluate the performance of spiking neuron models when fitting them to electrophysiological recordings of individual neurons. The gamma factor primarily focuses on the proportion of predicted spikes that coincide with the spikes in the recording. Essentially, it quantifies how accurately the model reproduces the timing of the neuron’s firing events. The calculation of the gamma factor is as follows:
(2) |
where
-
•
: number of coincidences
-
•
and : number of spikes in experimental and model spike trains
-
•
: average firing rate in experimental train
-
•
: expected number of coincidences with a Poisson process
The gamma factor equals 1 when the two spike trains match perfectly and decreases for less precise matches. It reaches 0 when the number of coincidences matches the expected count from two homogeneous Poisson processes with the same firing rate.
To turn the Gamma factor into a loss function, we add a correction term:
(3) |
where and are the firing rates measured in the data and model, respectively.
Appendix F Excitatory and inhibitory spiking network models
F.1 Network structure
The architecture of recurrent excitatory and inhibitory spiking networks used here is shown in Figure S5, where the recurrent layer consists of excitatory and inhibitory spiking units that receive and process the time-varying inputs from the input layer, and generate the desired time-varying outputs. The input layer encodes the sensory information relevant to the task, while the readout layer produces a decision in terms of an abstract decision variable.
![Refer to caption](https://cdn.statically.io/img/arxiv.org/x3.png)
F.2 Input layer
The input layer in our spiking neural network is designed for an evidence accumulation task and comprises spiking neurons (Figure 3A). These neurons are divided into four functionally distinct groups:
-
•
Left and Right Stimuli: The first two groups represent the left and right stimuli, respectively. Each group contains 25 neurons and fires at a maximum rate of 40 Hz during the evidence accumulation period (first 1100 ms). This neural activity encodes the sensory evidence for each side.
-
•
Recall Signals: The third group consists of 25 neurons that generate recall signals during the output period (last 150 ms). These neurons are crucial for retrieving relevant information from memory to aid in the decision-making process.
-
•
Background Noise: The fourth group comprises 25 neurons that fire at a constant rate of 10 Hz throughout the entire simulation. This group simulates background activity from other cortical areas, adding a layer of realism to the network.
F.3 Recurrent spiking units
For the spiking model, the recurrent cell was implemented with the spiking neuron model. In this study, the spiking neuron is modified from the generalized integrate-and-fire neuron model \citepjolivet2004generalized. In particular, this model has two internal currents, one fast and one slow. Its dynamic behavior is given by
fast internal current | (4) | ||||
slow internal current | (5) | ||||
membrane potential | (6) |
When of -th neuron meets , the modified GIF model fires:
(7) | |||
(8) | |||
(9) |
where denotes the time constant of the fast internal current, the time constant of the slow internal current, the time constant of membrane potential, the resistance, the external input, the resting potential, and and the spike-triggered currents.
To match the firing patterns observed in electrophysiological experiments, particularly the tonic spiking and adaptation, we fit the neuron parameters using our gradient-based optimization methods (Section 4.1).
For the forward spiking operation, we use the Heaviside function to generate the spike:
(10) |
where is used to represent .
To make the non-differentiable spiking activation compatible with the gradient-based algorithm, we considered a surrogate gradient (Appendix C):
(11) |
where , and . is the parameter that controls the altitude of the gradient, and is the parameter that controls the width of the gradient.
F.4 Synapse dynamics
For modeling the synaptic connections between these neurons, we employed a conductance-based current approach. The input current () for each neuron is calculated as
(12) |
where the reversal potentials are mV and mV.
The synaptic dynamics are characterized by exponential synapses,
(13) |
where is the time constant of the synaptic state decay, and is the -th spiking time of the presynaptic neuron . Moreover, the appropriate synaptic variable of the postsynaptic conductance increases when a presynaptic neuron () fires. For an excitatory presynaptic neuron,
(14) |
and for an inhibitory presynaptic neuron,
(15) |
Typically, we set ms.
F.5 Scaled membrane potential
The subthreshold dynamics of biological neurons and synapses usually operate with very negative membrane potentials, often in the range of -50 mV to -80 mV. These large negative values can create challenges during training, especially when using low-precision computing. To address this, we employ a rescaling approach that normalizes the membrane potential.
The rescaling process involves an offset value () and a scaling factor (). Every membrane potential () is transformed using the following equation:
(16) |
Here, represents the rescaled membrane potential. By setting such that the firing threshold becomes 1 after rescaling, we achieve a more manageable range of values for training purposes, particularly with limited computational precision. Particularly, we use and mv, so that the membrane threshold is normalized to 1, and reversal potentials of excitatory and inhibitory synapses are rescaled to 3.0 and -3.0.
F.6 Network connectivity
We designed a network comprising 400 cells, where the excitatory to inhibitory neuron ratio was set to 4:1 (Figure S5). To establish connectivity within the network, we randomly interconnected these neurons with a connection probability of 10%. This choice of connection probability takes into account the observed higher values for neighboring neurons in the cortex, as well as the lower values for neurons that are more distant \citeptheodoni2022structural.
Apart from the recurrent connections, the input neurons project excitatory synapses to all the recurrent neurons. As for the readout layer, since there is no known biological correspondence for how the brain interprets the recurrent signals, we opted for a linear projection with leaky dynamics.
F.7 Readout layer
The spiking activity of the recurrent neurons is read out using a linear layer that incorporates the leaky dynamics of the neurons:
(17) |
where is the time constant of the output neuron, the synaptic weights between the recurrent and output neurons, and the bias. In the discrete description, the output dynamics is written as:
(18) |
where .
F.8 Weight initialization
Initial input and recurrent weights were drawn from a Gaussian distribution and taken the absolute values , where is the number of afferent neurons, is the zero-mean unit-variance Gaussian distribution, and is the weight scale. For excitatory neurons (including the input and recurrent excitatory neurons), ; for inhibitory neurons, . For the readout weights, we draw its values from a Gaussian distribution , where is the number of neurons in the recurrent layer.
F.9 Training methods
The neuron fitting was performed using the L-BFGS-B algorithm \citepliu1989limited,byrd1995limited. While for the task training, we utilized the online learning algorithm in BrainScale \citepbrainscale. The integration time step is 1 ms for the spiking neural network. The Adam optimizer \citepKingma2014AdamAM was used to calculate gradient-based optimization. The goal of the training was to minimize the cross-entropy between the output activity and the target output during the recall period.
Appendix G Scalability analysis
The large-scale nature of the brain raises an important question: can our differentiable approach scale up to brain-scale spiking neural networks? Currently, we cannot provide a definitive answer, as it involves not only computational resource bottlenecks but also performance generalization to high-dimensional parameter spaces. However, from a computational resource perspective, we have evaluated how our approach can scale up the training of large-scale spiking networks with long time sequences (see Figure S6 for details).
![Refer to caption](https://cdn.statically.io/img/arxiv.org/x4.png)
In particular, we conducted a memory and computational complexity analysis on a three-layer recurrent spiking neural network trained on the IBM DVS Gesture dataset \citepamir2017low. With a batch size of 128 and 512 hidden neurons per layer, we compared the average memory usage and computation time per batch when training using backpropagation through time (BPTT) \citepWerbos1988GeneralizationOB,Mozer1989AFB,Williams1990AnEG and the BrainScale neuromorphic system \citepbrainscale across various sequence lengths.
As shown in Figure S6A, the training memory required by BPTT increases linearly with the number of time steps, leading to an out-of-memory error for sequences longer than 600 steps. In contrast, BrainScale demonstrates remarkable memory efficiency by maintaining a constant memory footprint for the eligibility trace, regardless of the sequence length. Notably, BrainScale consumes less than 0.5 GB of GPU memory during the entire training process, reducing memory requirements by hundreds of times compared to BPTT.
The computational time for both BPTT and BrainScale scales linearly with the number of time steps (Figure S6B). However, BrainScale trains approximately twice as fast as BPTT, and this acceleration ratio increases for longer sequences due to its event-driven and low complexity computation.
We also evaluated the training performance of BPTT and BrainScale (Figure S6C under the same hyperparameter settings for the training and network model. We calculated the maximum testing accuracy, and found that BrainScale demonstrated comparable performance to BPTT, even as the sequence length increased.
These results highlight the potential scalability of our differentiable approach, leveraging the computational efficiency and memory advantages of the online training system BrainScale. While further research is needed to generalize to larger networks and more complex tasks, our approach paves the way for training brain-scale spiking neural networks efficiently.
Appendix H Limitations and potential challenges
The multi-scale differentiable brain modeling workflow described above is a comprehensive approach that aims to integrate various levels of information and constraints to build brain models that can reproduce cognitive behaviors observed in humans or animals. However, there are potential limitations and challenges to this approach:
-
1.
Data availability and quality: The accuracy of brain models at each scale heavily depends on the quality and availability of the underlying data. At the microscale level, the accuracy of single neuron and synapse models relies on the quality and completeness of electrophysiological recording data. However, in practice, it is often challenging to obtain comprehensive electrophysiological data for all neurons. Similarly, at the mesoscopic level, the accuracy of connectome constraints depends on the quality and resolution of structural connectivity data obtained from techniques such as diffusion tensor imaging (DTI) or electron microscopy. These techniques face significant challenges in acquiring precise connectomic information in humans and animals. For example, while high-resolution electron microscopy can generate detailed 3D maps of neuronal connections from thin brain sections, the process is extremely labor-intensive, requires specialized equipment, and is currently limited to small tissue volumes due to the immense computational demands of reconstructing large-scale connectomes.
-
2.
Biological realism and simplifications: While our approach aims to incorporate biological constraints, it necessarily involves simplifications and approximations of the underlying biological processes. In particular, our current model utilizes point-based simplified neuron models and considers only excitatory and inhibitory synaptic connections. This level of abstraction cannot fully capture the complexity and dynamics of real biological systems, especially when addressing intricate phenomena such as nonlinear dendritic effects, neuromodulation, and synaptic plasticity.
-
3.
Complexity and computational resources: Building and training multi-scale brain models can be computationally intensive, particularly when dealing with large-scale networks and high-dimensional parameter spaces. To address this, we utilize the online learning framework BrainScale \citepbrainscale, which significantly reduces computational resource requirements for long-sequence learning tasks (see Appendix G). However, while online learning algorithms in BrainScale are memory-efficient, they are still approximations of full gradients and may encounter inefficiencies when training very high-dimensional parameters. Moreover, multi-scale brain models necessitate substantial computational resources, including powerful GPU hardware and efficient distributed computing algorithms. The complexity of these models demands not only high-performance computing infrastructure but also optimized software frameworks to manage and streamline the extensive computations involved.
-
4.
Integration across scales: Seamlessly integrating information and constraints from different spatial and temporal scales is a non-trivial task. More biological details or constraints usually lead to more poor training performance. As more biological details are incorporated, the number of parameters in the model grows, making the optimization problem more complex and prone to issues such as vanishing/exploding gradients, local minima, and slow convergence.
-
5.
Interpretability and theoretical insights: While the gradient-based optimization methods can produce models that fit the data, it may be challenging to derive theoretical insights or interpretable mechanisms from these models, especially when dealing with highly complex and non-linear brain dynamics models.
Despite these potential limitations and challenges, the multi-scale differentiable brain modeling workflow still represents a promising approach to integrate various levels of information and constraints to build more realistic and accurate brain models.
Appendix I Supplementary data
Fitting Method | Fitting Loss | Fitting Speed |
---|---|---|
L-BFGS-B (Ours) | 6.799 4.623 | 5.404 0.294 s |
DE (Nevergrad) | 9.966 2.281 | 1.172 0.09 s |
PSO (Nevergrad) | 14.034 2.857 | 1.151 0.150 s |
TwoPointsDE (Nevergrad) | 9.702 5.115 | 1.161 0.172 s |
Bayesian (scikit-optimize) | 3.511 1.741 | 62.330 10.917 s |
Fitting Method | Fitting Loss | Fitting Speed |
---|---|---|
L-BFGS-B (Ours) | 2.3e-08 1.55e-08 | 3.818 0.725 s |
DE (Nevergrad) | 24.42 22.55 | 0.619 0.139 s |
PSO (Nevergrad) | 35.72 24.06 | 0.663 0.168 s |
TwoPointsDE (Nevergrad) | 32.31 22.60 | 0.658 0.175 s |
Bayesian (scikit-optimize) | 27.95 25.76 | 55.26 12.39 s |