\addbibresource

main.bib \AtBeginBibliography marginparsep has been altered.
topmargin has been altered.
marginparpush has been altered.
The page layout violates the ICML style. Please do not change the page layout, or include packages like geometry, savetrees, or fullpage, which change it for you. We’re not able to reliably undo arbitrary changes to the style. Please remove the offending package(s), or layout-changing commands and try again.

 

A Differentiable Approach to Multi-scale Brain Modeling

 

Chaoming Wang1  Muyang Lyu2 3  Tianqiu Zhang2 3  Sichao He2 3  Si Wu1 2 3 4 5 6 


footnotetext: 1School of Psychological and Cognitive Sciences, 2Academy for Advanced Interdisciplinary Studies, 3Peking-Tsinghua Center for Life Sciences, 4IDG/McGovern Institute for Brain Research, 5Center of Quantitative Biology, 6Bejing Key Laboratory of Behavior and Mental Health, Peking University, Beijing, China. Correspondence to: Chaoming Wang <wangchaoming@pku.edu.cn>, Si Wu <siwu@pku.edu.cn>.  
Published at the 2ndsuperscript2𝑛𝑑\mathit{2}^{nd}italic_2 start_POSTSUPERSCRIPT italic_n italic_d end_POSTSUPERSCRIPT Differentiable Almost Everything Workshop at the 41stsuperscript41𝑠𝑡\mathit{41}^{st}italic_41 start_POSTSUPERSCRIPT italic_s italic_t end_POSTSUPERSCRIPT International Conference on Machine Learning, Vienna, Austria. July 2024. Copyright 2024 by the author(s).
Abstract

We present a multi-scale differentiable brain modeling workflow utilizing BrainPy \citepwang2023brainpy,wang2024brainpy, a unique differentiable brain simulator that combines accurate brain simulation with powerful gradient-based optimization. We leverage this capability of BrainPy across different brain scales. At the single-neuron level, we implement differentiable neuron models and employ gradient methods to optimize their fit to electrophysiological data. On the network level, we incorporate connectomic data to construct biologically constrained network models. Finally, to replicate animal behavior, we train these models on cognitive tasks using gradient-based learning rules. Experiments demonstrate that our approach achieves superior performance and speed in fitting generalized leaky integrate-and-fire and Hodgkin-Huxley single neuron models. Additionally, training a biologically-informed network of excitatory and inhibitory spiking neurons on working memory tasks successfully replicates observed neural activity and synaptic weight distributions. Overall, our differentiable multi-scale simulation approach offers a promising tool to bridge neuroscience data across electrophysiological, anatomical, and behavioral scales.

1 Introduction

Modeling the entire human brain within a computer has been a long-standing dream for humanity \citepamunts2024coming. However, it represents an immense challenge, as the accurate construction of whole-brain models that coherently link multiple spatial scales faces the obstacle of insufficient biological data collection \citepd2022quest. Despite the numerous efforts dedicated to recording and measuring the brain, our observations remain partial, and the information gathered from experimental recordings falls far short of what is necessary to simulate a realistic brain \citepscheffer2021connectome. For instance, at the single neuron level, neurons exhibit diverse firing patterns, while their underlying ionic channels are difficult to discern. Automatic neuron fitting has therefore become a valuable tool for bridging the gap between models and recorded neuronal data, as it can estimate the parameters of these models \citeprossant2011fitting,rossant2010automatic. At the network level, we have recorded neural activities such as magnetoencephalography (MEG), electroencephalography (EEG), and functional magnetic resonance imaging (fMRI) under diverse conditions. However, we still do not fully understand why the underlying neuronal circuits produce such neural activities, despite the availability of connectome data \citepshiu2023leaky,dorkenwald2023neuronal,winding2023connectome. At the behavioral level, brain simulation network models still struggle to replicate the behavior of how the animal performs cognitive tasks \citeppotjans2014cell,schmidt2018multi,billeh2020systematic.

Consequently, achieving accurate multi-scale brain modeling necessitates the development of highly efficient optimization methods capable of seamlessly integrating and reconciling data across multiple scales, spanning from individual neurons to large-scale neural networks and cognitive processes. However, conventional brain simulators, such as NEURON \citepawile2022modernizing, NEST \citepgewaltig2007nest, and Brian2 \citepstimberg2019brian, pose significant challenges for high-order optimization due to their inherent black-box nature and lack of differentiability. The absence of differentiability restricts researchers to slower and less efficient optimization techniques, and even manual heuristic parameter searches \citepbilleh2020systematic. Moreover, the inability to leverage powerful gradient-based optimization techniques usually lead to longer computation times, and suboptimal model fits, further impeding the scalability of larger and more complex systems and exacerbating the challenges of multi-scale brain modeling.

To overcome these limitations, there is a pressing need for brain simulation frameworks that are natively differentiable, enabling efficient gradient-based optimizations. Recently, BrainPy \citepwang2023brainpy,wang2024brainpy has been proposed as a differentiable brain simulator to bridge this gap. By introducing fundamental features of a brain simulator, such as event-driven computation, sparse operators, numerical integrators, and a multi-scale model building interface, into the numerical computing framework JAX \citepfrostig2018compiling, BrainPy enables faithful brain simulation while inheriting the automatic differentiation (autograd) capabilities of JAX.

Leveraging BrainPy’s strengths, we propose a workflow for differentiable multi-scale brain modeling (Figure 1). This workflow utilizes gradient-based optimization to fit differentiable models of single neurons and synapses. We then incorporate connectomic data from neuroscience experiments to construct data-driven, biologically constrained spiking neural networks (SNNs). Finally, to replicate animal behavior, we train these biological-informed models on cognitive tasks using gradient-based online learning rules. Our experiments demonstrate the feasibility of our approach for achieving accurate multi-scale brain modeling.

2 Methods

We first present designs to enable our entire workflow differentiable.

2.1 Differentiable neuron models with surrogate gradients

Biological neurons generate non-differentiable binary spike events. This discontinuous nature of spiking operation, represented by the Heaviside function (𝐯)𝐯\mathcal{H}(\mathbf{v})caligraphic_H ( bold_v ), where 𝐯𝐯\mathbf{v}bold_v is the membrane potential, poses a challenge in applying gradient-based optimizations to SNN models. This is because the derivative of spiking operation is a Dirac delta function δ(𝐯)𝛿𝐯\delta(\mathbf{v})italic_δ ( bold_v ). In practice, surrogate gradients, which replace the delta gradient function with a smooth surrogate function, such as Gaussian \citepyin2021accurate, linear \citepbellec2018long, SLayer \citepshrestha2018slayer, or multi-Gaussian function \citepyin2021accurate, have demonstrated their efficacy in training SNNs using gradient descent \citepbohte2011error,Neftci2019Surrogate,bellec2020solution. We apply this approach in our workflow to enable gradient-based optimization. Moreover, we provide a suite of surrogate gradient functions (listed in Appendix C) to facilitate the selection of the most suitable function for a given task.

2.2 Event-driven differentiable synaptic operators

To mimic the brain’s efficient communication, traditional brain simulators leverage custom data structures for event-driven computations and spike communication \citepkunkel2012meeting,kunkel2014spiking,stimberg2019brian. However, these approaches often clash with autograd systems, hindering gradient-based optimization of synaptic computations. We address this challenge by introducing differentiable event-driven synaptic operators compatible with autograd frameworks (details in Appendix D). We utilize the compressed sparse row (CSR) format for storing synaptic connections and implement event-driven operations based on CSR arrays (Listing LABEL:lst:csrmv). Notably, these operators provide both forward and backward differentiation rules for differentiable computations (Listing LABEL:lst:csrmv:grads). Furthermore, BrainPy’s event-driven operators achieve significant speedups (one to two orders of magnitude) compared to traditional sparse and dense alternatives \citepwang2023brainpy,wang2024brainpy. This efficiency benefit applies to both forward state computations and backward gradient calculations.

3 Workflow for multi-scale differentiable brain modeling

Based on the differentiable neuronal and synaptic building blocks described earlier, we present a workflow for multi-scale differentiable brain modeling. This approach seamlessly integrates microscopic neuron models, mesoscopic neural circuit connectivity, and macroscopic computational tasks through gradient-based optimization algorithms (see Figure 1).

Refer to caption

Figure 1: Multi-scale differentiable brain modeling workflow. The entire workflow is executed using the differentiable brain simulator BrainPy \citepwang2023brainpy,wang2024brainpy. (A) At the microscale level, the single neuron and synapse model are fitted based on electrophysiological recording data and gradient-based optimizations. (B) At the mesoscopic level, connectome constraints are incorporated into the network construction, facilitating the integration of structural connectivity information. (C) At the macroscale behavior level, gradient-based optimization methods are applied to train the above data-constrained model networks to reproduce the cognitive behaviors as seen in humans or animals.

At the single neuron and synapse level (Figure 1A), accurate modeling of individual neurons and synaptic currents is made possible by existing knowledge and experimental techniques in cellular biophysics \citepkoch1998methods,teeter2018generalized. To facilitate large-scale neuronal network simulation and training, we employ point models, which capture diverse cellular behaviors while remaining differentiable and computationally efficient. Furthermore, to construct large sets of neuron models from empirical datasets conveniently and quickly, we employ gradient-based optimization (e.g., L-BFGS-B algorithm \citepliu1989limited,byrd1995limited). This fitting procedure, powered by JAX \citepfrostig2018compiling, is easily scalable and parallelizable through the jax.vmap or jax.pmap semantics.

At the network level (Figure 1B), we incorporate brain structure and connectome information to construct realistic brain models. Universal function approximation \citepcybenko1989approximation,leshno1993multilayer and Kolmogorov-Arnold representation theorems \citepkolmogorov1961representation,braun2009constructive suggest that distinct neural networks with complete different connectivity can perform the same computational tasks. Therefore, utilizing brain connectome-constrained neural models is an essential step in linking the organizational features of neuronal networks and the spectrum of cortical functions. Recent quantitative databases of the connectomes of various animals (e.g., Drosophila \citepwinding2023connectome, Zebrafish \citepkunst2019cellular, macaque \citepmarkov2014weighted,markov2014anatomy, mice \citepoh2014mesoscale,zingg2014neural, and marmoset \citepmajka2020open), have provided rich resources for this purpose.

At the behavioral level (Figure 1C), we utilize gradient-based optimizations to train above brain data-constrained networks on computational tasks. While handcrafted tuning and manually engineered network connectivity can implement specific functions, they fall short in generating brain-scale intelligence. Here, we optimize unknown network parameters using deep learning techniques \citepsaxe2021if, enabling the model to learn and perform tasks similar to an animal. Notably, we employ the online learning method from BrainScale \citepbrainscale, which offers an online approximation of backpropagation with low computational complexity and high training performance.

4 Training biologically-informed spiking networks on cognitive tasks

To exemplify the proposed workflow, we conducted training on a biologically-informed excitatory and inhibitory (EI) spiking network using a working memory task. The dynamics of spiking neurons in our EI network are governed by generalized integrate-and-fire (GIF) neurons \citepjolivet2004generalized. The synaptic dynamics are implemented using the Exponential model. To accurately capture the characteristic tonic spiking and adaptation \citepstaining2015allen,teeter2018generalized, the firing pattern of each GIF neuron is optimized using the L-BFGS-B algorithm (Section 4.1). The N𝑁Nitalic_N neurons are divided into excitatory and inhibitory neurons with a 4:1 EI ratio. The connectivity between the N𝑁Nitalic_N neurons in the network is established based on principles derived from the neocortical connectome \citeptheodoni2022structural. For training both the excitatory and inhibitory weights to perform the working memory tasks (Section 4.2), we employed the online learning framework BrainScale \citepbrainscale. Complete details of the EI model please see Appendix F.

4.1 Neuron fitting

Our neuron fitting procedure is depicted in Figure 2. The experimental data is obtained through current-clamp recordings, where the recorded currents mimic synaptic activity observed in vivo (Figure 2A). The neuron is defined in BrainPy as a differentiable model, and a loss function is employed to quantify the disparity between the model’s predictions and the experimental data. The mean square error can be used for fitting the membrane potential, while the gamma factor \citepjolivet2008benchmark can be employed for fitting spike trains (Appendix E). These criteria are used to calculate the gradients and subsequently update the parameter values. Through iterative gradient estimations and parameter updates, the fitting procedure aims to identify the optimal parameters that best align with the experimental recording data (Figure 2B).

Refer to caption

Figure 2: Overview of the neuron fitting procedure. (A) Experimental data: Step currents are injected into the neuron, and the resultant membrane potential responses are recorded. (B) Illustration of the optimization procedure: Parameter values are initialized from a distribution (initialization). Neurons with these parameters are simulated in parallel, and their outputs are compared with the ground truth data (simulation). The prediction error is utilized to estimate gradients (gradient), which are then used to update the initialized parameters for the subsequent iteration (update). (C) Fitting results of the HH model on a cortical pyramidal cell using five different optimization methods.

Our fitting method is first tested on GIF neuron models to capture characteristic cortical firing patterns such as spike frequency adaptation, phasic spiking, and rebound spiking. We compare the performance and speed of our gradient-based methods with conventional optimization algorithms, namely differential evolution (DE), DE algorithm with two points crossover (TwoPointsDE), and particle swarm optimization (PSO) provided in Nevergrad \citepbennet2021nevergrad, as well as the Bayesian optimization method in scikit-optimize \citeplouppe2017bayesian. The experiments demonstrate that L-BFGS-B and Bayesian optimizations exhibit the best fitting performance (Figure S7), while DE, TwoPointsDE, and PSO methods demonstrate faster fitting speed (Table S2). Although Bayesian optimization shows good performance, it converges slowly. These results indicate that the gradient-based L-BFGS-B method provides a good tradeoff between fitting performance and speed.

We further evaluate our fitting method on Hodgkin-Huxley (HH) neuron models using realistic electrophysiological recording data. Figure 2C and Figure S8 demonstrate the application of five fitting methods to an in vitro intracellular recording of a cortical pyramidal cell. The fitting results reveal that L-BFGS-B exhibits the best fitting performance (Figure 2C), achieving nearly perfect fitting of membrane potentials with a loss close to zero (Table S3). Moreover, our fitting methods demonstrate comparable speed to the evolutionary algorithm while being significantly faster than the Bayesian optimization method (Table S3). These findings underscore the potential of differentiable optimization as a promising approach for neuronal fitting.

4.2 Task training

Understanding how the brain performs complex computations remains a challenge. Recent advances in training recurrent neural networks have demonstrated high performance across various tasks, offering a promising avenue for uncovering the underlying dynamical and computational mechanisms involved \citepsong2016training,barak2017recurrent. However, these networks often lack essential biological constraints, such as spike-based communication, structural connectivity, and the distinction between excitatory and inhibitory neurons. In this study, we propose training biological SNNs while explicitly considering electrophysiological, anatomical, and structural constraints. Specifically, we construct a foundational EI network using conductance-based GIF neurons fitted to data (see Section 4.1), incorporating connectomic connectivity \citeptheodoni2022structural and conductance-based synaptic dynamics \citepvogels2005signal to implement the working memory task through gradient-based optimization algorithms \citepbrainscale.

Refer to caption


Figure 3: Training the biological-informed excitatory and inhibitory spiking networks using the evidence accumulation task. (A) The input spike train. (B) The recurrent spiking dynamics. (C, D) The membrane potentials of five excitatory (C) and inhibitory (D) neurons. (E, F) The synaptic weight distribution before (E) and after (F) training.

To generate the training data, we followed the experimental setup of an evidence accumulation task \citepmorcos2016history. The input spike train was divided into four segments: the left and right stimuli, the recall cue, and the background noise (Figure 3A). Our network was trained to separately count the left and right cues and generate the correct response by comparing the resulting numbers after prolonged periods of delay. We recorded the responses of both the excitatory and inhibitory neurons in the recurrent layer (Figure 3B). During the evidence accumulation period, inhibitory neurons exhibited significant responses after each stimulus presentation, whereas excitatory neurons displayed lower firing rates. However, during the recall period, inhibitory neurons rarely spiked. We also examined the membrane potential of all neurons (Figure 3C and D). In contrast to the current-based synapse model commonly used in deep learning applications, our conductance-based synapse modeling ensured that the membrane potential remained constrained between excitatory and inhibitory reversal potentials, eliminating the need for voltage regularization. Additionally, we analyzed the synaptic weight distribution before and after training (Figure 3E and F). We initialized excitatory and inhibitory weights with a normal distribution and took their absolute values (Figure 3E). After training, synaptic weights exhibited a distribution similar to that observed in biological measurements. Specifically, excitatory weights followed the tail of a Gaussian distribution \citepbarbour2007can, while inhibitory weights showed a log-normal distribution \citeploewenstein2011multiplicative,buzsaki2014log.

5 Conclusion and discussion

We proposed a novel workflow for differentiable multi-scale brain modeling by integrating various levels of information and constraints to build brain models that can reproduce cognitive behaviors observed in humans or animals. We demonstrated this workflow by training a biologically informed GIF network to accomplish an evidence accumulation task. Although the current illustration utilizes a network with hundreds of neurons, the online learning algorithms employed are readily scalable to much larger models (refer to Appendix G for scalability analysis). Overall, our proposed differentiable approach has the potential to accelerate progress in developing accurate and biologically plausible multi-scale brain models, ultimately leading to a deeper understanding of the brain.

However, many important challenges remain to be addressed in the future (see Appendix H for details). These challenges include balancing data quality and availability with model realism, determining the appropriate granularity for simplifying and approximating biological processes within the model, and ensuring the interpretability and theoretical grounding of the derived models. Additionally, addressing the computational efficiency of handling large-scale networks with high-dimensional parameter spaces is crucial.

\printbibliography

Appendix A Software and Data

The in vitro intracellular recording of a cortical pyramidal cell can be obtained in brain2modelfitting \citepteska2020brian2modelfitting. BrainPy is available publicly on GitHub at https://github.com/brainpy/BrainPy. As of now, BrainScale \citepbrainscale is undergoing a review process and is temporarily unavailable. However, it is expected to be released in the future. Additional packages related to the BrainPy ecosystem are accessible through the BrainPy GitHub organization at https://github.com/brainpy. The code necessary to reproduce the results presented in this paper can be found in the following GitHub repository: https://github.com/chaoming0625/differentiable-brain-modeling-workflow.

Appendix B Acknowledgements

This work was supported by Science and Technology Innovation 2030-Brain Science and Brain-inspired Intelligence Project (No. 2021ZD0200204).

Appendix C Surrogate gradient functions

In recent years, spiking neural networks (SNNs) have garnered attention due to their promising advantages in energy efficiency, fault tolerance, and biological plausibility. However, training SNNs using standard gradient descent methods is challenging because their activation functions are discontinuous and have near-zero gradients across most points. To tackle this issue, a common approach is to replace the non-differentiable spiking function with a surrogate gradient function \citepNeftci2019Surrogate. A surrogate gradient function is a smooth approximation of the derivative of the activation function, enabling the application of gradient-based learning algorithms to SNNs. The BrainPy library offers a variety of surrogate gradient functions, each possessing different characteristics such as smoothness, boundedness, and biological plausibility. A comprehensive list of these functions is provided in Table S1, and an example of a surrogate gradient function is illustrated in Figure S4.

Refer to caption
Figure S4: The collection of surrogate gradient functions g(x)superscript𝑔𝑥g^{\prime}(x)italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) in BrainPy \citepwang2023brainpy,wang2024brainpy, where x0𝑥0x\geq 0italic_x ≥ 0 represents the neuronal membrane potential exceeding the spiking threshold.

In practical applications, users can employ these surrogate gradient functions to determine whether a spike is generated at the current time step, as demonstrated in the following Python code:

1 # V: the membrane potential
2 # V_th: the threshold of the membrane potential to generate a spike
3 spike = brainpy.math.surrogate.arctan(V - V_th)
Listing S1: The example code to employ the surrogate function as the method to determine whether a spike is generated.

By utilizing the appropriate surrogate gradient function, the code above allows users to assess whether a spike occurs based on the comparison between the membrane potential (V) and the threshold (V_th).

Table S1: The full list of surrogate gradient functions provided in BrainPy.
Surrogate Gradient Function Implementation
Sigmoid function brainpy.math.surrogate.sigmoid
Piecewise quadratic function \citepEsser2016Convolutional,Yujie2018Spatio brainpy.math.surrogate.piecewise_quadratic
Piecewise exponential function \citepNeftci2019Surrogate brainpy.math.surrogate.piecewise_exp
Soft sign function brainpy.math.surrogate.soft_sign
Arctan function brainpy.math.surrogate.arctan
Nonzero sign log function brainpy.math.surrogate.nonzero_sign_log
ERF function \citepEsser2015Backpropagation,Yujie2018Spatio,Yin2020Effective brainpy.math.surrogate.erf
Piecewise leaky ReLU function \citepShihui2018Algorithm,Yujie2018Spatio brainpy.math.surrogate.piecewise_leaky_relu
Squarewave Fourier series brainpy.math.surrogate.squarewave_fourier_series
S2NN function \citepsuetake2023s3nn brainpy.math.surrogate.s2nn
q-PseudoSpike function \citepherranzcelotti2022surrogate brainpy.math.surrogate.q_pseudo_spike
Leaky ReLU function brainpy.math.surrogate.leaky_relu
Log-tailed ReLU function \citepZhaowei2017Deep brainpy.math.surrogate.log_tailed_relu
ReLU function \citepNeftci2019Surrogate brainpy.math.surrogate.relu_grad
Gaussian function \citepyin2021accurate brainpy.math.surrogate.gaussian_grad
Multi-Gaussian function \citepyin2021accurate brainpy.math.surrogate.multi_gaussian_grad
Inverse-square function brainpy.math.surrogate.inv_square_grad
SLayer function \citepshrestha2018slayer brainpy.math.surrogate.slayer_grad

Appendix D Event-driven synaptic operators

Synaptic computation usually needs event-driven matrix-vector multiplication 𝐲=𝐌𝐯𝐲𝐌𝐯\mathbf{y}=\mathbf{M}\mathbf{v}bold_y = bold_Mv, where 𝐯𝐯\mathbf{v}bold_v is the presynaptic spikes, 𝐌𝐌\mathbf{M}bold_M the synaptic connection matrix, and 𝐲𝐲\mathbf{y}bold_y the postsynaptic current. Specifically, it performs matrix-vector multiplication in a sparse and efficient way by exploiting the event property of the input vector 𝐯𝐯\mathbf{v}bold_v. Instead of multiplying the entire matrix 𝐌𝐌\mathbf{M}bold_M by the vector 𝐯𝐯\mathbf{v}bold_v, which can be wasteful if 𝐯𝐯\mathbf{v}bold_v has many zero elements, event-driven matrix-vector multiplication in BrainPy only performs multiplications for the non-zero elements of the vector, which are called events. This can reduce the number of operations and memory accesses, and improve the running performance of matrix-vector multiplication.

Particularly, we implement event-driven operators based on arrays with the compressed sparse row (CSR) format and provide both forward and backward differentiation rules. The CSR format represents the synaptic connectivity between pre- and post-synaptic neuron populations, comprising three arrays: <val, col_ind, row_ptr>. val stores the non-zero synaptic weights, col_ind stores the postsynaptic indices of the corresponding non-zero weights, and row_ptr stores the starting indices of each presynaptic neuron in the val and col_ind arrays.

To perform an event-driven linear transformation 𝐲=𝐱𝐖𝐲𝐱𝐖\mathbf{y}=\mathbf{x}\mathbf{W}bold_y = bold_xW, where 𝐖𝐖\mathbf{W}bold_W is the CSR-formatted connectivity, the pseudo-code is implemented as:

1def csrmv(val, col_ind, row_ptr, x, y):
2 for i, event in enumerate(x):
3 if event:
4 for j in range(row_ptr[i],row_ptr[i+1]):
5 y[col_ind[i]] += val[j]
Listing S2: The forward pass of event-driven sparse matrix-vector multiplication.

To efficiently compute the gradients of d𝐱d𝐱\mathrm{d}\mathbf{x}roman_d bold_x and d𝐖d𝐖\mathrm{d}\mathbf{W}roman_d bold_W, we implement the event-driven gradient computation as follows:

1# compute dx
2def csrmv_dx(val, col_ind, row_ptr, dy, dx):
3 for i in range(dy.shape[0]):
4 r = 0.
5 for j in range(row_ptr[i], row_ptr[i+1]):
6 r += val[j] * dy[col_ind[j]]
7 dx[i] = r
8
9# compute dW
10def csrmv_dW(col_ind, row_ptr, x, dy, dW):
11 for i, event in enumerate(x):
12 if event:
13 for j in range(row_ptr[i], row_ptr[i+1]):
14 dW[j] = dy[col_ind[j]]
Listing S3: The backward pass of event-driven sparse matrix-vector multiplication.

Appendix E Loss functions for neuron fitting

E.1 Mean square error for fitting membrane potentials

In order to align the simulated membrane potential with the experimentally recorded potentials, we compute the mean squared difference between the data Yi^^subscript𝑌𝑖\hat{Y_{i}}over^ start_ARG italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG and the simulated trace Yisubscript𝑌𝑖Y_{i}italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT using the mean square error formula:

MSE=1Ti=1T(YiYi^)2,MSE1𝑇superscriptsubscript𝑖1𝑇superscriptsubscript𝑌𝑖^subscript𝑌𝑖2\mathrm{MSE}=\frac{1}{T}\sum_{i=1}^{T}(Y_{i}-\hat{Y_{i}})^{2},roman_MSE = divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - over^ start_ARG italic_Y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , (1)

where T𝑇Titalic_T is the total number of times.

E.2 Gamma factor for fitting spike trains

The Gamma factor \citepjolivet2008benchmark serves as a metric for assessing the agreement between spike timings in the simulated and target traces. It is commonly employed to evaluate the performance of spiking neuron models when fitting them to electrophysiological recordings of individual neurons. The gamma factor primarily focuses on the proportion of predicted spikes that coincide with the spikes in the recording. Essentially, it quantifies how accurately the model reproduces the timing of the neuron’s firing events. The calculation of the gamma factor is as follows:

Γ=(212Δrexp)(Ncoinc2δNexprexpNexp+Nmodel)Γ212Δsubscript𝑟expsubscript𝑁coinc2𝛿subscript𝑁expsubscript𝑟expsubscript𝑁expsubscript𝑁model\Gamma=\left(\frac{2}{1-2\Delta r_{\mathrm{exp}}}\right)\left(\frac{N_{\mathrm% {coinc}}-2\delta N_{\mathrm{exp}}r_{\mathrm{exp}}}{N_{\mathrm{exp}}+N_{\mathrm% {model}}}\right)roman_Γ = ( divide start_ARG 2 end_ARG start_ARG 1 - 2 roman_Δ italic_r start_POSTSUBSCRIPT roman_exp end_POSTSUBSCRIPT end_ARG ) ( divide start_ARG italic_N start_POSTSUBSCRIPT roman_coinc end_POSTSUBSCRIPT - 2 italic_δ italic_N start_POSTSUBSCRIPT roman_exp end_POSTSUBSCRIPT italic_r start_POSTSUBSCRIPT roman_exp end_POSTSUBSCRIPT end_ARG start_ARG italic_N start_POSTSUBSCRIPT roman_exp end_POSTSUBSCRIPT + italic_N start_POSTSUBSCRIPT roman_model end_POSTSUBSCRIPT end_ARG ) (2)

where

  • Ncoincsubscript𝑁coincN_{\mathrm{coinc}}italic_N start_POSTSUBSCRIPT roman_coinc end_POSTSUBSCRIPT: number of coincidences

  • Nexpsubscript𝑁expN_{\mathrm{exp}}italic_N start_POSTSUBSCRIPT roman_exp end_POSTSUBSCRIPT and Nmodelsubscript𝑁modelN_{\mathrm{model}}italic_N start_POSTSUBSCRIPT roman_model end_POSTSUBSCRIPT: number of spikes in experimental and model spike trains

  • rexpsubscript𝑟expr_{\mathrm{exp}}italic_r start_POSTSUBSCRIPT roman_exp end_POSTSUBSCRIPT: average firing rate in experimental train

  • 2ΔNexprexp2Δsubscript𝑁expsubscript𝑟exp2\Delta N_{\mathrm{exp}}r_{\mathrm{exp}}2 roman_Δ italic_N start_POSTSUBSCRIPT roman_exp end_POSTSUBSCRIPT italic_r start_POSTSUBSCRIPT roman_exp end_POSTSUBSCRIPT: expected number of coincidences with a Poisson process

The gamma factor ΓΓ\Gammaroman_Γ equals 1 when the two spike trains match perfectly and decreases for less precise matches. It reaches 0 when the number of coincidences matches the expected count from two homogeneous Poisson processes with the same firing rate.

To turn the Gamma factor into a loss function, we add a correction term:

Loss=1+2|rdatarmodel|rdataΓ,Loss12subscript𝑟datasubscript𝑟modelsubscript𝑟dataΓ\mathrm{Loss}=1+2\frac{\lvert r_{\mathrm{data}}-r_{\mathrm{model}}\rvert}{r_{% \mathrm{data}}}-\Gamma,roman_Loss = 1 + 2 divide start_ARG | italic_r start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT - italic_r start_POSTSUBSCRIPT roman_model end_POSTSUBSCRIPT | end_ARG start_ARG italic_r start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT end_ARG - roman_Γ , (3)

where rdatasubscript𝑟datar_{\mathrm{data}}italic_r start_POSTSUBSCRIPT roman_data end_POSTSUBSCRIPT and rmodelsubscript𝑟modelr_{\mathrm{model}}italic_r start_POSTSUBSCRIPT roman_model end_POSTSUBSCRIPT are the firing rates measured in the data and model, respectively.

Appendix F Excitatory and inhibitory spiking network models

F.1 Network structure

The architecture of recurrent excitatory and inhibitory spiking networks used here is shown in Figure S5, where the recurrent layer consists of excitatory and inhibitory spiking units that receive and process the time-varying inputs from the input layer, and generate the desired time-varying outputs. The input layer encodes the sensory information relevant to the task, while the readout layer produces a decision in terms of an abstract decision variable.

Refer to caption
Figure S5: Architecture of the recurrent spiking EI network. The network consists of excitatory (E) and inhibitory (I) spiking units, denoted by 𝐫(t)𝐫𝑡\mathbf{r}(t)bold_r ( italic_t ). These units are trained using an online gradient-based learning framework BrainScale \citepbrainscale. Time-varying inputs 𝐮(t)𝐮𝑡\mathbf{u}(t)bold_u ( italic_t ) are received by the network, and the recurrent activity is encoded through time-varying outputs 𝐳(t)𝐳𝑡\mathbf{z}(t)bold_z ( italic_t ). The inputs represent task-relevant sensory information or internal rules, while the outputs encode a decision in the form of an abstract decision variable, probability distribution, or direct motor output. Each spiking unit exhibits its own dynamics, and the firing rate of each unit is adjusted through our differentiable fitting method (Section 4.1). The connectivity between the spiking units is determined based on connectomic measurements \citeptheodoni2022structural.

F.2 Input layer

The input layer in our spiking neural network is designed for an evidence accumulation task and comprises Nin=100subscript𝑁in100N_{\mathrm{in}}=100italic_N start_POSTSUBSCRIPT roman_in end_POSTSUBSCRIPT = 100 spiking neurons (Figure 3A). These neurons are divided into four functionally distinct groups:

  • Left and Right Stimuli: The first two groups represent the left and right stimuli, respectively. Each group contains 25 neurons and fires at a maximum rate of 40 Hz during the evidence accumulation period (first 1100 ms). This neural activity encodes the sensory evidence for each side.

  • Recall Signals: The third group consists of 25 neurons that generate recall signals during the output period (last 150 ms). These neurons are crucial for retrieving relevant information from memory to aid in the decision-making process.

  • Background Noise: The fourth group comprises 25 neurons that fire at a constant rate of 10 Hz throughout the entire simulation. This group simulates background activity from other cortical areas, adding a layer of realism to the network.

F.3 Recurrent spiking units

For the spiking model, the recurrent cell was implemented with the spiking neuron model. In this study, the spiking neuron is modified from the generalized integrate-and-fire neuron model \citepjolivet2004generalized. In particular, this model has two internal currents, one fast and one slow. Its dynamic behavior is given by

τI1d𝐈𝟏dtsubscript𝜏𝐼1dsubscript𝐈1d𝑡\displaystyle\tau_{I1}\frac{\mathrm{d}\mathbf{I_{1}}}{\mathrm{d}t}italic_τ start_POSTSUBSCRIPT italic_I 1 end_POSTSUBSCRIPT divide start_ARG roman_d bold_I start_POSTSUBSCRIPT bold_1 end_POSTSUBSCRIPT end_ARG start_ARG roman_d italic_t end_ARG =𝐈𝟏,absentsubscript𝐈1\displaystyle=-\mathbf{I_{1}},= - bold_I start_POSTSUBSCRIPT bold_1 end_POSTSUBSCRIPT , fast internal current (4)
τI2d𝐈𝟐dtsubscript𝜏𝐼2dsubscript𝐈2d𝑡\displaystyle\tau_{I2}\frac{\mathrm{d}\mathbf{I_{2}}}{\mathrm{d}t}italic_τ start_POSTSUBSCRIPT italic_I 2 end_POSTSUBSCRIPT divide start_ARG roman_d bold_I start_POSTSUBSCRIPT bold_2 end_POSTSUBSCRIPT end_ARG start_ARG roman_d italic_t end_ARG =𝐈𝟐,absentsubscript𝐈2\displaystyle=-\mathbf{I_{2}},= - bold_I start_POSTSUBSCRIPT bold_2 end_POSTSUBSCRIPT , slow internal current (5)
τVd𝐕dtsubscript𝜏𝑉d𝐕d𝑡\displaystyle\tau_{V}\frac{\mathrm{d}\mathbf{V}}{\mathrm{d}t}italic_τ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT divide start_ARG roman_d bold_V end_ARG start_ARG roman_d italic_t end_ARG =𝐕+Vrest+R(𝐈𝟏+𝐈𝟐+𝐈ext),absent𝐕subscript𝑉rest𝑅subscript𝐈1subscript𝐈2subscript𝐈ext\displaystyle=-\mathbf{V}+V_{\mathrm{rest}}+R(\mathbf{I_{1}}+\mathbf{I_{2}}+% \mathbf{I_{\mathrm{ext}}}),= - bold_V + italic_V start_POSTSUBSCRIPT roman_rest end_POSTSUBSCRIPT + italic_R ( bold_I start_POSTSUBSCRIPT bold_1 end_POSTSUBSCRIPT + bold_I start_POSTSUBSCRIPT bold_2 end_POSTSUBSCRIPT + bold_I start_POSTSUBSCRIPT roman_ext end_POSTSUBSCRIPT ) , membrane potential (6)

When Visuperscript𝑉𝑖V^{i}italic_V start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT of i𝑖iitalic_i-th neuron meets Vthsubscript𝑉𝑡V_{th}italic_V start_POSTSUBSCRIPT italic_t italic_h end_POSTSUBSCRIPT, the modified GIF model fires:

I1iA1,superscriptsubscript𝐼1𝑖subscript𝐴1\displaystyle I_{1}^{i}\leftarrow A_{1},italic_I start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ← italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , (7)
I2iI2i+A2,superscriptsubscript𝐼2𝑖superscriptsubscript𝐼2𝑖subscript𝐴2\displaystyle I_{2}^{i}\leftarrow I_{2}^{i}+A_{2},italic_I start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ← italic_I start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT + italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , (8)
ViVrest,superscript𝑉𝑖subscript𝑉rest\displaystyle V^{i}\leftarrow V_{\mathrm{rest}},italic_V start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ← italic_V start_POSTSUBSCRIPT roman_rest end_POSTSUBSCRIPT , (9)

where τI1subscript𝜏𝐼1\tau_{I1}italic_τ start_POSTSUBSCRIPT italic_I 1 end_POSTSUBSCRIPT denotes the time constant of the fast internal current, τI2subscript𝜏𝐼2\tau_{I2}italic_τ start_POSTSUBSCRIPT italic_I 2 end_POSTSUBSCRIPT the time constant of the slow internal current, τVsubscript𝜏𝑉\tau_{V}italic_τ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT the time constant of membrane potential, R𝑅Ritalic_R the resistance, 𝐈extsubscript𝐈ext\mathbf{I_{\mathrm{ext}}}bold_I start_POSTSUBSCRIPT roman_ext end_POSTSUBSCRIPT the external input, Vrestsubscript𝑉restV_{\mathrm{rest}}italic_V start_POSTSUBSCRIPT roman_rest end_POSTSUBSCRIPT the resting potential, and A1subscript𝐴1A_{1}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and A2subscript𝐴2A_{2}italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT the spike-triggered currents.

To match the firing patterns observed in electrophysiological experiments, particularly the tonic spiking and adaptation, we fit the neuron parameters A1,A2,τI1,τI2subscript𝐴1subscript𝐴2subscript𝜏subscript𝐼1subscript𝜏subscript𝐼2A_{1},A_{2},\tau_{I_{1}},\tau_{I_{2}}italic_A start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_A start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_τ start_POSTSUBSCRIPT italic_I start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_τ start_POSTSUBSCRIPT italic_I start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT using our gradient-based optimization methods (Section 4.1).

For the forward spiking operation, we use the Heaviside function to generate the spike:

spike(𝐱)=(𝐕[t]Vth)=(𝐱),spike𝐱𝐕delimited-[]𝑡subscript𝑉𝑡𝐱\displaystyle\mathrm{spike}(\mathbf{x})=\mathcal{H}(\mathbf{V}[t]-V_{th})=% \mathcal{H}(\mathbf{x}),roman_spike ( bold_x ) = caligraphic_H ( bold_V [ italic_t ] - italic_V start_POSTSUBSCRIPT italic_t italic_h end_POSTSUBSCRIPT ) = caligraphic_H ( bold_x ) , (10)

where 𝐱𝐱\mathbf{x}bold_x is used to represent 𝐕[t]Vth𝐕delimited-[]𝑡subscript𝑉𝑡\mathbf{V}[t]-V_{th}bold_V [ italic_t ] - italic_V start_POSTSUBSCRIPT italic_t italic_h end_POSTSUBSCRIPT.

To make the non-differentiable spiking activation compatible with the gradient-based algorithm, we considered a surrogate gradient (Appendix C):

spike(𝐱)=ReLU(α(width|𝐱|))superscriptspike𝐱ReLU𝛼width𝐱\displaystyle\mathrm{spike}^{\prime}(\mathbf{x})=\text{ReLU}(\alpha*(\mathrm{% width}-|\mathbf{x}|))roman_spike start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_x ) = ReLU ( italic_α ∗ ( roman_width - | bold_x | ) ) (11)

where width=1.0width1.0\mathrm{width}=1.0roman_width = 1.0, and α=0.3𝛼0.3\alpha=0.3italic_α = 0.3. α𝛼\alphaitalic_α is the parameter that controls the altitude of the gradient, and widthwidth\mathrm{width}roman_width is the parameter that controls the width of the gradient.

F.4 Synapse dynamics

For modeling the synaptic connections between these neurons, we employed a conductance-based current approach. The input current (Iextisubscriptsuperscript𝐼𝑖extI^{i}_{\mathrm{ext}}italic_I start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ext end_POSTSUBSCRIPT) for each neuron i𝑖iitalic_i is calculated as

Iexti=gexci(EexcVi)+ginhi(EinhVi),subscriptsuperscript𝐼𝑖extsuperscriptsubscript𝑔exc𝑖subscript𝐸excsuperscript𝑉𝑖superscriptsubscript𝑔inh𝑖subscript𝐸inhsuperscript𝑉𝑖I^{i}_{\mathrm{ext}}=g_{\mathrm{exc}}^{i}(E_{\mathrm{exc}}-V^{i})+g_{\mathrm{% inh}}^{i}(E_{\mathrm{inh}}-V^{i}),italic_I start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ext end_POSTSUBSCRIPT = italic_g start_POSTSUBSCRIPT roman_exc end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ( italic_E start_POSTSUBSCRIPT roman_exc end_POSTSUBSCRIPT - italic_V start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ) + italic_g start_POSTSUBSCRIPT roman_inh end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ( italic_E start_POSTSUBSCRIPT roman_inh end_POSTSUBSCRIPT - italic_V start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ) , (12)

where the reversal potentials are Eexc=0subscript𝐸exc0E_{\mathrm{exc}}=0italic_E start_POSTSUBSCRIPT roman_exc end_POSTSUBSCRIPT = 0 mV and Einh=120subscript𝐸inh120E_{\mathrm{inh}}=-120italic_E start_POSTSUBSCRIPT roman_inh end_POSTSUBSCRIPT = - 120 mV.

The synaptic dynamics are characterized by exponential synapses,

τsynd𝐠exe/inhdt=𝐠exe/inhsubscript𝜏syndsubscript𝐠exeinhd𝑡subscript𝐠exeinh\tau_{\mathrm{syn}}\frac{\mathrm{d}\mathbf{g}_{\mathrm{exe/inh}}}{\mathrm{d}t}% =-\mathbf{g}_{\mathrm{exe/inh}}italic_τ start_POSTSUBSCRIPT roman_syn end_POSTSUBSCRIPT divide start_ARG roman_d bold_g start_POSTSUBSCRIPT roman_exe / roman_inh end_POSTSUBSCRIPT end_ARG start_ARG roman_d italic_t end_ARG = - bold_g start_POSTSUBSCRIPT roman_exe / roman_inh end_POSTSUBSCRIPT (13)

where τsynsubscript𝜏syn\tau_{\mathrm{syn}}italic_τ start_POSTSUBSCRIPT roman_syn end_POSTSUBSCRIPT is the time constant of the synaptic state decay, and tiksubscriptsuperscript𝑡𝑘𝑖t^{k}_{i}italic_t start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the k𝑘kitalic_k-th spiking time of the presynaptic neuron i𝑖iitalic_i. Moreover, the appropriate synaptic variable of the postsynaptic conductance gexe/inhjsuperscriptsubscript𝑔exeinh𝑗{g}_{\mathrm{exe/inh}}^{j}italic_g start_POSTSUBSCRIPT roman_exe / roman_inh end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT increases when a presynaptic neuron (i𝑖iitalic_i) fires. For an excitatory presynaptic neuron,

gexcjgexcj+Wijexc,superscriptsubscript𝑔exc𝑗superscriptsubscript𝑔exc𝑗subscriptsuperscript𝑊exc𝑖𝑗g_{\mathrm{exc}}^{j}\to g_{\mathrm{exc}}^{j}+W^{\mathrm{exc}}_{ij},italic_g start_POSTSUBSCRIPT roman_exc end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT → italic_g start_POSTSUBSCRIPT roman_exc end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT + italic_W start_POSTSUPERSCRIPT roman_exc end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT , (14)

and for an inhibitory presynaptic neuron,

ginhjginhj+Wijinh.superscriptsubscript𝑔inh𝑗superscriptsubscript𝑔inh𝑗subscriptsuperscript𝑊inh𝑖𝑗g_{\mathrm{inh}}^{j}\to g_{\mathrm{inh}}^{j}+W^{\mathrm{inh}}_{ij}.italic_g start_POSTSUBSCRIPT roman_inh end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT → italic_g start_POSTSUBSCRIPT roman_inh end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT + italic_W start_POSTSUPERSCRIPT roman_inh end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT . (15)

Typically, we set τsyn=10subscript𝜏syn10\tau_{\mathrm{syn}}=10italic_τ start_POSTSUBSCRIPT roman_syn end_POSTSUBSCRIPT = 10 ms.

F.5 Scaled membrane potential

The subthreshold dynamics of biological neurons and synapses usually operate with very negative membrane potentials, often in the range of -50 mV to -80 mV. These large negative values can create challenges during training, especially when using low-precision computing. To address this, we employ a rescaling approach that normalizes the membrane potential.

The rescaling process involves an offset value (Voffsetsubscript𝑉offsetV_{\mathrm{offset}}italic_V start_POSTSUBSCRIPT roman_offset end_POSTSUBSCRIPT) and a scaling factor (Vscalesubscript𝑉scaleV_{\mathrm{scale}}italic_V start_POSTSUBSCRIPT roman_scale end_POSTSUBSCRIPT). Every membrane potential (V𝑉Vitalic_V) is transformed using the following equation:

Vs=VVoffsetVscalesubscript𝑉𝑠𝑉subscript𝑉offsetsubscript𝑉scaleV_{s}=\frac{V-V_{\text{offset}}}{V_{\text{scale}}}italic_V start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT = divide start_ARG italic_V - italic_V start_POSTSUBSCRIPT offset end_POSTSUBSCRIPT end_ARG start_ARG italic_V start_POSTSUBSCRIPT scale end_POSTSUBSCRIPT end_ARG (16)

Here, Vssubscript𝑉𝑠V_{s}italic_V start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT represents the rescaled membrane potential. By setting Vscalesubscript𝑉scaleV_{\mathrm{scale}}italic_V start_POSTSUBSCRIPT roman_scale end_POSTSUBSCRIPT such that the firing threshold becomes 1 after rescaling, we achieve a more manageable range of values for training purposes, particularly with limited computational precision. Particularly, we use Vscale=20subscript𝑉scale20V_{\mathrm{scale}}=20italic_V start_POSTSUBSCRIPT roman_scale end_POSTSUBSCRIPT = 20 and Voffset=60subscript𝑉offset60V_{\mathrm{offset}}=60italic_V start_POSTSUBSCRIPT roman_offset end_POSTSUBSCRIPT = 60 mv, so that the membrane threshold is normalized to 1, and reversal potentials of excitatory and inhibitory synapses are rescaled to 3.0 and -3.0.

F.6 Network connectivity

We designed a network comprising 400 cells, where the excitatory to inhibitory neuron ratio was set to 4:1 (Figure S5). To establish connectivity within the network, we randomly interconnected these neurons with a connection probability of 10%. This choice of connection probability takes into account the observed higher values for neighboring neurons in the cortex, as well as the lower values for neurons that are more distant \citeptheodoni2022structural.

Apart from the recurrent connections, the input neurons project excitatory synapses to all the recurrent neurons. As for the readout layer, since there is no known biological correspondence for how the brain interprets the recurrent signals, we opted for a linear projection with leaky dynamics.

F.7 Readout layer

The spiking activity of the recurrent neurons is read out using a linear layer that incorporates the leaky dynamics of the neurons:

τoutd𝐲dt=𝐲+Wout𝐳+bout,subscript𝜏outd𝐲d𝑡𝐲superscript𝑊out𝐳superscript𝑏out\displaystyle\tau_{\mathrm{out}}\frac{\mathrm{d}\mathbf{y}}{\mathrm{d}t}=-% \mathbf{y}+W^{\mathrm{out}}\mathbf{z}+b^{\mathrm{out}},italic_τ start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT divide start_ARG roman_d bold_y end_ARG start_ARG roman_d italic_t end_ARG = - bold_y + italic_W start_POSTSUPERSCRIPT roman_out end_POSTSUPERSCRIPT bold_z + italic_b start_POSTSUPERSCRIPT roman_out end_POSTSUPERSCRIPT , (17)

where τoutsubscript𝜏out\tau_{\mathrm{out}}italic_τ start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT is the time constant of the output neuron, Woutsuperscript𝑊outW^{\mathrm{out}}italic_W start_POSTSUPERSCRIPT roman_out end_POSTSUPERSCRIPT the synaptic weights between the recurrent and output neurons, and boutsuperscript𝑏outb^{\mathrm{out}}italic_b start_POSTSUPERSCRIPT roman_out end_POSTSUPERSCRIPT the bias. In the discrete description, the output dynamics is written as:

𝐲[t+Δt]=αout𝐲[t]+(Wout𝐳[t]+bout)Δt,𝐲delimited-[]𝑡Δ𝑡subscript𝛼out𝐲delimited-[]𝑡superscript𝑊out𝐳delimited-[]𝑡superscript𝑏outΔ𝑡\displaystyle\mathbf{y}[t+\Delta t]=\alpha_{\mathrm{out}}\mathbf{y}[t]+(W^{% \mathrm{out}}\mathbf{z}[t]+b^{\mathrm{out}})\Delta t,bold_y [ italic_t + roman_Δ italic_t ] = italic_α start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT bold_y [ italic_t ] + ( italic_W start_POSTSUPERSCRIPT roman_out end_POSTSUPERSCRIPT bold_z [ italic_t ] + italic_b start_POSTSUPERSCRIPT roman_out end_POSTSUPERSCRIPT ) roman_Δ italic_t , (18)

where αout=e1τoutΔtsubscript𝛼outsuperscript𝑒1subscript𝜏outΔ𝑡\alpha_{\mathrm{out}}=e^{-\frac{1}{\tau_{\mathrm{out}}}\Delta t}italic_α start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT = italic_e start_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_τ start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT end_ARG roman_Δ italic_t end_POSTSUPERSCRIPT.

F.8 Weight initialization

Initial input and recurrent weights were drawn from a Gaussian distribution and taken the absolute values Wji|snin𝒩(0,1)|similar-tosubscript𝑊𝑗𝑖𝑠subscript𝑛in𝒩01W_{ji}\sim\left|\sqrt{\frac{s}{n_{\mathrm{in}}}}\mathscr{N}(0,1)\right|italic_W start_POSTSUBSCRIPT italic_j italic_i end_POSTSUBSCRIPT ∼ | square-root start_ARG divide start_ARG italic_s end_ARG start_ARG italic_n start_POSTSUBSCRIPT roman_in end_POSTSUBSCRIPT end_ARG end_ARG script_N ( 0 , 1 ) |, where ninsubscript𝑛inn_{\mathrm{in}}italic_n start_POSTSUBSCRIPT roman_in end_POSTSUBSCRIPT is the number of afferent neurons, 𝒩(0,1)𝒩01\mathscr{N}(0,1)script_N ( 0 , 1 ) is the zero-mean unit-variance Gaussian distribution, and s𝑠sitalic_s is the weight scale. For excitatory neurons (including the input and recurrent excitatory neurons), s=1.0𝑠1.0s=1.0italic_s = 1.0; for inhibitory neurons, s=4.0𝑠4.0s=4.0italic_s = 4.0. For the readout weights, we draw its values from a Gaussian distribution Wjiout2.0nrec𝒩(0,1)similar-tosuperscriptsubscript𝑊𝑗𝑖out2.0subscript𝑛rec𝒩01W_{ji}^{\mathrm{out}}\sim\sqrt{\frac{2.0}{n_{\mathrm{rec}}}}\mathscr{N}(0,1)italic_W start_POSTSUBSCRIPT italic_j italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_out end_POSTSUPERSCRIPT ∼ square-root start_ARG divide start_ARG 2.0 end_ARG start_ARG italic_n start_POSTSUBSCRIPT roman_rec end_POSTSUBSCRIPT end_ARG end_ARG script_N ( 0 , 1 ), where nrecsubscript𝑛recn_{\mathrm{rec}}italic_n start_POSTSUBSCRIPT roman_rec end_POSTSUBSCRIPT is the number of neurons in the recurrent layer.

F.9 Training methods

The neuron fitting was performed using the L-BFGS-B algorithm \citepliu1989limited,byrd1995limited. While for the task training, we utilized the online learning algorithm in BrainScale \citepbrainscale. The integration time step ΔtΔ𝑡\Delta troman_Δ italic_t is 1 ms for the spiking neural network. The Adam optimizer \citepKingma2014AdamAM was used to calculate gradient-based optimization. The goal of the training was to minimize the cross-entropy between the output activity and the target output during the recall period.

Appendix G Scalability analysis

The large-scale nature of the brain raises an important question: can our differentiable approach scale up to brain-scale spiking neural networks? Currently, we cannot provide a definitive answer, as it involves not only computational resource bottlenecks but also performance generalization to high-dimensional parameter spaces. However, from a computational resource perspective, we have evaluated how our approach can scale up the training of large-scale spiking networks with long time sequences (see Figure S6 for details).

Refer to caption
Figure S6: Comparative Analysis of Computational Memory, Speed, and Training Performance between BPTT and BrainScale \citepbrainscale using the IBM DVS Gesture dataset \citepamir2017low. (A) Memory Consumption Comparison per Batch: This subfigure illustrates the comparison of memory requirements between BPTT and BrainScale, showcasing the amount of GPU runtime memory utilized by each method for processing a single batch of data. (B) Computational Speed Comparison per Batch: This subfigure presents a comparative analysis of the computational speed achieved by BPTT and BrainScale for processing a single batch of data, highlighting the differences in their processing capabilities. (C) Maximum Testing Accuracy Comparison: This subfigure showcases the comparison of the maximum achieved testing accuracy between BPTT and BrainScale, emphasizing the performance differences of the two methods when evaluated on the IBM DVS Gesture dataset.

In particular, we conducted a memory and computational complexity analysis on a three-layer recurrent spiking neural network trained on the IBM DVS Gesture dataset \citepamir2017low. With a batch size of 128 and 512 hidden neurons per layer, we compared the average memory usage and computation time per batch when training using backpropagation through time (BPTT) \citepWerbos1988GeneralizationOB,Mozer1989AFB,Williams1990AnEG and the BrainScale neuromorphic system \citepbrainscale across various sequence lengths.

As shown in Figure S6A, the training memory required by BPTT increases linearly with the number of time steps, leading to an out-of-memory error for sequences longer than 600 steps. In contrast, BrainScale demonstrates remarkable memory efficiency by maintaining a constant memory footprint for the eligibility trace, regardless of the sequence length. Notably, BrainScale consumes less than 0.5 GB of GPU memory during the entire training process, reducing memory requirements by hundreds of times compared to BPTT.

The computational time for both BPTT and BrainScale scales linearly with the number of time steps (Figure S6B). However, BrainScale trains approximately twice as fast as BPTT, and this acceleration ratio increases for longer sequences due to its event-driven and low complexity computation.

We also evaluated the training performance of BPTT and BrainScale (Figure S6C under the same hyperparameter settings for the training and network model. We calculated the maximum testing accuracy, and found that BrainScale demonstrated comparable performance to BPTT, even as the sequence length increased.

These results highlight the potential scalability of our differentiable approach, leveraging the computational efficiency and memory advantages of the online training system BrainScale. While further research is needed to generalize to larger networks and more complex tasks, our approach paves the way for training brain-scale spiking neural networks efficiently.

Appendix H Limitations and potential challenges

The multi-scale differentiable brain modeling workflow described above is a comprehensive approach that aims to integrate various levels of information and constraints to build brain models that can reproduce cognitive behaviors observed in humans or animals. However, there are potential limitations and challenges to this approach:

  1. 1.

    Data availability and quality: The accuracy of brain models at each scale heavily depends on the quality and availability of the underlying data. At the microscale level, the accuracy of single neuron and synapse models relies on the quality and completeness of electrophysiological recording data. However, in practice, it is often challenging to obtain comprehensive electrophysiological data for all neurons. Similarly, at the mesoscopic level, the accuracy of connectome constraints depends on the quality and resolution of structural connectivity data obtained from techniques such as diffusion tensor imaging (DTI) or electron microscopy. These techniques face significant challenges in acquiring precise connectomic information in humans and animals. For example, while high-resolution electron microscopy can generate detailed 3D maps of neuronal connections from thin brain sections, the process is extremely labor-intensive, requires specialized equipment, and is currently limited to small tissue volumes due to the immense computational demands of reconstructing large-scale connectomes.

  2. 2.

    Biological realism and simplifications: While our approach aims to incorporate biological constraints, it necessarily involves simplifications and approximations of the underlying biological processes. In particular, our current model utilizes point-based simplified neuron models and considers only excitatory and inhibitory synaptic connections. This level of abstraction cannot fully capture the complexity and dynamics of real biological systems, especially when addressing intricate phenomena such as nonlinear dendritic effects, neuromodulation, and synaptic plasticity.

  3. 3.

    Complexity and computational resources: Building and training multi-scale brain models can be computationally intensive, particularly when dealing with large-scale networks and high-dimensional parameter spaces. To address this, we utilize the online learning framework BrainScale \citepbrainscale, which significantly reduces computational resource requirements for long-sequence learning tasks (see Appendix G). However, while online learning algorithms in BrainScale are memory-efficient, they are still approximations of full gradients and may encounter inefficiencies when training very high-dimensional parameters. Moreover, multi-scale brain models necessitate substantial computational resources, including powerful GPU hardware and efficient distributed computing algorithms. The complexity of these models demands not only high-performance computing infrastructure but also optimized software frameworks to manage and streamline the extensive computations involved.

  4. 4.

    Integration across scales: Seamlessly integrating information and constraints from different spatial and temporal scales is a non-trivial task. More biological details or constraints usually lead to more poor training performance. As more biological details are incorporated, the number of parameters in the model grows, making the optimization problem more complex and prone to issues such as vanishing/exploding gradients, local minima, and slow convergence.

  5. 5.

    Interpretability and theoretical insights: While the gradient-based optimization methods can produce models that fit the data, it may be challenging to derive theoretical insights or interpretable mechanisms from these models, especially when dealing with highly complex and non-linear brain dynamics models.

Despite these potential limitations and challenges, the multi-scale differentiable brain modeling workflow still represents a promising approach to integrate various levels of information and constraints to build more realistic and accurate brain models.

Appendix I Supplementary data


Refer to caption

Figure S7: Fitting the GIF neuron model to the membrane potential data. (A) Synaptic current synthesized in vitor for fitting the GIF neuron model with membrane potential data. (B) Fitting results of the GIF dynamics using the L-BFGS-B algorithm. (C) Fitting results of the GIF dynamics using the differential evolution (DE) algorithm provided in Nevergrad. (D) Fitting results of the GIF dynamics using the particle swarm optimization (PSO) algorithm provided in Nevergrad. (E) Fitting results of the GIF dynamics using the DE optimization with two points crossover (TwoPointsDE) algorithm provided in Nevergrad. (F) Fitting results of the GIF dynamics using the Bayesian optimization algorithm provided in scikit-optimize.

Refer to caption

Figure S8: Fitting the HH neuron model to the membrane potential data. (A) Synaptic current synthesized in vitor for fitting the HH neuron model with membrane potential data. (B) Fitting results of the HH dynamics using the L-BFGS-B algorithm. (C) Fitting results of the HH dynamics using the differential evolution (DE) algorithm provided in Nevergrad. (D) Fitting results of the HH dynamics using the particle swarm optimization (PSO) algorithm provided in Nevergrad. (E) Fitting results of the HH dynamics using the DE optimization with two points crossover (TwoPointsDE) algorithm provided in Nevergrad. (F) Fitting results of the HH dynamics using the Bayesian optimization algorithm provided in scikit-optimize.
Table S2: The loss and speed comparison among five optimization methods, including L-BFGS-B, DE, PSO, TwoPointsDE, and Bayesian optimizations, when fitting the GIF neuron dynamics on the membrane potential data.

Fitting Method Fitting Loss Fitting Speed
L-BFGS-B (Ours) 6.799 ±plus-or-minus\pm± 4.623 5.404 ±plus-or-minus\pm± 0.294 s
DE (Nevergrad) 9.966 ±plus-or-minus\pm± 2.281 1.172 ±plus-or-minus\pm± 0.09 s
PSO (Nevergrad) 14.034 ±plus-or-minus\pm± 2.857 1.151 ±plus-or-minus\pm± 0.150 s
TwoPointsDE (Nevergrad) 9.702 ±plus-or-minus\pm± 5.115 1.161 ±plus-or-minus\pm± 0.172 s
Bayesian (scikit-optimize) 3.511 ±plus-or-minus\pm± 1.741 62.330 ±plus-or-minus\pm± 10.917 s
Table S3: The loss and speed comparison among five optimization methods, including L-BFGS-B, DE, PSO, TwoPointsDE, and Bayesian optimizations, when fitting the HH neuron dynamics on the membrane potential data.

Fitting Method Fitting Loss Fitting Speed
L-BFGS-B (Ours) 2.3e-08 ±plus-or-minus\pm± 1.55e-08 3.818 ±plus-or-minus\pm± 0.725 s
DE (Nevergrad) 24.42 ±plus-or-minus\pm± 22.55 0.619 ±plus-or-minus\pm± 0.139 s
PSO (Nevergrad) 35.72 ±plus-or-minus\pm± 24.06 0.663 ±plus-or-minus\pm± 0.168 s
TwoPointsDE (Nevergrad) 32.31 ±plus-or-minus\pm± 22.60 0.658 ±plus-or-minus\pm± 0.175 s
Bayesian (scikit-optimize) 27.95 ±plus-or-minus\pm± 25.76 55.26 ±plus-or-minus\pm± 12.39 s